diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index a5070c937c6..b827468d719 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -9,5 +9,5 @@ jobs: name: "Gradle wrapper validation" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: gradle/wrapper-validation-action@v1 diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml index cb648e7fc73..ba37d4db4be 100644 --- a/.github/workflows/lock.yml +++ b/.github/workflows/lock.yml @@ -13,8 +13,8 @@ jobs: lock: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v2 + - uses: dessant/lock-threads@v3 with: github-token: ${{ github.token }} - issue-lock-inactive-days: 90 - pr-lock-inactive-days: 90 + issue-inactive-days: 90 + pr-inactive-days: 90 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 609d0841494..4788ebfc7f0 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -21,14 +21,14 @@ jobs: fail-fast: false # Should swap to true if we grow a large matrix steps: - - uses: actions/checkout@v2 - - uses: actions/setup-java@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-java@v3 with: java-version: ${{ matrix.jre }} distribution: 'temurin' - name: Gradle cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | ~/.gradle/caches @@ -37,7 +37,7 @@ jobs: restore-keys: | ${{ runner.os }}-gradle- - name: Maven cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | ~/.m2/repository @@ -46,7 +46,7 @@ jobs: restore-keys: | ${{ runner.os }}-maven- - name: Protobuf cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /tmp/protobuf-cache key: ${{ runner.os }}-maven-${{ hashFiles('buildscripts/make_dependencies.sh') }} @@ -55,7 +55,7 @@ jobs: run: buildscripts/kokoro/unix.sh - name: Post Failure Upload Test Reports to Artifacts if: ${{ failure() }} - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: Test Reports (JRE ${{ matrix.jre }}) path: ./*/build/reports/tests/** @@ -67,4 +67,4 @@ jobs: if: matrix.jre == 8 # Upload once, instead of for each job in the matrix run: ./gradlew :grpc-all:coveralls -x compileJava - name: Codecov - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 diff --git a/BUILD.bazel b/BUILD.bazel index 4bf8cdbc9b5..40c04022673 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -51,3 +51,21 @@ java_library( "@com_google_guava_guava//jar", ], ) + +java_plugin( + name = "auto_value", + generates_api = 1, + processor_class = "com.google.auto.value.processor.AutoValueProcessor", + deps = ["@com_google_auto_value_auto_value//jar"], +) + +java_library( + name = "auto_value_annotations", + exported_plugins = [":auto_value"], + neverlink = 1, + visibility = ["//:__subpackages__"], + exports = [ + "@com_google_auto_value_auto_value_annotations//jar", + "@org_apache_tomcat_annotations_api//jar", # @Generated for Java 9+ + ], +) diff --git a/COMPILING.md b/COMPILING.md index aef4cf61b66..3c5ad537e07 100644 --- a/COMPILING.md +++ b/COMPILING.md @@ -44,11 +44,11 @@ This section is only necessary if you are making changes to the code generation. Most users only need to use `skipCodegen=true` as discussed above. ### Build Protobuf -The codegen plugin is C++ code and requires protobuf 3.19.2 or later. +The codegen plugin is C++ code and requires protobuf 21.7 or later. For Linux, Mac and MinGW: ``` -$ PROTOBUF_VERSION=3.19.2 +$ PROTOBUF_VERSION=21.7 $ curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v$PROTOBUF_VERSION/protobuf-all-$PROTOBUF_VERSION.tar.gz $ tar xzf protobuf-all-$PROTOBUF_VERSION.tar.gz $ cd protobuf-$PROTOBUF_VERSION @@ -132,23 +132,21 @@ of Android SDK being installed is shown at `Android SDK Location` at the same pa The default is `$HOME/Library/Android/sdk` for Mac OS and `$HOME/Android/Sdk` for Linux. You can change this to a custom location. -### Install via Android SDK Manager -Go to [Android SDK](https://developer.android.com/studio) and navigate to __Command line tools only__. -Download and unzip the package for your build machine OS into somewhere easy to find -(e.g., `$HOME/Android/sdk`). This will be your Android SDK home directory. -The Android SDK Manager tool is in `tools/bin/sdkmanager`. - -Run the `sdkmanager` tool: -``` -$ tools/bin/sdkmanager --update -$ tools/bin/sdkmanager "platforms;android-28" -``` -This installs Android SDK 28 into `platforms/android-28` of your Android SDK home directory. -More usage of `sdkmanager` can be found at [Android User Guide](https://developer.android.com/studio/command-line/sdkmanager). - - -After Android SDK is installed, you need to set the `ANDROID_HOME` environment variable to your -Android SDK home directory: -``` -$ export ANDROID_HOME= +### Install via Command line tools only +Go to [Android SDK](https://developer.android.com/studio#command-tools) and +download the commandlinetools package for your build machine OS. Decide where +you want the Android SDK to be stored. `$HOME/Library/Android/sdk` is typical on +Mac OS and `$HOME/Android/Sdk` for Linux. + +```sh +export ANDROID_HOME=$HOME/Android/Sdk # Adjust to your liking +mkdir $HOME/Android +mkdir $ANDROID_HOME +mkdir $ANDROID_HOME/cmdline-tools +unzip -d $ANDROID_HOME/cmdline-tools DOWNLOADS/commandlinetools-*.zip +mv $ANDROID_HOME/cmdline-tools/cmdline-tools $ANDROID_HOME/cmdline-tools/latest +# Android SDK is now ready. Now accept licenses so the build can auto-download packages +$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager --licenses + +# Add 'export ANDROID_HOME=$HOME/Android/Sdk' to your .bashrc or equivalent ``` diff --git a/MAINTAINERS.md b/MAINTAINERS.md index f657dccc405..f05542e1987 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,12 +8,13 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [dapengzhang0](https://github.com/dapengzhang0), Google LLC - [ejona86](https://github.com/ejona86), Google LLC +- [jdcormie](https://github.com/jdcormie), Google LLC +- [larry-safran](https://github.com/larry-safran), Google LLC +- [markb74](https://github.com/markb74), Google LLC - [ran-su](https://github.com/ran-su), Google LLC - [sanjaypujare](https://github.com/sanjaypujare), Google LLC - [sergiitk](https://github.com/sergiitk), Google LLC -- [srini100](https://github.com/srini100), Google LLC - [temawi](https://github.com/temawi), Google LLC - [YifeiZhuang](https://github.com/YifeiZhuang), Google LLC - [zhangkun83](https://github.com/zhangkun83), Google LLC @@ -21,11 +22,13 @@ for general contribution guidelines. ## Emeritus Maintainers (in alphabetical order) - [carl-mastrangelo](https://github.com/carl-mastrangelo), Google LLC - [creamsoup](https://github.com/creamsoup), Google LLC +- [dapengzhang0](https://github.com/dapengzhang0), Google LLC - [ericgribkoff](https://github.com/ericgribkoff), Google LLC - [jiangtaoli2016](https://github.com/jiangtaoli2016), Google LLC - [jtattermusch](https://github.com/jtattermusch), Google LLC - [louiscryan](https://github.com/louiscryan), Google LLC - [nicolasnoble](https://github.com/nicolasnoble), Google LLC - [nmittler](https://github.com/nmittler), Google LLC +- [srini100](https://github.com/srini100), Google LLC - [voidzcy](https://github.com/voidzcy), Google LLC - [zpencer](https://github.com/zpencer), Google LLC diff --git a/README.md b/README.md index d19c0868269..2ae0da3f12a 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,6 @@ gRPC-Java - An RPC library and framework ======================================== -gRPC-Java works with JDK 8. gRPC-Java clients are supported on Android API -levels 19 and up (KitKat and later). Deploying gRPC servers on an Android -device is not supported. - -TLS usage typically requires using Java 8, or Play Services Dynamic Security -Provider on Android. Please see the [Security Readme](SECURITY.md). - @@ -24,6 +17,26 @@ Provider on Android. Please see the [Security Readme](SECURITY.md). [![Line Coverage Status](https://coveralls.io/repos/grpc/grpc-java/badge.svg?branch=master&service=github)](https://coveralls.io/github/grpc/grpc-java?branch=master) [![Branch-adjusted Line Coverage Status](https://codecov.io/gh/grpc/grpc-java/branch/master/graph/badge.svg)](https://codecov.io/gh/grpc/grpc-java) +Supported Platforms +------------------- + +gRPC-Java supports Java 8 and later. Android minSdkVersion 19 (KitKat) and +later are supported with [Java 8 language desugaring][android-java-8]. + +TLS usage on Android typically requires Play Services Dynamic Security Provider. +Please see the [Security Readme](SECURITY.md). + +Older Java versions are not directly supported, but a branch remains available +for fixes and releases. See [gRFC P5 JDK Version Support +Policy][P5-jdk-version-support]. + +Java version | gRPC Branch +------------ | ----------- +7 | 1.41.x + +[android-java-8]: https://developer.android.com/studio/write/java8-support#supported_features +[P5-jdk-version-support]: https://github.com/grpc/proposal/blob/master/P5-jdk-version-support.md#proposal + Getting Started --------------- @@ -31,8 +44,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.44.1/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.44.1/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.53.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.53.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,18 +56,18 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.44.1 + 1.53.0 runtime io.grpc grpc-protobuf - 1.44.1 + 1.53.0 io.grpc grpc-stub - 1.44.1 + 1.53.0 org.apache.tomcat @@ -66,23 +79,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.44.1' -implementation 'io.grpc:grpc-protobuf:1.44.1' -implementation 'io.grpc:grpc-stub:1.44.1' +runtimeOnly 'io.grpc:grpc-netty-shaded:1.53.0' +implementation 'io.grpc:grpc-protobuf:1.53.0' +implementation 'io.grpc:grpc-stub:1.53.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.44.1' -implementation 'io.grpc:grpc-protobuf-lite:1.44.1' -implementation 'io.grpc:grpc-stub:1.44.1' +implementation 'io.grpc:grpc-okhttp:1.53.0' +implementation 'io.grpc:grpc-protobuf-lite:1.53.0' +implementation 'io.grpc:grpc-stub:1.53.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.44.1 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.53.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -112,9 +125,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.19.2:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.21.7:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.44.1:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.53.0:exe:${os.detected.classifier} @@ -135,16 +148,16 @@ For non-Android protobuf-based codegen integrated with the Gradle build system, you can use [protobuf-gradle-plugin][]: ```gradle plugins { - id 'com.google.protobuf' version '0.8.17' + id 'com.google.protobuf' version '0.9.1' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.19.2" + artifact = "com.google.protobuf:protoc:3.21.7" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.44.1' + artifact = 'io.grpc:protoc-gen-grpc-java:1.53.0' } } generateProtoTasks { @@ -161,23 +174,23 @@ The prebuilt protoc-gen-grpc-java binary uses glibc on Linux. If you are compiling on Alpine Linux, you may want to use the [Alpine grpc-java package][] which uses musl instead. -[Alpine grpc-java package]: https://pkgs.alpinelinux.org/package/edge/testing/x86_64/grpc-java +[Alpine grpc-java package]: https://pkgs.alpinelinux.org/package/edge/community/x86_64/grpc-java For Android protobuf-based codegen integrated with the Gradle build system, also use protobuf-gradle-plugin but specify the 'lite' options: ```gradle plugins { - id 'com.google.protobuf' version '0.8.17' + id 'com.google.protobuf' version '0.9.1' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.19.2" + artifact = "com.google.protobuf:protoc:3.21.7" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.44.1' + artifact = 'io.grpc:protoc-gen-grpc-java:1.53.0' } } generateProtoTasks { @@ -243,12 +256,15 @@ wire. The interfaces to it are abstract just enough to allow plugging in of different implementations. Note the transport layer API is considered internal to gRPC and has weaker API guarantees than the core API under package `io.grpc`. -gRPC comes with three Transport implementations: +gRPC comes with multiple Transport implementations: -1. The Netty-based transport is the main transport implementation based on - [Netty](https://netty.io). It is for both the client and the server. -2. The OkHttp-based transport is a lightweight transport based on - [OkHttp](https://square.github.io/okhttp/). It is mainly for use on Android - and is for client only. +1. The Netty-based HTTP/2 transport is the main transport implementation based + on [Netty](https://netty.io). It is not officially supported on Android. +2. The OkHttp-based HTTP/2 transport is a lightweight transport based on + [Okio](https://square.github.io/okio/) and forked low-level parts of + [OkHttp](https://square.github.io/okhttp/). It is mainly for use on Android. 3. The in-process transport is for when a server is in the same process as the - client. It is useful for testing, while also being safe for production use. + client. It is used frequently for testing, while also being safe for + production use. +4. The Binder transport is for Android cross-process communication on a single + device. diff --git a/RELEASING.md b/RELEASING.md index b8682961234..989426f328a 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -33,9 +33,11 @@ $ VERSION_FILES=( examples/example-jwt-auth/pom.xml examples/example-hostname/build.gradle examples/example-hostname/pom.xml + examples/example-servlet/build.gradle examples/example-tls/build.gradle examples/example-tls/pom.xml examples/example-xds/build.gradle + examples/example-orca/build.gradle ) ``` @@ -50,7 +52,7 @@ would be used to create all `v1.7` tags (e.g. `v1.7.0`, `v1.7.1`). 1. Review the issues in the current release [milestone](https://github.com/grpc/grpc-java/milestones) for issues that won't make the cut. Check if any of them can be - closed. Be aware of the issues with the 'release blocker' label. + closed. Be aware of the issues with the [TODO:release blocker][] label. Consider reaching out to the assignee for the status update. 2. For `master`, change root build files to the next minor snapshot (e.g. ``1.8.0-SNAPSHOT``). @@ -88,11 +90,16 @@ would be used to create all `v1.7` tags (e.g. `v1.7.0`, `v1.7.1`). git log --oneline "$(git merge-base v$MAJOR.$((MINOR-1)).0 upstream/v$MAJOR.$MINOR.x)"..v$MAJOR.$((MINOR-1)).0^ ``` +[TODO:release blocker]: https://github.com/grpc/grpc-java/issues?q=label%3A%22TODO%3Arelease+blocker%22 +[TODO:backport]: https://github.com/grpc/grpc-java/issues?q=label%3ATODO%3Abackport + Tagging the Release ------------------- 1. Verify there are no open issues in the release milestone. Open issues should - either be deferred or resolved and the fix backported. + either be deferred or resolved and the fix backported. Verify there are no + [TODO:release blocker][] nor [TODO:backport][] issues (open or closed), or + that they are tracking an issue for a different branch. 2. Ensure that Google-internal steps completed at go/grpc/java/releasing#before-tagging-a-release. 3. For vMajor.Minor.x branch, change `README.md` to refer to the next release version. _Also_ update the version numbers for protoc if the protobuf library @@ -103,7 +110,7 @@ Tagging the Release $ git pull upstream v$MAJOR.$MINOR.x $ git checkout -b release # Bump documented gRPC versions. - # Also update protoc version to match protocVersion in build.gradle. + # Also update protoc version to match protobuf version in gradle/libs.versions.toml. $ ${EDITOR:-nano -w} README.md $ ${EDITOR:-nano -w} documentation/android-channel-builder.md $ ${EDITOR:-nano -w} cronet/README.md @@ -187,7 +194,7 @@ $ git cherry-pick v$MAJOR.$MINOR.$PATCH^ Update version referenced by tutorials -------------------------------------- -Update the `grpc_java_release_tag` in +Update `params.grpc_vers.java` in [config.yaml](https://github.com/grpc/grpc.io/blob/master/config.yaml) of the grpc.io repository. diff --git a/SECURITY.md b/SECURITY.md index 44efbe8d42e..1d36eb90103 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -410,7 +410,9 @@ grpc-netty version | netty-handler version | netty-tcnative-boringssl-static ver 1.32.x-1.34.x | 4.1.51.Final | 2.0.31.Final 1.35.x-1.41.x | 4.1.52.Final | 2.0.34.Final 1.42.x-1.43.x | 4.1.63.Final | 2.0.38.Final -1.44.x | 4.1.72.Final | 2.0.46.Final +1.44.x-1.47.x | 4.1.72.Final | 2.0.46.Final +1.48.x-1.49.x | 4.1.77.Final | 2.0.53.Final +1.50.x- | 4.1.79.Final | 2.0.54.Final _(grpc-netty-shaded avoids issues with keeping these versions in sync.)_ diff --git a/WORKSPACE b/WORKSPACE index b6198573f28..4727a962148 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,9 +4,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "rules_jvm_external", - sha256 = "cd1a77b7b02e8e008439ca76fd34f5b07aecb8c752961f9640dea15e9e5ba1ca", - strip_prefix = "rules_jvm_external-4.2", - url = "https://github.com/bazelbuild/rules_jvm_external/archive/4.2.zip", + sha256 = "c21ce8b8c4ccac87c809c317def87644cdc3a9dd650c74f41698d761c95175f3", + strip_prefix = "rules_jvm_external-1498ac6ccd3ea9cdb84afed65aa257c57abf3e0a", + url = "https://github.com/bazelbuild/rules_jvm_external/archive/1498ac6ccd3ea9cdb84afed65aa257c57abf3e0a.zip", ) load("@rules_jvm_external//:defs.bzl", "maven_install") @@ -21,6 +21,17 @@ load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") protobuf_deps() +load("@envoy_api//bazel:repositories.bzl", "api_dependencies") + +api_dependencies() + +load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") + +switched_rules_by_language( + name = "com_google_googleapis_imports", + java = True, +) + maven_install( artifacts = IO_GRPC_GRPC_JAVA_ARTIFACTS + PROTOBUF_MAVEN_ARTIFACTS, generate_compat_repositories = True, @@ -28,6 +39,7 @@ maven_install( repositories = [ "https://repo.maven.apache.org/maven2/", ], + strict_visibility = True, ) load("@maven//:compat.bzl", "compat_repositories") diff --git a/all/build.gradle b/all/build.gradle index ede0b03c27e..e4d7e9085b5 100644 --- a/all/build.gradle +++ b/all/build.gradle @@ -19,6 +19,8 @@ def subprojects = [ project(':grpc-protobuf-lite'), project(':grpc-rls'), project(':grpc-services'), + project(':grpc-servlet'), + project(':grpc-servlet-jakarta'), project(':grpc-stub'), project(':grpc-testing'), project(':grpc-xds'), @@ -36,7 +38,7 @@ dependencies { api subprojects.minus([project(':grpc-protobuf-lite')]) } -javadoc { +tasks.named("javadoc").configure { classpath = files(subprojects.collect { subproject -> subproject.javadoc.classpath }) @@ -49,7 +51,7 @@ javadoc { } } -task jacocoMerge(type: JacocoMerge) { +tasks.register("jacocoMerge", JacocoMerge) { dependsOn(subprojects.jacocoTestReport.dependsOn) dependsOn(project(':grpc-interop-testing').jacocoTestReport.dependsOn) mustRunAfter(subprojects.jacocoTestReport.mustRunAfter) @@ -60,7 +62,7 @@ task jacocoMerge(type: JacocoMerge) { .filter { f -> f.exists() } } -jacocoTestReport { +tasks.named("jacocoTestReport").configure { dependsOn(jacocoMerge) reports { xml.required = true @@ -78,4 +80,4 @@ coveralls { sourceDirs = subprojects.sourceSets.main.allSource.srcDirs.flatten() } -tasks.coveralls { dependsOn(jacocoTestReport) } +tasks.named("coveralls").configure { dependsOn tasks.named("jacocoTestReport") } diff --git a/alts/build.gradle b/alts/build.gradle index a056799d4f0..926a3d4993b 100644 --- a/alts/build.gradle +++ b/alts/build.gradle @@ -17,12 +17,12 @@ dependencies { project(':grpc-grpclb'), project(':grpc-protobuf'), project(':grpc-stub'), - libraries.protobuf, + libraries.protobuf.java, libraries.conscrypt, libraries.guava, - libraries.google_auth_oauth2_http + libraries.google.auth.oauth2Http def nettyDependency = implementation project(':grpc-netty') - compileOnly libraries.javax_annotation + compileOnly libraries.javax.annotation shadow configurations.implementation.getDependencies().minus(nettyDependency) shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') @@ -32,43 +32,48 @@ dependencies { project(':grpc-testing-proto'), libraries.guava, libraries.junit, - libraries.mockito, + libraries.mockito.core, libraries.truth - testImplementation (libraries.guava_testlib) { + testImplementation (libraries.guava.testlib) { exclude group: 'junit', module: 'junit' } - testRuntimeOnly libraries.netty_tcnative, - libraries.netty_epoll - signature 'org.codehaus.mojo.signature:java17:1.0@signature' + testRuntimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testRuntimeOnly (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-x86_64" + } + } + signature libraries.signature.java } configureProtoCompilation() import net.ltgt.gradle.errorprone.CheckSeverity -[compileJava, compileTestJava].each() { +[tasks.named("compileJava"), tasks.named("compileTestJava")]*.configure { // protobuf calls valueof. Will be fixed in next release (google/protobuf#4046) - it.options.compilerArgs += [ + options.compilerArgs += [ "-Xlint:-deprecation" ] // ALTS returns a lot of futures that we mostly don't care about. - it.options.errorprone.check("FutureReturnValueIgnored", CheckSeverity.OFF) + options.errorprone.check("FutureReturnValueIgnored", CheckSeverity.OFF) } -javadoc { +tasks.named("javadoc").configure { exclude 'io/grpc/alts/internal/**' exclude 'io/grpc/alts/Internal*' } -jar { +tasks.named("jar").configure { // Must use a different archiveClassifier to avoid conflicting with shadowJar archiveClassifier = 'original' } // We want to use grpc-netty-shaded instead of grpc-netty. But we also want our // source to work with Bazel, so we rewrite the code as part of the build. -shadowJar { +tasks.named("shadowJar").configure { archiveClassifier = null dependencies { exclude(dependency {true}) diff --git a/alts/src/main/java/io/grpc/alts/AltsContext.java b/alts/src/main/java/io/grpc/alts/AltsContext.java index f264ad112d7..7680de4160e 100644 --- a/alts/src/main/java/io/grpc/alts/AltsContext.java +++ b/alts/src/main/java/io/grpc/alts/AltsContext.java @@ -20,7 +20,6 @@ import io.grpc.alts.internal.AltsInternalContext; import io.grpc.alts.internal.HandshakerResult; import io.grpc.alts.internal.Identity; -import io.grpc.alts.internal.SecurityLevel; /** {@code AltsContext} contains security-related information on the ALTS channel. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7864") diff --git a/alts/src/main/java/io/grpc/alts/AltsContextUtil.java b/alts/src/main/java/io/grpc/alts/AltsContextUtil.java index a5d7c0e3ff9..91b06756dc3 100644 --- a/alts/src/main/java/io/grpc/alts/AltsContextUtil.java +++ b/alts/src/main/java/io/grpc/alts/AltsContextUtil.java @@ -26,7 +26,7 @@ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7864") public final class AltsContextUtil { - private AltsContextUtil(){} + private AltsContextUtil() {} /** * Creates a {@link AltsContext} from ALTS context information in the {@link ServerCall}. diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerStub.java b/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerStub.java index 61d9fd2f894..bb2ff9dbc4d 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerStub.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerStub.java @@ -64,12 +64,18 @@ public HandshakerResp send(HandshakerReq req) throws InterruptedException, IOExc if (!responseQueue.isEmpty()) { throw new IOException("Received an unexpected response."); } + writer.onNext(req); Optional result = responseQueue.take(); - if (!result.isPresent()) { - maybeThrowIoException(); + if (result.isPresent()) { + return result.get(); + } + + if (exceptionMessage.get() != null) { + throw new IOException(exceptionMessage.get()); + } else { + throw new IOException("No handshaker response received"); } - return result.get(); } /** Create a new writer if the writer is null. */ diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index edfff2b481f..8df9363f99f 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -39,6 +39,8 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.net.URI; +import java.net.URISyntaxException; import java.security.GeneralSecurityException; import java.util.List; import java.util.logging.Level; @@ -58,14 +60,19 @@ public final class AltsProtocolNegotiator { private static final AsyncSemaphore handshakeSemaphore = new AsyncSemaphore(32); @Grpc.TransportAttr - public static final Attributes.Key TSI_PEER_KEY = Attributes.Key.create("TSI_PEER"); + public static final Attributes.Key TSI_PEER_KEY = + Attributes.Key.create("internal:TSI_PEER"); @Grpc.TransportAttr public static final Attributes.Key AUTH_CONTEXT_KEY = - Attributes.Key.create("AUTH_CONTEXT_KEY"); + Attributes.Key.create("internal:AUTH_CONTEXT_KEY"); private static final AsciiString SCHEME = AsciiString.of("https"); private static final String DIRECT_PATH_SERVICE_CFE_CLUSTER_PREFIX = "google_cfe_"; + private static final String CFE_CLUSTER_RESOURCE_NAME_PREFIX = + "/envoy.config.cluster.v3.Cluster/google_cfe_"; + private static final String CFE_CLUSTER_AUTHORITY_NAME = + "traffic-director-c2p.xds.googleapis.com"; /** * ClientAltsProtocolNegotiatorFactory is a factory for doing client side negotiation of an ALTS @@ -287,11 +294,8 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler securityHandler; boolean isXdsDirectPath = false; if (clusterNameAttrKey != null) { - String clusterName = grpcHandler.getEagAttributes().get(clusterNameAttrKey); - if (clusterName != null - && !clusterName.startsWith(DIRECT_PATH_SERVICE_CFE_CLUSTER_PREFIX)) { - isXdsDirectPath = true; - } + isXdsDirectPath = isDirectPathCluster( + grpcHandler.getEagAttributes().get(clusterNameAttrKey)); } if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null || grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null @@ -311,6 +315,26 @@ gnh, nettyHandshaker, new AltsHandshakeValidator(), handshakeSemaphore, return wuah; } + private boolean isDirectPathCluster(String clusterName) { + if (clusterName == null) { + return false; + } + if (clusterName.startsWith(DIRECT_PATH_SERVICE_CFE_CLUSTER_PREFIX)) { + return false; + } + if (!clusterName.startsWith("xdstp:")) { + return true; + } + try { + URI uri = new URI(clusterName); + // If authority AND path match our CFE checks, use TLS; otherwise use ALTS. + return !CFE_CLUSTER_AUTHORITY_NAME.equals(uri.getHost()) + || !uri.getPath().startsWith(CFE_CLUSTER_RESOURCE_NAME_PREFIX); + } catch (URISyntaxException e) { + return true; // Shouldn't happen, but assume ALTS. + } + } + @Override public void close() { logger.finest("ALTS Server ProtocolNegotiator Closed"); diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java index 2269f0a0fa9..007db9e1eed 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java @@ -169,7 +169,7 @@ public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityExcepti } /** - * Returns true if and only if the handshake is still in progress + * Returns true if and only if the handshake is still in progress. * * @return true, if the handshake is still in progress, false otherwise. */ diff --git a/alts/src/main/java/io/grpc/alts/internal/NettyTsiHandshaker.java b/alts/src/main/java/io/grpc/alts/internal/NettyTsiHandshaker.java index 5087123ab06..b91cfdad08c 100644 --- a/alts/src/main/java/io/grpc/alts/internal/NettyTsiHandshaker.java +++ b/alts/src/main/java/io/grpc/alts/internal/NettyTsiHandshaker.java @@ -99,7 +99,7 @@ boolean processBytesFromPeer(ByteBuf data) throws GeneralSecurityException { } /** - * Returns true if and only if the handshake is still in progress + * Returns true if and only if the handshake is still in progress. * * @return true, if the handshake is still in progress, false otherwise. */ diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshaker.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshaker.java index 35b945770d2..6580a4433c7 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshaker.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshaker.java @@ -68,7 +68,7 @@ public interface TsiHandshaker { boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException; /** - * Returns true if and only if the handshake is still in progress + * Returns true if and only if the handshake is still in progress. * * @return true, if the handshake is still in progress, false otherwise. */ @@ -86,7 +86,7 @@ public interface TsiHandshaker { * * @return the extracted peer. */ - public Object extractPeerObject() throws GeneralSecurityException; + Object extractPeerObject() throws GeneralSecurityException; /** * Creates a frame protector from a completed handshake. No other methods may be called after the diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index f149c4306c6..9a520720beb 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -185,5 +185,47 @@ public void tlsHandler_googleCfe() { XDS_CLUSTER_NAME_ATTR_KEY, "google_cfe_api.googleapis.com").build(); subtest_tlsHandler(attrs); } + + @Test + public void altsHandler_googleCfe_federation() { + Attributes attrs = Attributes.newBuilder().set( + XDS_CLUSTER_NAME_ATTR_KEY, "xdstp1://").build(); + subtest_altsHandler(attrs); + } + + @Test + public void tlsHanlder_googleCfe() { + Attributes attrs = Attributes.newBuilder().set( + XDS_CLUSTER_NAME_ATTR_KEY, + "xdstp://traffic-director-c2p.xds.googleapis.com/" + + "envoy.config.cluster.v3.Cluster/google_cfe_example/apis") + .build(); + subtest_tlsHandler(attrs); + } + + @Test + public void altsHanlder_nonGoogleCfe_authorityNotMatch() { + Attributes attrs = Attributes.newBuilder().set( + XDS_CLUSTER_NAME_ATTR_KEY, + "//example.com/envoy.config.cluster.v3.Cluster/google_cfe_") + .build(); + subtest_altsHandler(attrs); + } + + @Test + public void altsHanlder_nonGoogleCfe_pathNotMatch() { + Attributes attrs = Attributes.newBuilder().set( + XDS_CLUSTER_NAME_ATTR_KEY, + "//traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_gfe") + .build(); + subtest_altsHandler(attrs); + } + + @Test + public void altsHandler_googleCfe_invalidUri() { + Attributes attrs = Attributes.newBuilder().set( + XDS_CLUSTER_NAME_ATTR_KEY, "//").build(); + subtest_altsHandler(attrs); + } } } diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index e40507e6a3b..2db9f513d64 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -56,7 +56,7 @@ android { dependencies { implementation 'androidx.appcompat:appcompat:1.3.0' implementation 'androidx.multidex:multidex:2.0.0' - implementation libraries.androidx_annotation + implementation libraries.androidx.annotation implementation 'com.google.android.gms:play-services-base:18.0.1' implementation project(':grpc-auth'), @@ -68,13 +68,17 @@ dependencies { libraries.hdrhistogram, libraries.junit, libraries.truth, - libraries.opencensus_contrib_grpc_metrics + libraries.opencensus.contrib.grpc.metrics - implementation (libraries.google_auth_oauth2_http) { + implementation (libraries.google.auth.oauth2Http) { exclude group: 'org.apache.httpcomponents' } - compileOnly libraries.javax_annotation + implementation (project(':grpc-services')) { + exclude group: 'com.google.protobuf' + } + + compileOnly libraries.javax.annotation androidTestImplementation 'androidx.test.ext:junit:1.1.3', 'androidx.test:runner:1.4.0' @@ -97,7 +101,7 @@ project.tasks['check'].dependsOn checkStyleMain, checkStyleTest import net.ltgt.gradle.errorprone.CheckSeverity -tasks.withType(JavaCompile) { +tasks.withType(JavaCompile).configureEach { options.compilerArgs += [ "-Xlint:-cast" ] diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index 8754f3819cc..476e9cc03a1 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -18,27 +18,27 @@ private ReconnectServiceGrpc() {} public static final String SERVICE_NAME = "grpc.testing.ReconnectService"; // Static method descriptors that strictly reflect the proto. - private static volatile io.grpc.MethodDescriptor getStartMethod; @io.grpc.stub.annotations.RpcMethod( fullMethodName = SERVICE_NAME + '/' + "Start", - requestType = io.grpc.testing.integration.EmptyProtos.Empty.class, + requestType = io.grpc.testing.integration.Messages.ReconnectParams.class, responseType = io.grpc.testing.integration.EmptyProtos.Empty.class, methodType = io.grpc.MethodDescriptor.MethodType.UNARY) - public static io.grpc.MethodDescriptor getStartMethod() { - io.grpc.MethodDescriptor getStartMethod; + io.grpc.MethodDescriptor getStartMethod; if ((getStartMethod = ReconnectServiceGrpc.getStartMethod) == null) { synchronized (ReconnectServiceGrpc.class) { if ((getStartMethod = ReconnectServiceGrpc.getStartMethod) == null) { ReconnectServiceGrpc.getStartMethod = getStartMethod = - io.grpc.MethodDescriptor.newBuilder() + io.grpc.MethodDescriptor.newBuilder() .setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Start")) .setSampledToLocalTracing(true) .setRequestMarshaller(io.grpc.protobuf.lite.ProtoLiteUtils.marshaller( - io.grpc.testing.integration.EmptyProtos.Empty.getDefaultInstance())) + io.grpc.testing.integration.Messages.ReconnectParams.getDefaultInstance())) .setResponseMarshaller(io.grpc.protobuf.lite.ProtoLiteUtils.marshaller( io.grpc.testing.integration.EmptyProtos.Empty.getDefaultInstance())) .build(); @@ -131,7 +131,7 @@ public static abstract class ReconnectServiceImplBase implements io.grpc.Bindabl /** */ - public void start(io.grpc.testing.integration.EmptyProtos.Empty request, + public void start(io.grpc.testing.integration.Messages.ReconnectParams request, io.grpc.stub.StreamObserver responseObserver) { io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getStartMethod(), responseObserver); } @@ -149,7 +149,7 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, getStartMethod(), io.grpc.stub.ServerCalls.asyncUnaryCall( new MethodHandlers< - io.grpc.testing.integration.EmptyProtos.Empty, + io.grpc.testing.integration.Messages.ReconnectParams, io.grpc.testing.integration.EmptyProtos.Empty>( this, METHODID_START))) .addMethod( @@ -182,7 +182,7 @@ protected ReconnectServiceStub build( /** */ - public void start(io.grpc.testing.integration.EmptyProtos.Empty request, + public void start(io.grpc.testing.integration.Messages.ReconnectParams request, io.grpc.stub.StreamObserver responseObserver) { io.grpc.stub.ClientCalls.asyncUnaryCall( getChannel().newCall(getStartMethod(), getCallOptions()), request, responseObserver); @@ -216,7 +216,7 @@ protected ReconnectServiceBlockingStub build( /** */ - public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.EmptyProtos.Empty request) { + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) { return io.grpc.stub.ClientCalls.blockingUnaryCall( getChannel(), getStartMethod(), getCallOptions(), request); } @@ -249,7 +249,7 @@ protected ReconnectServiceFutureStub build( /** */ public com.google.common.util.concurrent.ListenableFuture start( - io.grpc.testing.integration.EmptyProtos.Empty request) { + io.grpc.testing.integration.Messages.ReconnectParams request) { return io.grpc.stub.ClientCalls.futureUnaryCall( getChannel().newCall(getStartMethod(), getCallOptions()), request); } @@ -284,7 +284,7 @@ private static final class MethodHandlers implements public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { switch (methodId) { case METHODID_START: - serviceImpl.start((io.grpc.testing.integration.EmptyProtos.Empty) request, + serviceImpl.start((io.grpc.testing.integration.Messages.ReconnectParams) request, (io.grpc.stub.StreamObserver) responseObserver); break; case METHODID_STOP: diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index 8754f3819cc..476e9cc03a1 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -18,27 +18,27 @@ private ReconnectServiceGrpc() {} public static final String SERVICE_NAME = "grpc.testing.ReconnectService"; // Static method descriptors that strictly reflect the proto. - private static volatile io.grpc.MethodDescriptor getStartMethod; @io.grpc.stub.annotations.RpcMethod( fullMethodName = SERVICE_NAME + '/' + "Start", - requestType = io.grpc.testing.integration.EmptyProtos.Empty.class, + requestType = io.grpc.testing.integration.Messages.ReconnectParams.class, responseType = io.grpc.testing.integration.EmptyProtos.Empty.class, methodType = io.grpc.MethodDescriptor.MethodType.UNARY) - public static io.grpc.MethodDescriptor getStartMethod() { - io.grpc.MethodDescriptor getStartMethod; + io.grpc.MethodDescriptor getStartMethod; if ((getStartMethod = ReconnectServiceGrpc.getStartMethod) == null) { synchronized (ReconnectServiceGrpc.class) { if ((getStartMethod = ReconnectServiceGrpc.getStartMethod) == null) { ReconnectServiceGrpc.getStartMethod = getStartMethod = - io.grpc.MethodDescriptor.newBuilder() + io.grpc.MethodDescriptor.newBuilder() .setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Start")) .setSampledToLocalTracing(true) .setRequestMarshaller(io.grpc.protobuf.lite.ProtoLiteUtils.marshaller( - io.grpc.testing.integration.EmptyProtos.Empty.getDefaultInstance())) + io.grpc.testing.integration.Messages.ReconnectParams.getDefaultInstance())) .setResponseMarshaller(io.grpc.protobuf.lite.ProtoLiteUtils.marshaller( io.grpc.testing.integration.EmptyProtos.Empty.getDefaultInstance())) .build(); @@ -131,7 +131,7 @@ public static abstract class ReconnectServiceImplBase implements io.grpc.Bindabl /** */ - public void start(io.grpc.testing.integration.EmptyProtos.Empty request, + public void start(io.grpc.testing.integration.Messages.ReconnectParams request, io.grpc.stub.StreamObserver responseObserver) { io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getStartMethod(), responseObserver); } @@ -149,7 +149,7 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, getStartMethod(), io.grpc.stub.ServerCalls.asyncUnaryCall( new MethodHandlers< - io.grpc.testing.integration.EmptyProtos.Empty, + io.grpc.testing.integration.Messages.ReconnectParams, io.grpc.testing.integration.EmptyProtos.Empty>( this, METHODID_START))) .addMethod( @@ -182,7 +182,7 @@ protected ReconnectServiceStub build( /** */ - public void start(io.grpc.testing.integration.EmptyProtos.Empty request, + public void start(io.grpc.testing.integration.Messages.ReconnectParams request, io.grpc.stub.StreamObserver responseObserver) { io.grpc.stub.ClientCalls.asyncUnaryCall( getChannel().newCall(getStartMethod(), getCallOptions()), request, responseObserver); @@ -216,7 +216,7 @@ protected ReconnectServiceBlockingStub build( /** */ - public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.EmptyProtos.Empty request) { + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) { return io.grpc.stub.ClientCalls.blockingUnaryCall( getChannel(), getStartMethod(), getCallOptions(), request); } @@ -249,7 +249,7 @@ protected ReconnectServiceFutureStub build( /** */ public com.google.common.util.concurrent.ListenableFuture start( - io.grpc.testing.integration.EmptyProtos.Empty request) { + io.grpc.testing.integration.Messages.ReconnectParams request) { return io.grpc.stub.ClientCalls.futureUnaryCall( getChannel().newCall(getStartMethod(), getCallOptions()), request); } @@ -284,7 +284,7 @@ private static final class MethodHandlers implements public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { switch (methodId) { case METHODID_START: - serviceImpl.start((io.grpc.testing.integration.EmptyProtos.Empty) request, + serviceImpl.start((io.grpc.testing.integration.Messages.ReconnectParams) request, (io.grpc.stub.StreamObserver) responseObserver); break; case METHODID_STOP: diff --git a/android/build.gradle b/android/build.gradle index 7e32ac5ca2f..3bbff43869c 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -33,7 +33,7 @@ dependencies { api project(':grpc-core') implementation libraries.guava testImplementation project('::grpc-okhttp') - testImplementation libraries.androidx_test + testImplementation libraries.androidx.test.core testImplementation libraries.junit testImplementation (libraries.robolectric) { // Unreleased change: https://github.com/robolectric/robolectric/pull/5432 @@ -42,7 +42,7 @@ dependencies { testImplementation libraries.truth } -task javadocs(type: Javadoc) { +tasks.register("javadocs", Javadoc) { source = android.sourceSets.main.java.srcDirs classpath += files(android.getBootClasspath()) classpath += files({ @@ -58,12 +58,13 @@ task javadocs(type: Javadoc) { } } -task javadocJar(type: Jar, dependsOn: javadocs) { +tasks.register("javadocJar", Jar) { + dependsOn javadocs archiveClassifier = 'javadoc' from javadocs.destinationDir } -task sourcesJar(type: Jar) { +tasks.register("sourcesJar", Jar) { archiveClassifier = 'sources' from android.sourceSets.main.java.srcDirs } diff --git a/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java b/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java index dadab1c830c..f23a73fbad9 100644 --- a/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java +++ b/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java @@ -33,8 +33,10 @@ import io.grpc.ConnectivityState; import io.grpc.ExperimentalApi; import io.grpc.ForwardingChannelBuilder; +import io.grpc.InternalManagedChannelProvider; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.ManagedChannelProvider; import io.grpc.MethodDescriptor; import io.grpc.internal.GrpcUtil; import java.util.concurrent.TimeUnit; @@ -55,14 +57,35 @@ public final class AndroidChannelBuilder extends ForwardingChannelBuilder OKHTTP_CHANNEL_BUILDER_CLASS = findOkHttp(); + @Nullable private static final ManagedChannelProvider OKHTTP_CHANNEL_PROVIDER = findOkHttp(); - private static Class findOkHttp() { + private static ManagedChannelProvider findOkHttp() { + Class klassRaw; try { - return Class.forName("io.grpc.okhttp.OkHttpChannelBuilder"); + klassRaw = Class.forName("io.grpc.okhttp.OkHttpChannelProvider"); } catch (ClassNotFoundException e) { + Log.w(LOG_TAG, "Failed to find OkHttpChannelProvider", e); return null; } + Class klass; + try { + klass = klassRaw.asSubclass(ManagedChannelProvider.class); + } catch (ClassCastException e) { + Log.w(LOG_TAG, "Couldn't cast OkHttpChannelProvider to ManagedChannelProvider", e); + return null; + } + ManagedChannelProvider provider; + try { + provider = klass.getConstructor().newInstance(); + } catch (Exception e) { + Log.w(LOG_TAG, "Failed to construct OkHttpChannelProvider", e); + return null; + } + if (!InternalManagedChannelProvider.isAvailable(provider)) { + Log.w(LOG_TAG, "OkHttpChannelProvider.isAvailable() returned false"); + return null; + } + return provider; } private final ManagedChannelBuilder delegateBuilder; @@ -113,18 +136,11 @@ public static AndroidChannelBuilder usingBuilder(ManagedChannelBuilder builde } private AndroidChannelBuilder(String target) { - if (OKHTTP_CHANNEL_BUILDER_CLASS == null) { - throw new UnsupportedOperationException("No ManagedChannelBuilder found on the classpath"); - } - try { - delegateBuilder = - (ManagedChannelBuilder) - OKHTTP_CHANNEL_BUILDER_CLASS - .getMethod("forTarget", String.class) - .invoke(null, target); - } catch (Exception e) { - throw new RuntimeException("Failed to create ManagedChannelBuilder", e); + if (OKHTTP_CHANNEL_PROVIDER == null) { + throw new UnsupportedOperationException("Unable to load OkHttpChannelProvider"); } + delegateBuilder = + InternalManagedChannelProvider.builderForTarget(OKHTTP_CHANNEL_PROVIDER, target); } private AndroidChannelBuilder(ManagedChannelBuilder delegateBuilder) { diff --git a/api/build.gradle b/api/build.gradle index 9f5e6163153..05cd80674c3 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -13,22 +13,22 @@ evaluationDependsOn(project(':grpc-context').path) dependencies { api project(':grpc-context'), libraries.jsr305, - libraries.errorprone + libraries.errorprone.annotations implementation libraries.guava testImplementation project(':grpc-context').sourceSets.test.output, project(':grpc-testing'), project(':grpc-grpclb') - testImplementation (libraries.guava_testlib) { + testImplementation (libraries.guava.testlib) { exclude group: 'junit', module: 'junit' } jmh project(':grpc-core') - signature "org.codehaus.mojo.signature:java17:1.0@signature" - signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" + signature libraries.signature.java + signature libraries.signature.android } -javadoc { +tasks.named("javadoc").configure { // We want io.grpc.Internal, but not io.grpc.Internal* exclude 'io/grpc/Internal?*.java' } diff --git a/api/src/main/java/io/grpc/Attributes.java b/api/src/main/java/io/grpc/Attributes.java index 26c2a72f909..461ed04d381 100644 --- a/api/src/main/java/io/grpc/Attributes.java +++ b/api/src/main/java/io/grpc/Attributes.java @@ -110,7 +110,8 @@ public Builder toBuilder() { } /** - * Key for an key-value pair. + * Key for an key-value pair. Uses reference equality. + * * @param type of the value in the key-value pair */ @Immutable diff --git a/api/src/main/java/io/grpc/CallCredentials.java b/api/src/main/java/io/grpc/CallCredentials.java index 3e588e24027..1d353fb9f6d 100644 --- a/api/src/main/java/io/grpc/CallCredentials.java +++ b/api/src/main/java/io/grpc/CallCredentials.java @@ -91,6 +91,13 @@ public abstract static class RequestInfo { */ public abstract MethodDescriptor getMethodDescriptor(); + /** + * The call options used to call this RPC. + */ + public CallOptions getCallOptions() { + throw new UnsupportedOperationException("Not implemented"); + } + /** * The security level on the transport. */ diff --git a/api/src/main/java/io/grpc/CallOptions.java b/api/src/main/java/io/grpc/CallOptions.java index 5c05d5b7bd7..4b180c56e07 100644 --- a/api/src/main/java/io/grpc/CallOptions.java +++ b/api/src/main/java/io/grpc/CallOptions.java @@ -41,42 +41,75 @@ public final class CallOptions { /** * A blank {@code CallOptions} that all fields are not set. */ - public static final CallOptions DEFAULT = new CallOptions(); + public static final CallOptions DEFAULT; + + static { + Builder b = new Builder(); + b.customOptions = new Object[0][2]; + b.streamTracerFactories = Collections.emptyList(); + DEFAULT = b.build(); + } - // Although {@code CallOptions} is immutable, its fields are not final, so that we can initialize - // them outside of constructor. Otherwise the constructor will have a potentially long list of - // unnamed arguments, which is undesirable. @Nullable - private Deadline deadline; - + private final Deadline deadline; + @Nullable - private Executor executor; + private final Executor executor; @Nullable - private String authority; + private final String authority; @Nullable - private CallCredentials credentials; + private final CallCredentials credentials; @Nullable - private String compressorName; + private final String compressorName; - private Object[][] customOptions; + private final Object[][] customOptions; - // Unmodifiable list - private List streamTracerFactories = Collections.emptyList(); + private final List streamTracerFactories; /** * Opposite to fail fast. */ @Nullable - private Boolean waitForReady; + private final Boolean waitForReady; @Nullable - private Integer maxInboundMessageSize; + private final Integer maxInboundMessageSize; @Nullable - private Integer maxOutboundMessageSize; + private final Integer maxOutboundMessageSize; + + private CallOptions(Builder builder) { + this.deadline = builder.deadline; + this.executor = builder.executor; + this.authority = builder.authority; + this.credentials = builder.credentials; + this.compressorName = builder.compressorName; + this.customOptions = builder.customOptions; + this.streamTracerFactories = builder.streamTracerFactories; + this.waitForReady = builder.waitForReady; + this.maxInboundMessageSize = builder.maxInboundMessageSize; + this.maxOutboundMessageSize = builder.maxOutboundMessageSize; + } + static class Builder { + Deadline deadline; + Executor executor; + String authority; + CallCredentials credentials; + String compressorName; + Object[][] customOptions; + // Unmodifiable list + List streamTracerFactories; + Boolean waitForReady; + Integer maxInboundMessageSize; + Integer maxOutboundMessageSize; + + private CallOptions build() { + return new CallOptions(this); + } + } /** * Override the HTTP/2 authority the channel claims to be connecting to. This is not @@ -89,18 +122,18 @@ public final class CallOptions { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1767") public CallOptions withAuthority(@Nullable String authority) { - CallOptions newOptions = new CallOptions(this); - newOptions.authority = authority; - return newOptions; + Builder builder = toBuilder(this); + builder.authority = authority; + return builder.build(); } /** * Returns a new {@code CallOptions} with the given call credentials. */ public CallOptions withCallCredentials(@Nullable CallCredentials credentials) { - CallOptions newOptions = new CallOptions(this); - newOptions.credentials = credentials; - return newOptions; + Builder builder = toBuilder(this); + builder.credentials = credentials; + return builder.build(); } /** @@ -113,9 +146,9 @@ public CallOptions withCallCredentials(@Nullable CallCredentials credentials) { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1704") public CallOptions withCompression(@Nullable String compressorName) { - CallOptions newOptions = new CallOptions(this); - newOptions.compressorName = compressorName; - return newOptions; + Builder builder = toBuilder(this); + builder.compressorName = compressorName; + return builder.build(); } /** @@ -127,9 +160,9 @@ public CallOptions withCompression(@Nullable String compressorName) { * @param deadline the deadline or {@code null} for unsetting the deadline. */ public CallOptions withDeadline(@Nullable Deadline deadline) { - CallOptions newOptions = new CallOptions(this); - newOptions.deadline = deadline; - return newOptions; + Builder builder = toBuilder(this); + builder.deadline = deadline; + return builder.build(); } /** @@ -156,9 +189,9 @@ public Deadline getDeadline() { * fails RPCs without sending them if unable to connect. */ public CallOptions withWaitForReady() { - CallOptions newOptions = new CallOptions(this); - newOptions.waitForReady = Boolean.TRUE; - return newOptions; + Builder builder = toBuilder(this); + builder.waitForReady = Boolean.TRUE; + return builder.build(); } /** @@ -166,9 +199,9 @@ public CallOptions withWaitForReady() { * This method should be rarely used because the default is without 'wait for ready'. */ public CallOptions withoutWaitForReady() { - CallOptions newOptions = new CallOptions(this); - newOptions.waitForReady = Boolean.FALSE; - return newOptions; + Builder builder = toBuilder(this); + builder.waitForReady = Boolean.FALSE; + return builder.build(); } /** @@ -208,9 +241,9 @@ public CallCredentials getCredentials() { * executor specified with {@link ManagedChannelBuilder#executor}. */ public CallOptions withExecutor(@Nullable Executor executor) { - CallOptions newOptions = new CallOptions(this); - newOptions.executor = executor; - return newOptions; + Builder builder = toBuilder(this); + builder.executor = executor; + return builder.build(); } /** @@ -221,13 +254,13 @@ public CallOptions withExecutor(@Nullable Executor executor) { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") public CallOptions withStreamTracerFactory(ClientStreamTracer.Factory factory) { - CallOptions newOptions = new CallOptions(this); ArrayList newList = new ArrayList<>(streamTracerFactories.size() + 1); newList.addAll(streamTracerFactories); newList.add(factory); - newOptions.streamTracerFactories = Collections.unmodifiableList(newList); - return newOptions; + Builder builder = toBuilder(this); + builder.streamTracerFactories = Collections.unmodifiableList(newList); + return builder.build(); } /** @@ -319,7 +352,7 @@ public CallOptions withOption(Key key, T value) { Preconditions.checkNotNull(key, "key"); Preconditions.checkNotNull(value, "value"); - CallOptions newOptions = new CallOptions(this); + Builder builder = toBuilder(this); int existingIdx = -1; for (int i = 0; i < customOptions.length; i++) { if (key.equals(customOptions[i][0])) { @@ -328,18 +361,18 @@ public CallOptions withOption(Key key, T value) { } } - newOptions.customOptions = new Object[customOptions.length + (existingIdx == -1 ? 1 : 0)][2]; - System.arraycopy(customOptions, 0, newOptions.customOptions, 0, customOptions.length); + builder.customOptions = new Object[customOptions.length + (existingIdx == -1 ? 1 : 0)][2]; + System.arraycopy(customOptions, 0, builder.customOptions, 0, customOptions.length); if (existingIdx == -1) { // Add a new option - newOptions.customOptions[customOptions.length] = new Object[] {key, value}; + builder.customOptions[customOptions.length] = new Object[] {key, value}; } else { // Replace an existing option - newOptions.customOptions[existingIdx] = new Object[] {key, value}; + builder.customOptions[existingIdx] = new Object[] {key, value}; } - return newOptions; + return builder.build(); } /** @@ -368,10 +401,6 @@ public Executor getExecutor() { return executor; } - private CallOptions() { - customOptions = new Object[0][2]; - } - /** * Returns whether * 'wait for ready' option is enabled for the call. 'Fail fast' is the default option for gRPC @@ -392,9 +421,9 @@ Boolean getWaitForReady() { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2563") public CallOptions withMaxInboundMessageSize(int maxSize) { checkArgument(maxSize >= 0, "invalid maxsize %s", maxSize); - CallOptions newOptions = new CallOptions(this); - newOptions.maxInboundMessageSize = maxSize; - return newOptions; + Builder builder = toBuilder(this); + builder.maxInboundMessageSize = maxSize; + return builder.build(); } /** @@ -403,9 +432,9 @@ public CallOptions withMaxInboundMessageSize(int maxSize) { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2563") public CallOptions withMaxOutboundMessageSize(int maxSize) { checkArgument(maxSize >= 0, "invalid maxsize %s", maxSize); - CallOptions newOptions = new CallOptions(this); - newOptions.maxOutboundMessageSize = maxSize; - return newOptions; + Builder builder = toBuilder(this); + builder.maxOutboundMessageSize = maxSize; + return builder.build(); } /** @@ -427,19 +456,21 @@ public Integer getMaxOutboundMessageSize() { } /** - * Copy constructor. + * Copy CallOptions. */ - private CallOptions(CallOptions other) { - deadline = other.deadline; - authority = other.authority; - credentials = other.credentials; - executor = other.executor; - compressorName = other.compressorName; - customOptions = other.customOptions; - waitForReady = other.waitForReady; - maxInboundMessageSize = other.maxInboundMessageSize; - maxOutboundMessageSize = other.maxOutboundMessageSize; - streamTracerFactories = other.streamTracerFactories; + private static Builder toBuilder(CallOptions other) { + Builder builder = new Builder(); + builder.deadline = other.deadline; + builder.executor = other.executor; + builder.authority = other.authority; + builder.credentials = other.credentials; + builder.compressorName = other.compressorName; + builder.customOptions = other.customOptions; + builder.streamTracerFactories = other.streamTracerFactories; + builder.waitForReady = other.waitForReady; + builder.maxInboundMessageSize = other.maxInboundMessageSize; + builder.maxOutboundMessageSize = other.maxOutboundMessageSize; + return builder; } @Override diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index 34b2957d836..969c82e010b 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -40,11 +40,16 @@ public final class EquivalentAddressGroup { * However, if the channel has overridden authority via * {@link ManagedChannelBuilder#overrideAuthority(String)}, the transport will use the channel's * authority override. + * + *

The authority must be from a trusted source, because if the authority is + * tampered with, RPCs may be sent to attackers which may leak sensitive user data. If the + * authority was acquired by doing I/O, the communication must be authenticated (e.g., via TLS). + * Recognize that the server that provided the authority can trivially impersonate the service. */ @Attr @ExperimentalApi("https://github.com/grpc/grpc-java/issues/6138") public static final Attributes.Key ATTR_AUTHORITY_OVERRIDE = - Attributes.Key.create("io.grpc.EquivalentAddressGroup.authorityOverride"); + Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE"); private final List addrs; private final Attributes attrs; diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java index 8ea355de6e9..0be219bfdb3 100644 --- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java @@ -133,6 +133,48 @@ public T handshakeTimeout(long timeout, TimeUnit unit) { return thisT(); } + @Override + public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + delegate().keepAliveTime(keepAliveTime, timeUnit); + return thisT(); + } + + @Override + public T keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { + delegate().keepAliveTimeout(keepAliveTimeout, timeUnit); + return thisT(); + } + + @Override + public T maxConnectionIdle(long maxConnectionIdle, TimeUnit timeUnit) { + delegate().maxConnectionIdle(maxConnectionIdle, timeUnit); + return thisT(); + } + + @Override + public T maxConnectionAge(long maxConnectionAge, TimeUnit timeUnit) { + delegate().maxConnectionAge(maxConnectionAge, timeUnit); + return thisT(); + } + + @Override + public T maxConnectionAgeGrace(long maxConnectionAgeGrace, TimeUnit timeUnit) { + delegate().maxConnectionAgeGrace(maxConnectionAgeGrace, timeUnit); + return thisT(); + } + + @Override + public T permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + delegate().permitKeepAliveTime(keepAliveTime, timeUnit); + return thisT(); + } + + @Override + public T permitKeepAliveWithoutCalls(boolean permit) { + delegate().permitKeepAliveWithoutCalls(permit); + return thisT(); + } + @Override public T maxInboundMessageSize(int bytes) { delegate().maxInboundMessageSize(bytes); diff --git a/api/src/main/java/io/grpc/GlobalInterceptors.java b/api/src/main/java/io/grpc/GlobalInterceptors.java new file mode 100644 index 00000000000..e5fd86170f0 --- /dev/null +++ b/api/src/main/java/io/grpc/GlobalInterceptors.java @@ -0,0 +1,93 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** The collection of global interceptors and global server stream tracers. */ +@Internal +final class GlobalInterceptors { + private static List clientInterceptors = null; + private static List serverInterceptors = null; + private static List serverStreamTracerFactories = + null; + private static boolean isGlobalInterceptorsTracersSet; + private static boolean isGlobalInterceptorsTracersGet; + + // Prevent instantiation + private GlobalInterceptors() {} + + /** + * Sets the list of global interceptors and global server stream tracers. + * + *

If {@code setInterceptorsTracers()} is called again, this method will throw {@link + * IllegalStateException}. + * + *

It is only safe to call early. This method throws {@link IllegalStateException} after any of + * the get calls [{@link #getClientInterceptors()}, {@link #getServerInterceptors()} or {@link + * #getServerStreamTracerFactories()}] has been called, in order to limit changes to the result of + * {@code setInterceptorsTracers()}. + * + * @param clientInterceptorList list of {@link ClientInterceptor} that make up global Client + * Interceptors. + * @param serverInterceptorList list of {@link ServerInterceptor} that make up global Server + * Interceptors. + * @param serverStreamTracerFactoryList list of {@link ServerStreamTracer.Factory} that make up + * global ServerStreamTracer factories. + */ + static synchronized void setInterceptorsTracers( + List clientInterceptorList, + List serverInterceptorList, + List serverStreamTracerFactoryList) { + if (isGlobalInterceptorsTracersGet) { + throw new IllegalStateException("Set cannot be called after any get call"); + } + if (isGlobalInterceptorsTracersSet) { + throw new IllegalStateException("Global interceptors and tracers are already set"); + } + checkNotNull(clientInterceptorList); + checkNotNull(serverInterceptorList); + checkNotNull(serverStreamTracerFactoryList); + clientInterceptors = Collections.unmodifiableList(new ArrayList<>(clientInterceptorList)); + serverInterceptors = Collections.unmodifiableList(new ArrayList<>(serverInterceptorList)); + serverStreamTracerFactories = + Collections.unmodifiableList(new ArrayList<>(serverStreamTracerFactoryList)); + isGlobalInterceptorsTracersSet = true; + } + + /** Returns the list of global {@link ClientInterceptor}. If not set, this returns null. */ + static synchronized List getClientInterceptors() { + isGlobalInterceptorsTracersGet = true; + return clientInterceptors; + } + + /** Returns list of global {@link ServerInterceptor}. If not set, this returns null. */ + static synchronized List getServerInterceptors() { + isGlobalInterceptorsTracersGet = true; + return serverInterceptors; + } + + /** Returns list of global {@link ServerStreamTracer.Factory}. If not set, this returns null. */ + static synchronized List getServerStreamTracerFactories() { + isGlobalInterceptorsTracersGet = true; + return serverStreamTracerFactories; + } +} diff --git a/api/src/main/java/io/grpc/Grpc.java b/api/src/main/java/io/grpc/Grpc.java index 7855c7db2a4..baa9f5f0ab6 100644 --- a/api/src/main/java/io/grpc/Grpc.java +++ b/api/src/main/java/io/grpc/Grpc.java @@ -38,7 +38,7 @@ private Grpc() { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1710") @TransportAttr public static final Attributes.Key TRANSPORT_ATTR_REMOTE_ADDR = - Attributes.Key.create("remote-addr"); + Attributes.Key.create("io.grpc.Grpc.TRANSPORT_ATTR_REMOTE_ADDR"); /** * Attribute key for the local address of a transport. @@ -46,7 +46,7 @@ private Grpc() { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1710") @TransportAttr public static final Attributes.Key TRANSPORT_ATTR_LOCAL_ADDR = - Attributes.Key.create("local-addr"); + Attributes.Key.create("io.grpc.Grpc.TRANSPORT_ATTR_LOCAL_ADDR"); /** * Attribute key for SSL session of a transport. @@ -54,7 +54,7 @@ private Grpc() { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1710") @TransportAttr public static final Attributes.Key TRANSPORT_ATTR_SSL_SESSION = - Attributes.Key.create("ssl-session"); + Attributes.Key.create("io.grpc.Grpc.TRANSPORT_ATTR_SSL_SESSION"); /** * Annotation for transport attributes. It follows the annotation semantics defined diff --git a/api/src/main/java/io/grpc/InternalConfigSelector.java b/api/src/main/java/io/grpc/InternalConfigSelector.java index fa8ac5c9c60..38856f440b4 100644 --- a/api/src/main/java/io/grpc/InternalConfigSelector.java +++ b/api/src/main/java/io/grpc/InternalConfigSelector.java @@ -32,7 +32,7 @@ public abstract class InternalConfigSelector { @NameResolver.ResolutionResultAttr public static final Attributes.Key KEY - = Attributes.Key.create("io.grpc.config-selector"); + = Attributes.Key.create("internal:io.grpc.config-selector"); // Use PickSubchannelArgs for SelectConfigArgs for now. May change over time. /** Selects the config for an PRC. */ diff --git a/api/src/main/java/io/grpc/InternalGlobalInterceptors.java b/api/src/main/java/io/grpc/InternalGlobalInterceptors.java new file mode 100644 index 00000000000..db0ff6e2ce9 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalGlobalInterceptors.java @@ -0,0 +1,46 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** Accessor to internal methods of {@link GlobalInterceptors}. */ +@Internal +public final class InternalGlobalInterceptors { + + public static void setInterceptorsTracers( + List clientInterceptorList, + List serverInterceptorList, + List serverStreamTracerFactoryList) { + GlobalInterceptors.setInterceptorsTracers( + clientInterceptorList, serverInterceptorList, serverStreamTracerFactoryList); + } + + public static List getClientInterceptors() { + return GlobalInterceptors.getClientInterceptors(); + } + + public static List getServerInterceptors() { + return GlobalInterceptors.getServerInterceptors(); + } + + public static List getServerStreamTracerFactories() { + return GlobalInterceptors.getServerStreamTracerFactories(); + } + + private InternalGlobalInterceptors() {} +} diff --git a/api/src/main/java/io/grpc/InternalManagedChannelProvider.java b/api/src/main/java/io/grpc/InternalManagedChannelProvider.java index 076b5464b7e..2b22e6013ed 100644 --- a/api/src/main/java/io/grpc/InternalManagedChannelProvider.java +++ b/api/src/main/java/io/grpc/InternalManagedChannelProvider.java @@ -25,6 +25,10 @@ public final class InternalManagedChannelProvider { private InternalManagedChannelProvider() { } + public static boolean isAvailable(ManagedChannelProvider provider) { + return provider.isAvailable(); + } + public static ManagedChannelBuilder builderForAddress(ManagedChannelProvider provider, String name, int port) { return provider.builderForAddress(name, port); diff --git a/api/src/main/java/io/grpc/InternalMayRequireSpecificExecutor.java b/api/src/main/java/io/grpc/InternalMayRequireSpecificExecutor.java new file mode 100644 index 00000000000..0be807ce9dc --- /dev/null +++ b/api/src/main/java/io/grpc/InternalMayRequireSpecificExecutor.java @@ -0,0 +1,22 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +@Internal +public interface InternalMayRequireSpecificExecutor { + boolean isSpecificExecutorRequired(); +} diff --git a/api/src/main/java/io/grpc/InternalMetadata.java b/api/src/main/java/io/grpc/InternalMetadata.java index 2823882952f..cbf6b72aaf0 100644 --- a/api/src/main/java/io/grpc/InternalMetadata.java +++ b/api/src/main/java/io/grpc/InternalMetadata.java @@ -19,6 +19,7 @@ import com.google.common.io.BaseEncoding; import io.grpc.Metadata.AsciiMarshaller; import io.grpc.Metadata.BinaryStreamMarshaller; +import java.io.InputStream; import java.nio.charset.Charset; /** @@ -100,7 +101,7 @@ public static Object[] serializePartial(Metadata md) { /** * Creates a holder for a pre-parsed value read by the transport. * - * @param marshaller The {@link Metadata#BinaryStreamMarshaller} associated with this value. + * @param marshaller The {@link Metadata.BinaryStreamMarshaller} associated with this value. * @param value The value to store. * @return an object holding the pre-parsed value for this key. */ diff --git a/api/src/main/java/io/grpc/InternalStatus.java b/api/src/main/java/io/grpc/InternalStatus.java index 9f6854a2de2..b6549bb435f 100644 --- a/api/src/main/java/io/grpc/InternalStatus.java +++ b/api/src/main/java/io/grpc/InternalStatus.java @@ -16,6 +16,8 @@ package io.grpc; +import javax.annotation.Nullable; + /** * Accesses internal data. Do not use this. */ @@ -34,4 +36,14 @@ private InternalStatus() {} */ @Internal public static final Metadata.Key CODE_KEY = Status.CODE_KEY; + + /** + * Create a new {@link StatusRuntimeException} with the internal option of skipping the filling + * of the stack trace. + */ + @Internal + public static final StatusRuntimeException asRuntimeException(Status status, + @Nullable Metadata trailers, boolean fillInStackTrace) { + return new StatusRuntimeException(status, trailers, fillInStackTrace); + } } diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 4a39ce9d40f..6469e33b907 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -114,7 +114,7 @@ public abstract class LoadBalancer { @Internal @NameResolver.ResolutionResultAttr public static final Attributes.Key> ATTR_HEALTH_CHECKING_CONFIG = - Attributes.Key.create("health-checking-config"); + Attributes.Key.create("internal:health-checking-config"); private int recursionCount; /** @@ -124,39 +124,46 @@ public abstract class LoadBalancer { * *

Implementations should not modify the given {@code servers}. * - * @param servers the resolved server addresses, never empty. - * @param attributes extra information from naming system. - * @deprecated override {@link #handleResolvedAddresses(ResolvedAddresses) instead} - * @since 1.2.0 + * @param resolvedAddresses the resolved server addresses, attributes, and config. + * @since 1.21.0 */ - @Deprecated - public void handleResolvedAddressGroups( - List servers, - @NameResolver.ResolutionResultAttr Attributes attributes) { + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (recursionCount++ == 0) { - handleResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attributes).build()); + // Note that the information about the addresses actually being accepted will be lost + // if you rely on this method for backward compatibility. + acceptResolvedAddresses(resolvedAddresses); } recursionCount = 0; } /** - * Handles newly resolved server groups and metadata attributes from name resolution system. - * {@code servers} contained in {@link EquivalentAddressGroup} should be considered equivalent - * but may be flattened into a single list if needed. + * Accepts newly resolved addresses from the name resolution system. The {@link + * EquivalentAddressGroup} addresses should be considered equivalent but may be flattened into a + * single list if needed. * - *

Implementations should not modify the given {@code servers}. + *

Implementations can choose to reject the given addresses by returning {@code false}. + * + *

Implementations should not modify the given {@code addresses}. * * @param resolvedAddresses the resolved server addresses, attributes, and config. - * @since 1.21.0 + * @return {@code true} if the resolved addresses were accepted. {@code false} if rejected. + * @since 1.49.0 */ - @SuppressWarnings("deprecation") - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (recursionCount++ == 0) { - handleResolvedAddressGroups( - resolvedAddresses.getAddresses(), resolvedAddresses.getAttributes()); + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getAddresses().isEmpty() + && !canHandleEmptyAddressListFromNameResolution()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } else { + if (recursionCount++ == 0) { + handleResolvedAddresses(resolvedAddresses); + } + recursionCount = 0; + + return true; } - recursionCount = 0; } /** @@ -1073,8 +1080,10 @@ public void refreshNameResolution() { * that need to be updated for the new expected behavior. * * @since 1.38.0 + * @deprecated Warning has been removed */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8088") + @Deprecated public void ignoreRefreshNameResolutionCheck() { // no-op } @@ -1356,13 +1365,19 @@ public interface SubchannelStateListener { * unnecessary delays of RPCs. Please refer to {@link PickResult#withSubchannel * PickResult.withSubchannel()}'s javadoc for more information. * + *

When a subchannel's state is IDLE or TRANSIENT_FAILURE and the address for the subchannel + * was received in {@link LoadBalancer#handleResolvedAddresses}, load balancers should call + * {@link Helper#refreshNameResolution} to inform polling name resolvers that it is an + * appropriate time to refresh the addresses. Without the refresh, changes to the addresses may + * never be detected. + * *

SHUTDOWN can only happen in two cases. One is that LoadBalancer called {@link * Subchannel#shutdown} earlier, thus it should have already discarded this Subchannel. The * other is that Channel is doing a {@link ManagedChannel#shutdownNow forced shutdown} or has * already terminated, thus there won't be further requests to LoadBalancer. Therefore, the * LoadBalancer usually don't need to react to a SHUTDOWN state. - * @param newState the new state * + * @param newState the new state * @since 1.22.0 */ void onSubchannelState(ConnectivityStateInfo newState); diff --git a/api/src/main/java/io/grpc/LoadBalancerRegistry.java b/api/src/main/java/io/grpc/LoadBalancerRegistry.java index a215c4108d1..f6b69f978b8 100644 --- a/api/src/main/java/io/grpc/LoadBalancerRegistry.java +++ b/api/src/main/java/io/grpc/LoadBalancerRegistry.java @@ -107,9 +107,7 @@ public static synchronized LoadBalancerRegistry getDefaultRegistry() { instance = new LoadBalancerRegistry(); for (LoadBalancerProvider provider : providerList) { logger.fine("Service loader found " + provider); - if (provider.isAvailable()) { - instance.addProvider(provider); - } + instance.addProvider(provider); } instance.refreshProviderMap(); } diff --git a/api/src/main/java/io/grpc/ManagedChannelProvider.java b/api/src/main/java/io/grpc/ManagedChannelProvider.java index f57340d9ba9..42941dfc809 100644 --- a/api/src/main/java/io/grpc/ManagedChannelProvider.java +++ b/api/src/main/java/io/grpc/ManagedChannelProvider.java @@ -17,6 +17,8 @@ package io.grpc; import com.google.common.base.Preconditions; +import java.net.SocketAddress; +import java.util.Collection; /** * Provider of managed channels for transport agnostic consumption. @@ -79,6 +81,11 @@ protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCreden return NewChannelBuilderResult.error("ChannelCredentials are unsupported"); } + /** + * Returns the {@link SocketAddress} types this ManagedChannelProvider supports. + */ + protected abstract Collection> getSupportedSocketAddressTypes(); + public static final class NewChannelBuilderResult { private final ManagedChannelBuilder channelBuilder; private final String error; diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index 8eb1cce14ac..04bdc6b0d57 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -18,7 +18,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashSet; @@ -101,9 +106,7 @@ public static synchronized ManagedChannelRegistry getDefaultRegistry() { instance = new ManagedChannelRegistry(); for (ManagedChannelProvider provider : providerList) { logger.fine("Service loader found " + provider); - if (provider.isAvailable()) { - instance.addProvider(provider); - } + instance.addProvider(provider); } instance.refreshProviders(); } @@ -140,10 +143,37 @@ static List> getHardCodedClasses() { } catch (ClassNotFoundException e) { logger.log(Level.FINE, "Unable to find NettyChannelProvider", e); } + try { + list.add(Class.forName("io.grpc.netty.UdsNettyChannelProvider")); + } catch (ClassNotFoundException e) { + logger.log(Level.FINE, "Unable to find UdsNettyChannelProvider", e); + } return Collections.unmodifiableList(list); } ManagedChannelBuilder newChannelBuilder(String target, ChannelCredentials creds) { + return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds); + } + + @VisibleForTesting + ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegistry, + String target, ChannelCredentials creds) { + NameResolverProvider nameResolverProvider = null; + try { + URI uri = new URI(target); + nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme()); + } catch (URISyntaxException ignore) { + // bad URI found, just ignore and continue + } + if (nameResolverProvider == null) { + nameResolverProvider = nameResolverRegistry.providers().get( + nameResolverRegistry.asFactory().getDefaultScheme()); + } + Collection> nameResolverSocketAddressTypes + = (nameResolverProvider != null) + ? nameResolverProvider.getProducedSocketAddressTypes() : + Collections.emptySet(); + List providers = providers(); if (providers.isEmpty()) { throw new ProviderNotFoundException("No functional channel service provider found. " @@ -152,6 +182,15 @@ ManagedChannelBuilder newChannelBuilder(String target, ChannelCredentials cre } StringBuilder error = new StringBuilder(); for (ManagedChannelProvider provider : providers()) { + Collection> channelProviderSocketAddressTypes + = provider.getSupportedSocketAddressTypes(); + if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { + error.append("; "); + error.append(provider.getClass().getName()); + error.append(": does not support 1 or more of "); + error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray())); + continue; + } ManagedChannelProvider.NewChannelBuilderResult result = provider.newChannelBuilder(target, creds); if (result.getChannelBuilder() != null) { diff --git a/api/src/main/java/io/grpc/Metadata.java b/api/src/main/java/io/grpc/Metadata.java index 9c2a2227f8c..58fcefe1373 100644 --- a/api/src/main/java/io/grpc/Metadata.java +++ b/api/src/main/java/io/grpc/Metadata.java @@ -211,8 +211,8 @@ private int len() { return size * 2; } + /** checks when {@link #namesAndValues} is null or has no elements. */ private boolean isEmpty() { - /** checks when {@link #namesAndValues} is null or has no elements */ return size == 0; } @@ -834,7 +834,7 @@ public String toString() { abstract T parseBytes(byte[] serialized); /** - * @return whether this key will be serialized to bytes lazily. + * Returns whether this key will be serialized to bytes lazily. */ boolean serializesToStreams() { return false; @@ -873,7 +873,7 @@ private BinaryKey(String name, BinaryMarshaller marshaller) { @Override byte[] toBytes(T value) { - return marshaller.toBytes(value); + return Preconditions.checkNotNull(marshaller.toBytes(value), "null marshaller.toBytes()"); } @Override @@ -901,7 +901,7 @@ private LazyStreamBinaryKey(String name, BinaryStreamMarshaller marshaller) { @Override byte[] toBytes(T value) { - return streamToBytes(marshaller.toStream(value)); + return streamToBytes(checkNotNull(marshaller.toStream(value), "null marshaller.toStream()")); } @Override @@ -932,7 +932,7 @@ static LazyValue create(Key key, T value) { } InputStream toStream() { - return marshaller.toStream(value); + return checkNotNull(marshaller.toStream(value), "null marshaller.toStream()"); } byte[] toBytes() { @@ -979,7 +979,9 @@ private AsciiKey(String name, boolean pseudo, AsciiMarshaller marshaller) { @Override byte[] toBytes(T value) { - return marshaller.toAsciiString(value).getBytes(US_ASCII); + String encoded = Preconditions.checkNotNull( + marshaller.toAsciiString(value), "null marshaller.toAsciiString()"); + return encoded.getBytes(US_ASCII); } @Override @@ -1004,7 +1006,8 @@ private TrustedAsciiKey(String name, boolean pseudo, TrustedAsciiMarshaller m @Override byte[] toBytes(T value) { - return marshaller.toAsciiString(value); + return Preconditions.checkNotNull( + marshaller.toAsciiString(value), "null marshaller.toAsciiString()"); } @Override diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index 78b35a14e38..48f835f9943 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -73,7 +73,9 @@ public abstract class NameResolver { public abstract String getServiceAuthority(); /** - * Starts the resolution. + * Starts the resolution. The method is not supposed to throw any exceptions. That might cause the + * Channel that the name resolver is serving to crash. Errors should be propagated + * through {@link Listener#onError}. * * @param listener used to receive updates on the target * @since 1.0.0 @@ -97,7 +99,9 @@ public void onResult(ResolutionResult resolutionResult) { } /** - * Starts the resolution. + * Starts the resolution. The method is not supposed to throw any exceptions. That might cause the + * Channel that the name resolver is serving to crash. Errors should be propagated + * through {@link Listener2#onError}. * * @param listener used to receive updates on the target * @since 1.21.0 @@ -201,6 +205,8 @@ void onAddresses( @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") public abstract static class Listener2 implements Listener { /** + * Handles updates on resolved addresses and attributes. + * * @deprecated This will be removed in 1.22.0 */ @Override @@ -261,6 +267,7 @@ public static final class Args { @Nullable private final ScheduledExecutorService scheduledExecutorService; @Nullable private final ChannelLogger channelLogger; @Nullable private final Executor executor; + @Nullable private final String overrideAuthority; private Args( Integer defaultPort, @@ -269,7 +276,8 @@ private Args( ServiceConfigParser serviceConfigParser, @Nullable ScheduledExecutorService scheduledExecutorService, @Nullable ChannelLogger channelLogger, - @Nullable Executor executor) { + @Nullable Executor executor, + @Nullable String overrideAuthority) { this.defaultPort = checkNotNull(defaultPort, "defaultPort not set"); this.proxyDetector = checkNotNull(proxyDetector, "proxyDetector not set"); this.syncContext = checkNotNull(syncContext, "syncContext not set"); @@ -277,6 +285,7 @@ private Args( this.scheduledExecutorService = scheduledExecutorService; this.channelLogger = channelLogger; this.executor = executor; + this.overrideAuthority = overrideAuthority; } /** @@ -362,6 +371,20 @@ public Executor getOffloadExecutor() { return executor; } + /** + * Returns the overrideAuthority from channel {@link ManagedChannelBuilder#overrideAuthority}. + * Overrides the host name for L7 HTTP virtual host matching. Almost all name resolvers should + * not use this. + * + * @since 1.49.0 + */ + @Nullable + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9406") + public String getOverrideAuthority() { + return overrideAuthority; + } + + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -372,6 +395,7 @@ public String toString() { .add("scheduledExecutorService", scheduledExecutorService) .add("channelLogger", channelLogger) .add("executor", executor) + .add("overrideAuthority", overrideAuthority) .toString(); } @@ -389,6 +413,7 @@ public Builder toBuilder() { builder.setScheduledExecutorService(scheduledExecutorService); builder.setChannelLogger(channelLogger); builder.setOffloadExecutor(executor); + builder.setOverrideAuthority(overrideAuthority); return builder; } @@ -414,6 +439,7 @@ public static final class Builder { private ScheduledExecutorService scheduledExecutorService; private ChannelLogger channelLogger; private Executor executor; + private String overrideAuthority; Builder() { } @@ -490,6 +516,17 @@ public Builder setOffloadExecutor(Executor executor) { return this; } + /** + * See {@link Args#getOverrideAuthority()}. This is an optional field. + * + * @since 1.49.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9406") + public Builder setOverrideAuthority(String authority) { + this.overrideAuthority = authority; + return this; + } + /** * Builds an {@link Args}. * @@ -499,7 +536,7 @@ public Args build() { return new Args( defaultPort, proxyDetector, syncContext, serviceConfigParser, - scheduledExecutorService, channelLogger, executor); + scheduledExecutorService, channelLogger, executor, overrideAuthority); } } } diff --git a/api/src/main/java/io/grpc/NameResolverProvider.java b/api/src/main/java/io/grpc/NameResolverProvider.java index 2c337cd5052..e7cddfc36d0 100644 --- a/api/src/main/java/io/grpc/NameResolverProvider.java +++ b/api/src/main/java/io/grpc/NameResolverProvider.java @@ -17,6 +17,10 @@ package io.grpc; import io.grpc.NameResolver.Factory; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** * Provider of name resolvers for name agnostic consumption. @@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory { protected String getScheme() { return getDefaultScheme(); } + + /** + * Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing. + * This enables selection of the appropriate {@link ManagedChannelProvider} for a channel. + * + * @return the {@link SocketAddress} types this provider's name-resolver is capable of producing. + */ + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/api/src/main/java/io/grpc/NameResolverRegistry.java b/api/src/main/java/io/grpc/NameResolverRegistry.java index 2e12bb77483..ab8a1e803eb 100644 --- a/api/src/main/java/io/grpc/NameResolverRegistry.java +++ b/api/src/main/java/io/grpc/NameResolverRegistry.java @@ -124,9 +124,7 @@ public static synchronized NameResolverRegistry getDefaultRegistry() { instance = new NameResolverRegistry(); for (NameResolverProvider provider : providerList) { logger.fine("Service loader found " + provider); - if (provider.isAvailable()) { - instance.addProvider(provider); - } + instance.addProvider(provider); } instance.refreshProviders(); } diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java index 6a8954b1a20..e5f0ae62702 100644 --- a/api/src/main/java/io/grpc/ServerBuilder.java +++ b/api/src/main/java/io/grpc/ServerBuilder.java @@ -243,6 +243,117 @@ public T handshakeTimeout(long timeout, TimeUnit unit) { throw new UnsupportedOperationException(); } + /** + * Sets the time without read activity before sending a keepalive ping. An unreasonably small + * value might be increased, and {@code Long.MAX_VALUE} nano seconds or an unreasonably large + * value will disable keepalive. The typical default is two hours when supported. + * + * @throws IllegalArgumentException if time is not positive + * @throws UnsupportedOperationException if unsupported + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + /** + * Sets a time waiting for read activity after sending a keepalive ping. If the time expires + * without any read activity on the connection, the connection is considered dead. An unreasonably + * small value might be increased. Defaults to 20 seconds when supported. + * + *

This value should be at least multiple times the RTT to allow for lost packets. + * + * @throws IllegalArgumentException if timeout is not positive + * @throws UnsupportedOperationException if unsupported + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + /** + * Sets the maximum connection idle time, connections being idle for longer than which will be + * gracefully terminated. Idleness duration is defined since the most recent time the number of + * outstanding RPCs became zero or the connection establishment. An unreasonably small value might + * be increased. {@code Long.MAX_VALUE} nano seconds or an unreasonably large value will disable + * max connection idle. + * + * @throws IllegalArgumentException if idle is not positive + * @throws UnsupportedOperationException if unsupported + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T maxConnectionIdle(long maxConnectionIdle, TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + /** + * Sets the maximum connection age, connections lasting longer than which will be gracefully + * terminated. An unreasonably small value might be increased. A random jitter of +/-10% will be + * added to it. {@code Long.MAX_VALUE} nano seconds or an unreasonably large value will disable + * max connection age. + * + * @throws IllegalArgumentException if age is not positive + * @throws UnsupportedOperationException if unsupported + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T maxConnectionAge(long maxConnectionAge, TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + /** + * Sets the grace time for the graceful connection termination. Once the max connection age + * is reached, RPCs have the grace time to complete. RPCs that do not complete in time will be + * cancelled, allowing the connection to terminate. {@code Long.MAX_VALUE} nano seconds or an + * unreasonably large value are considered infinite. + * + * @throws IllegalArgumentException if grace is negative + * @throws UnsupportedOperationException if unsupported + * @see #maxConnectionAge(long, TimeUnit) + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T maxConnectionAgeGrace(long maxConnectionAgeGrace, TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + /** + * Specify the most aggressive keep-alive time clients are permitted to configure. The server will + * try to detect clients exceeding this rate and when detected will forcefully close the + * connection. The typical default is 5 minutes when supported. + * + *

Even though a default is defined that allows some keep-alives, clients must not use + * keep-alive without approval from the service owner. Otherwise, they may experience failures in + * the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a + * significant amount of traffic and CPU usage, so clients and servers should be conservative in + * what they use and accept. + * + * @throws IllegalArgumentException if time is negative + * @throws UnsupportedOperationException if unsupported + * @see #permitKeepAliveWithoutCalls(boolean) + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + /** + * Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding + * RPCs on the connection. Defaults to {@code false} when supported. + * + * @throws UnsupportedOperationException if unsupported + * @see #permitKeepAliveTime(long, TimeUnit) + * @since 1.47.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9009") + public T permitKeepAliveWithoutCalls(boolean permit) { + throw new UnsupportedOperationException(); + } + /** * Sets the maximum message size allowed to be received on the server. If not called, * defaults to 4 MiB. The default provides protection to servers who haven't considered the diff --git a/api/src/main/java/io/grpc/ServerCall.java b/api/src/main/java/io/grpc/ServerCall.java index d391cb5c79a..40bcd2f3718 100644 --- a/api/src/main/java/io/grpc/ServerCall.java +++ b/api/src/main/java/io/grpc/ServerCall.java @@ -211,6 +211,21 @@ public void setCompression(String compressor) { // noop } + /** + * Returns the level of security guarantee in communications + * + *

Determining the level of security offered by the transport for RPCs on server-side. + * This can be approximated by looking for the SSLSession, but that doesn't work for ALTS and + * maybe some future TLS approaches. May return a lower security level when it cannot be + * determined precisely. + * + * @return non-{@code null} SecurityLevel enum + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4692") + public SecurityLevel getSecurityLevel() { + return SecurityLevel.NONE; + } + /** * Returns properties of a single call. * diff --git a/api/src/main/java/io/grpc/ServerRegistry.java b/api/src/main/java/io/grpc/ServerRegistry.java index e40039fb34c..e6a067ce87f 100644 --- a/api/src/main/java/io/grpc/ServerRegistry.java +++ b/api/src/main/java/io/grpc/ServerRegistry.java @@ -98,9 +98,7 @@ public static synchronized ServerRegistry getDefaultRegistry() { instance = new ServerRegistry(); for (ServerProvider provider : providerList) { logger.fine("Service loader found " + provider); - if (provider.isAvailable()) { - instance.addProvider(provider); - } + instance.addProvider(provider); } instance.refreshProviders(); } diff --git a/api/src/main/java/io/grpc/ServiceProviders.java b/api/src/main/java/io/grpc/ServiceProviders.java index 08837bcf5cc..ac4b27d8783 100644 --- a/api/src/main/java/io/grpc/ServiceProviders.java +++ b/api/src/main/java/io/grpc/ServiceProviders.java @@ -122,15 +122,26 @@ public static Iterable getCandidatesViaServiceLoader(Class klass, Clas static Iterable getCandidatesViaHardCoded(Class klass, Iterable> hardcoded) { List list = new ArrayList<>(); for (Class candidate : hardcoded) { - list.add(create(klass, candidate)); + T t = createForHardCoded(klass, candidate); + if (t == null) { + continue; + } + list.add(t); } return list; } - @VisibleForTesting - static T create(Class klass, Class rawClass) { + private static T createForHardCoded(Class klass, Class rawClass) { try { return rawClass.asSubclass(klass).getConstructor().newInstance(); + } catch (ClassCastException ex) { + // Tools like Proguard that perform obfuscation rewrite strings only when the class they + // reference is known, as otherwise they wouldn't know its new name. This means some + // hard-coded Class.forNames() won't be rewritten. This can cause ClassCastException at + // runtime if the class ends up appearing on the classpath but that class is part of a + // separate copy of grpc. With tools like Maven Shade Plugin the class wouldn't be found at + // all and so would be skipped. We want to skip in this case as well. + return null; } catch (Throwable t) { throw new ServiceConfigurationError( String.format("Provider %s could not be instantiated %s", rawClass.getName(), t), t); diff --git a/api/src/main/java/io/grpc/Status.java b/api/src/main/java/io/grpc/Status.java index 1ad5abc0539..7382cd03ee1 100644 --- a/api/src/main/java/io/grpc/Status.java +++ b/api/src/main/java/io/grpc/Status.java @@ -50,6 +50,10 @@ * *

Utility functions are provided to convert a status to an exception and to extract them * back out. + * + *

Extended descriptions, including a list of codes that should not be generated by the library, + * can be found at + * doc/statuscodes.md */ @Immutable @CheckReturnValue @@ -595,6 +599,8 @@ private static boolean isEscapingChar(byte b) { } /** + * Percent encode bytes to make them ASCII. + * * @param valueBytes the UTF-8 bytes * @param ri The reader index, pointed at the first byte that needs escaping. */ diff --git a/api/src/main/java/io/grpc/StatusRuntimeException.java b/api/src/main/java/io/grpc/StatusRuntimeException.java index e3e0555fb09..68b816fc7fa 100644 --- a/api/src/main/java/io/grpc/StatusRuntimeException.java +++ b/api/src/main/java/io/grpc/StatusRuntimeException.java @@ -32,7 +32,7 @@ public class StatusRuntimeException extends RuntimeException { private final boolean fillInStackTrace; /** - * Constructs the exception with both a status. See also {@link Status#asException()}. + * Constructs the exception with both a status. See also {@link Status#asRuntimeException()}. * * @since 1.0.0 */ @@ -41,8 +41,8 @@ public StatusRuntimeException(Status status) { } /** - * Constructs the exception with both a status and trailers. See also - * {@link Status#asException(Metadata)}. + * Constructs the exception with both a status and trailers. See also {@link + * Status#asRuntimeException(Metadata)}. * * @since 1.0.0 */ diff --git a/api/src/main/java/io/grpc/SynchronizationContext.java b/api/src/main/java/io/grpc/SynchronizationContext.java index 03d26b55f0a..fe4243ec227 100644 --- a/api/src/main/java/io/grpc/SynchronizationContext.java +++ b/api/src/main/java/io/grpc/SynchronizationContext.java @@ -163,6 +163,38 @@ public String toString() { return new ScheduledHandle(runnable, future); } + /** + * Schedules a task to be added and run via {@link #execute} after an inital delay and then + * repeated after the delay until cancelled. + * + * @param task the task being scheduled + * @param initialDelay the delay before the first run + * @param delay the delay after the first run. + * @param unit the time unit for the delay + * @param timerService the {@code ScheduledExecutorService} that provides delayed execution + * + * @return an object for checking the status and/or cancel the scheduled task + */ + public final ScheduledHandle scheduleWithFixedDelay( + final Runnable task, long initialDelay, long delay, TimeUnit unit, + ScheduledExecutorService timerService) { + final ManagedRunnable runnable = new ManagedRunnable(task); + ScheduledFuture future = timerService.scheduleWithFixedDelay(new Runnable() { + @Override + public void run() { + execute(runnable); + } + + @Override + public String toString() { + return task.toString() + "(scheduled in SynchronizationContext with delay of " + delay + + ")"; + } + }, initialDelay, delay, unit); + return new ScheduledHandle(runnable, future); + } + + private static class ManagedRunnable implements Runnable { final Runnable task; boolean isCancelled; diff --git a/api/src/test/java/io/grpc/GlobalInterceptorsTest.java b/api/src/test/java/io/grpc/GlobalInterceptorsTest.java new file mode 100644 index 00000000000..7315186f1ee --- /dev/null +++ b/api/src/test/java/io/grpc/GlobalInterceptorsTest.java @@ -0,0 +1,188 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GlobalInterceptorsTest { + + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader( + getClass().getClassLoader(), Pattern.compile("io\\.grpc\\.[^.]+")); + + @Test + public void setInterceptorsTracers() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderSet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void setGlobalInterceptorsTracers_twice() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderSetTwice.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void getBeforeSet_clientInterceptors() throws Exception { + Class runnable = + classLoader.loadClass( + StaticTestingClassLoaderGetBeforeSetClientInterceptor.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void getBeforeSet_serverInterceptors() throws Exception { + Class runnable = + classLoader.loadClass( + StaticTestingClassLoaderGetBeforeSetServerInterceptor.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void getBeforeSet_serverStreamTracerFactories() throws Exception { + Class runnable = + classLoader.loadClass( + StaticTestingClassLoaderGetBeforeSetServerStreamTracerFactory.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassLoaderSet implements Runnable { + @Override + public void run() { + List clientInterceptorList = + new ArrayList<>(Arrays.asList(new NoopClientInterceptor())); + List serverInterceptorList = + new ArrayList<>(Arrays.asList(new NoopServerInterceptor())); + List serverStreamTracerFactoryList = + new ArrayList<>( + Arrays.asList( + new NoopServerStreamTracerFactory(), new NoopServerStreamTracerFactory())); + + GlobalInterceptors.setInterceptorsTracers( + clientInterceptorList, serverInterceptorList, serverStreamTracerFactoryList); + + assertThat(GlobalInterceptors.getClientInterceptors()).isEqualTo(clientInterceptorList); + assertThat(GlobalInterceptors.getServerInterceptors()).isEqualTo(serverInterceptorList); + assertThat(GlobalInterceptors.getServerStreamTracerFactories()) + .isEqualTo(serverStreamTracerFactoryList); + } + } + + public static final class StaticTestingClassLoaderSetTwice implements Runnable { + @Override + public void run() { + GlobalInterceptors.setInterceptorsTracers( + new ArrayList<>(Arrays.asList(new NoopClientInterceptor())), + Collections.emptyList(), + new ArrayList<>(Arrays.asList(new NoopServerStreamTracerFactory()))); + try { + GlobalInterceptors.setInterceptorsTracers( + null, new ArrayList<>(Arrays.asList(new NoopServerInterceptor())), null); + fail("should have failed for calling setGlobalInterceptorsTracers() again"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().isEqualTo("Global interceptors and tracers are already set"); + } + } + } + + public static final class StaticTestingClassLoaderGetBeforeSetClientInterceptor + implements Runnable { + @Override + public void run() { + List clientInterceptors = GlobalInterceptors.getClientInterceptors(); + assertThat(clientInterceptors).isNull(); + + try { + GlobalInterceptors.setInterceptorsTracers( + new ArrayList<>(Arrays.asList(new NoopClientInterceptor())), null, null); + fail("should have failed for invoking set call after get is already called"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().isEqualTo("Set cannot be called after any get call"); + } + } + } + + public static final class StaticTestingClassLoaderGetBeforeSetServerInterceptor + implements Runnable { + @Override + public void run() { + List serverInterceptors = GlobalInterceptors.getServerInterceptors(); + assertThat(serverInterceptors).isNull(); + + try { + GlobalInterceptors.setInterceptorsTracers( + null, new ArrayList<>(Arrays.asList(new NoopServerInterceptor())), null); + fail("should have failed for invoking set call after get is already called"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().isEqualTo("Set cannot be called after any get call"); + } + } + } + + public static final class StaticTestingClassLoaderGetBeforeSetServerStreamTracerFactory + implements Runnable { + @Override + public void run() { + List serverStreamTracerFactories = + GlobalInterceptors.getServerStreamTracerFactories(); + assertThat(serverStreamTracerFactories).isNull(); + + try { + GlobalInterceptors.setInterceptorsTracers( + null, null, new ArrayList<>(Arrays.asList(new NoopServerStreamTracerFactory()))); + fail("should have failed for invoking set call after get is already called"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().isEqualTo("Set cannot be called after any get call"); + } + } + } + + private static class NoopClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + } + + private static class NoopServerInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + return next.startCall(call, headers); + } + } + + private static class NoopServerStreamTracerFactory extends ServerStreamTracer.Factory { + @Override + public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java b/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java index e8588e5e8d8..3debc871121 100644 --- a/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java @@ -41,7 +41,7 @@ public void getClassesViaHardcoded_classesPresent() throws Exception { @Test public void stockProviders() { LoadBalancerRegistry defaultRegistry = LoadBalancerRegistry.getDefaultRegistry(); - assertThat(defaultRegistry.providers()).hasSize(3); + assertThat(defaultRegistry.providers()).hasSize(4); LoadBalancerProvider pickFirst = defaultRegistry.getProvider("pick_first"); assertThat(pickFirst).isInstanceOf(PickFirstLoadBalancerProvider.class); @@ -52,6 +52,12 @@ public void stockProviders() { "io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider"); assertThat(roundRobin.getPriority()).isEqualTo(5); + LoadBalancerProvider outlierDetection = defaultRegistry.getProvider( + "outlier_detection_experimental"); + assertThat(outlierDetection.getClass().getName()).isEqualTo( + "io.grpc.util.OutlierDetectionLoadBalancerProvider"); + assertThat(roundRobin.getPriority()).isEqualTo(5); + LoadBalancerProvider grpclb = defaultRegistry.getProvider("grpclb"); assertThat(grpclb).isInstanceOf(GrpclbLoadBalancerProvider.class); assertThat(grpclb.getPriority()).isEqualTo(5); diff --git a/api/src/test/java/io/grpc/LoadBalancerTest.java b/api/src/test/java/io/grpc/LoadBalancerTest.java index be3d10ba2ae..beaf3335e2c 100644 --- a/api/src/test/java/io/grpc/LoadBalancerTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerTest.java @@ -235,13 +235,14 @@ public void createSubchannelArgs_toString() { @Deprecated @Test - public void handleResolvedAddressGroups_delegatesToHandleResolvedAddresses() { + public void handleResolvedAddresses_delegatesToAcceptResolvedAddresses() { final AtomicReference resultCapture = new AtomicReference<>(); LoadBalancer balancer = new LoadBalancer() { @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { resultCapture.set(resolvedAddresses); + return true; } @Override @@ -260,23 +261,22 @@ public void shutdown() { List servers = Arrays.asList( new EquivalentAddressGroup(new SocketAddress(){}), new EquivalentAddressGroup(new SocketAddress(){})); - balancer.handleResolvedAddressGroups(servers, attrs); + ResolvedAddresses addresses = ResolvedAddresses.newBuilder().setAddresses(servers) + .setAttributes(attrs).build(); + balancer.handleResolvedAddresses(addresses); assertThat(resultCapture.get()).isEqualTo( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attrs).build()); } @Deprecated @Test - public void handleResolvedAddresses_delegatesToHandleResolvedAddressGroups() { - final AtomicReference> serversCapture = new AtomicReference<>(); - final AtomicReference attrsCapture = new AtomicReference<>(); + public void acceptResolvedAddresses_delegatesToHandleResolvedAddressGroups() { + final AtomicReference addressesCapture = new AtomicReference<>(); LoadBalancer balancer = new LoadBalancer() { @Override - public void handleResolvedAddressGroups( - List servers, Attributes attrs) { - serversCapture.set(servers); - attrsCapture.set(attrs); + public void handleResolvedAddresses(ResolvedAddresses addresses) { + addressesCapture.set(addresses); } @Override @@ -295,25 +295,23 @@ public void shutdown() { List servers = Arrays.asList( new EquivalentAddressGroup(new SocketAddress(){}), new EquivalentAddressGroup(new SocketAddress(){})); - balancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attrs).build()); - assertThat(serversCapture.get()).isEqualTo(servers); - assertThat(attrsCapture.get()).isEqualTo(attrs); + ResolvedAddresses addresses = ResolvedAddresses.newBuilder().setAddresses(servers) + .setAttributes(attrs).build(); + balancer.handleResolvedAddresses(addresses); + assertThat(addressesCapture.get().getAddresses()).isEqualTo(servers); + assertThat(addressesCapture.get().getAttributes()).isEqualTo(attrs); } @Deprecated @Test - public void handleResolvedAddresses_noInfiniteLoop() { - final List> serversCapture = new ArrayList<>(); - final List attrsCapture = new ArrayList<>(); + public void acceptResolvedAddresses_noInfiniteLoop() { + final List addressesCapture = new ArrayList<>(); LoadBalancer balancer = new LoadBalancer() { @Override - public void handleResolvedAddressGroups( - List servers, Attributes attrs) { - serversCapture.add(servers); - attrsCapture.add(attrs); - super.handleResolvedAddressGroups(servers, attrs); + public void handleResolvedAddresses(ResolvedAddresses addresses) { + addressesCapture.add(addresses); + super.handleResolvedAddresses(addresses); } @Override @@ -328,12 +326,12 @@ public void shutdown() { List servers = Arrays.asList( new EquivalentAddressGroup(new SocketAddress(){}), new EquivalentAddressGroup(new SocketAddress(){})); - balancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(attrs).build()); - assertThat(serversCapture).hasSize(1); - assertThat(attrsCapture).hasSize(1); - assertThat(serversCapture.get(0)).isEqualTo(servers); - assertThat(attrsCapture.get(0)).isEqualTo(attrs); + ResolvedAddresses addresses = ResolvedAddresses.newBuilder().setAddresses(servers) + .setAttributes(attrs).build(); + balancer.handleResolvedAddresses(addresses); + assertThat(addressesCapture).hasSize(1); + assertThat(addressesCapture.get(0).getAddresses()).isEqualTo(servers); + assertThat(addressesCapture.get(0).getAttributes()).isEqualTo(attrs); } private static class NoopHelper extends LoadBalancer.Helper { diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 6f25f620576..b968fbd7ec2 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -19,6 +19,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.common.collect.ImmutableSet; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -156,6 +162,251 @@ public void newChannelBuilder_noProvider() { } } + @Test + public void newChannelBuilder_usesScheme() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + class SocketAddress2 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") { + @Override + protected Collection> getProducedSocketAddressTypes() { + fail("Should not be called"); + throw new AssertionError(); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + registry.register(new BaseProvider(true, 5) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress2.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_unsupportedSocketAddressTypes() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + class SocketAddress2 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + registry.register(new BaseProvider(true, 5) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress2.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + try { + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds); + fail("expected exception"); + } catch (ManagedChannelRegistry.ProviderNotFoundException ex) { + assertThat(ex).hasMessageThat().contains("does not support 1 or more of"); + assertThat(ex).hasMessageThat().contains("SocketAddress1"); + assertThat(ex).hasMessageThat().contains("SocketAddress2"); + } + } + + @Test + public void newChannelBuilder_emptySet_asDefault() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.emptySet(); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_noSchemeUsesDefaultScheme() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_badUri() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs( + mcb); + } + + private static class BaseNameResolverProvider extends NameResolverProvider { + private final boolean isAvailable; + private final int priority; + private final String defaultScheme; + + public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) { + this.isAvailable = isAvailable; + this.priority = priority; + this.defaultScheme = defaultScheme; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return defaultScheme; + } + + @Override + protected boolean isAvailable() { + return isAvailable; + } + + @Override + protected int priority() { + return priority; + } + } + private static class BaseProvider extends ManagedChannelProvider { private final boolean isAvailable; private final int priority; @@ -184,5 +435,10 @@ protected ManagedChannelBuilder builderForAddress(String name, int port) { protected ManagedChannelBuilder builderForTarget(String target) { throw new UnsupportedOperationException(); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } } diff --git a/api/src/test/java/io/grpc/NameResolverTest.java b/api/src/test/java/io/grpc/NameResolverTest.java index 0b5b198713b..f825de354af 100644 --- a/api/src/test/java/io/grpc/NameResolverTest.java +++ b/api/src/test/java/io/grpc/NameResolverTest.java @@ -40,6 +40,7 @@ public class NameResolverTest { mock(ScheduledExecutorService.class); private final ChannelLogger channelLogger = mock(ChannelLogger.class); private final Executor executor = Executors.newSingleThreadExecutor(); + private final String overrideAuthority = "grpc.io"; @Test public void args() { @@ -51,6 +52,7 @@ public void args() { assertThat(args.getScheduledExecutorService()).isSameInstanceAs(scheduledExecutorService); assertThat(args.getChannelLogger()).isSameInstanceAs(channelLogger); assertThat(args.getOffloadExecutor()).isSameInstanceAs(executor); + assertThat(args.getOverrideAuthority()).isSameInstanceAs(overrideAuthority); NameResolver.Args args2 = args.toBuilder().build(); assertThat(args2.getDefaultPort()).isEqualTo(defaultPort); @@ -60,6 +62,7 @@ public void args() { assertThat(args2.getScheduledExecutorService()).isSameInstanceAs(scheduledExecutorService); assertThat(args2.getChannelLogger()).isSameInstanceAs(channelLogger); assertThat(args2.getOffloadExecutor()).isSameInstanceAs(executor); + assertThat(args2.getOverrideAuthority()).isSameInstanceAs(overrideAuthority); assertThat(args2).isNotSameInstanceAs(args); assertThat(args2).isNotEqualTo(args); @@ -74,6 +77,7 @@ private NameResolver.Args createArgs() { .setScheduledExecutorService(scheduledExecutorService) .setChannelLogger(channelLogger) .setOffloadExecutor(executor) + .setOverrideAuthority(overrideAuthority) .build(); } } diff --git a/api/src/test/java/io/grpc/ServiceProvidersTest.java b/api/src/test/java/io/grpc/ServiceProvidersTest.java index 4bbfe2117be..7d4388a5bb9 100644 --- a/api/src/test/java/io/grpc/ServiceProvidersTest.java +++ b/api/src/test/java/io/grpc/ServiceProvidersTest.java @@ -215,19 +215,35 @@ public void getCandidatesViaHardCoded_failAtInit_moreCandidates() throws Excepti } @Test - public void create_throwsErrorOnMisconfiguration() throws Exception { - class PrivateClass {} + public void getCandidatesViaHardCoded_throwsErrorOnMisconfiguration() throws Exception { + class PrivateClass extends BaseProvider { + private PrivateClass() { + super(true, 5); + } + } try { - ServiceProviders.create( - ServiceProvidersTestAbstractProvider.class, PrivateClass.class); + ServiceProviders.getCandidatesViaHardCoded( + ServiceProvidersTestAbstractProvider.class, + Collections.>singletonList(PrivateClass.class)); fail("Expected exception"); } catch (ServiceConfigurationError expected) { - assertTrue("Expected ClassCastException cause: " + expected.getCause(), - expected.getCause() instanceof ClassCastException); + assertTrue("Expected NoSuchMethodException cause: " + expected.getCause(), + expected.getCause() instanceof NoSuchMethodException); } } + @Test + public void getCandidatesViaHardCoded_skipsWrongClassType() throws Exception { + class RandomClass {} + + Iterable candidates = + ServiceProviders.getCandidatesViaHardCoded( + ServiceProvidersTestAbstractProvider.class, + Collections.>singletonList(RandomClass.class)); + assertFalse(candidates.iterator().hasNext()); + } + private static class BaseProvider extends ServiceProvidersTestAbstractProvider { private final boolean isAvailable; private final int priority; diff --git a/auth/build.gradle b/auth/build.gradle index 233de359b49..73c108d5203 100644 --- a/auth/build.gradle +++ b/auth/build.gradle @@ -9,10 +9,10 @@ plugins { description = "gRPC: Auth" dependencies { api project(':grpc-api'), - libraries.google_auth_credentials + libraries.google.auth.credentials implementation libraries.guava testImplementation project(':grpc-testing'), - libraries.google_auth_oauth2_http - signature "org.codehaus.mojo.signature:java17:1.0@signature" - signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" + libraries.google.auth.oauth2Http + signature libraries.signature.java + signature libraries.signature.android } diff --git a/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java b/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java index 4b95a6c7f4d..2a414e792d9 100644 --- a/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java +++ b/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java @@ -22,6 +22,7 @@ import com.google.auth.RequestMetadataCallback; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.BaseEncoding; +import io.grpc.InternalMayRequireSpecificExecutor; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -44,13 +45,16 @@ /** * Wraps {@link Credentials} as a {@link io.grpc.CallCredentials}. */ -final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials { +final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials + implements InternalMayRequireSpecificExecutor { private static final Logger log = Logger.getLogger(GoogleAuthLibraryCallCredentials.class.getName()); private static final JwtHelper jwtHelper = createJwtHelperOrNull(GoogleAuthLibraryCallCredentials.class.getClassLoader()); - private static final Class googleCredentialsClass + private static final Class GOOGLE_CREDENTIALS_CLASS = loadGoogleCredentialsClass(); + private static final Class APP_ENGINE_CREDENTIALS_CLASS + = loadAppEngineCredentials(); private final boolean requirePrivacy; @VisibleForTesting @@ -59,6 +63,8 @@ final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials { private Metadata lastHeaders; private Map> lastMetadata; + private Boolean requiresSpecificExecutor; + public GoogleAuthLibraryCallCredentials(Credentials creds) { this(creds, jwtHelper); } @@ -67,12 +73,12 @@ public GoogleAuthLibraryCallCredentials(Credentials creds) { GoogleAuthLibraryCallCredentials(Credentials creds, JwtHelper jwtHelper) { checkNotNull(creds, "creds"); boolean requirePrivacy = false; - if (googleCredentialsClass != null) { + if (GOOGLE_CREDENTIALS_CLASS != null) { // All GoogleCredentials instances are bearer tokens and should only be used on private // channels. This catches all return values from GoogleCredentials.getApplicationDefault(). // This should be checked before upgrading the Service Account to JWT, as JWT is also a bearer // token. - requirePrivacy = googleCredentialsClass.isInstance(creds); + requirePrivacy = GOOGLE_CREDENTIALS_CLASS.isInstance(creds); } if (jwtHelper != null) { creds = jwtHelper.tryServiceAccountToJwt(creds); @@ -242,6 +248,16 @@ private static Class loadGoogleCredentialsClass() { return rawGoogleCredentialsClass.asSubclass(Credentials.class); } + @Nullable + private static Class loadAppEngineCredentials() { + try { + return Class.forName("com.google.auth.appengine.AppEngineCredentials"); + } catch (ClassNotFoundException ex) { + log.log(Level.FINE, "AppEngineCredentials not available in classloader", ex); + return null; + } + } + private static class MethodPair { private final Method getter; private final Method builderSetter; @@ -298,6 +314,11 @@ public JwtHelper(Class rawServiceAccountClass, ClassLoader loader) Method setter = builderClass.getMethod("setPrivateKeyId", getter.getReturnType()); methodPairs.add(new MethodPair(getter, setter)); } + { + Method getter = serviceAccountClass.getMethod("getQuotaProjectId"); + Method setter = builderClass.getMethod("setQuotaProjectId", getter.getReturnType()); + methodPairs.add(new MethodPair(getter, setter)); + } } /** @@ -348,4 +369,24 @@ public Credentials tryServiceAccountToJwt(Credentials creds) { return creds; } } + + /** + * This method is to support the hack for AppEngineCredentials which need to run on a + * specific thread. + * @return Whether a specific executor is needed or if any executor can be used + */ + @Override + public boolean isSpecificExecutorRequired() { + // Cache the value so we only need to try to load the class once + if (requiresSpecificExecutor == null) { + if (APP_ENGINE_CREDENTIALS_CLASS == null) { + requiresSpecificExecutor = Boolean.FALSE; + } else { + requiresSpecificExecutor = APP_ENGINE_CREDENTIALS_CLASS.isInstance(creds); + } + } + + return requiresSpecificExecutor; + } + } diff --git a/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java b/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java index ee5713bfd27..cbb2afdc3d0 100644 --- a/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java +++ b/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java @@ -379,6 +379,7 @@ public void jwtAccessCredentialsInRequestMetadata() throws Exception { .setClientEmail("test-email@example.com") .setPrivateKey(pair.getPrivate()) .setPrivateKeyId("test-private-key-id") + .setQuotaProjectId("test-quota-project-id") .build(); GoogleAuthLibraryCallCredentials callCredentials = new GoogleAuthLibraryCallCredentials(credentials); @@ -401,6 +402,10 @@ public void jwtAccessCredentialsInRequestMetadata() throws Exception { || "https://example.com:123/a.service".equals(payload.get("aud"))); assertEquals("test-email@example.com", payload.get("iss")); assertEquals("test-email@example.com", payload.get("sub")); + + Metadata.Key quotaProject = Metadata.Key + .of("X-Goog-User-Project", Metadata.ASCII_STRING_MARSHALLER); + assertEquals("test-quota-project-id", Iterables.getOnlyElement(headers.getAll(quotaProject))); } private int runPendingRunnables() { diff --git a/authz/build.gradle b/authz/build.gradle index f6110a5850c..50084752d0e 100644 --- a/authz/build.gradle +++ b/authz/build.gradle @@ -13,12 +13,12 @@ dependencies { implementation project(':grpc-protobuf'), project(':grpc-core') - annotationProcessor libraries.autovalue - compileOnly libraries.javax_annotation + annotationProcessor libraries.auto.value + compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), project(':grpc-testing-proto') - testImplementation (libraries.guava_testlib) { + testImplementation (libraries.guava.testlib) { exclude group: 'junit', module: 'junit' } @@ -26,20 +26,14 @@ dependencies { shadow configurations.implementation.getDependencies().minus([xdsDependency]) shadow project(path: ':grpc-xds', configuration: 'shadow') - signature "org.codehaus.mojo.signature:java17:1.0@signature" + signature libraries.signature.java } -jar { +tasks.named("jar").configure { classifier = 'original' } -// TODO(ashithasantosh): Remove javadoc exclusion on adding authorization -// interceptor implementations. -javadoc { - exclude "io/grpc/authz/*" -} - -shadowJar { +tasks.named("shadowJar").configure { classifier = null dependencies { exclude(dependency {true}) @@ -52,6 +46,12 @@ shadowJar { relocate 'com.google.api.expr', 'io.grpc.xds.shaded.com.google.api.expr' } +tasks.named("compileJava").configure { + it.options.compilerArgs += [ + "-Xlint:-processing", + ] +} + publishing { publications { maven(MavenPublication) { @@ -74,4 +74,6 @@ publishing { } } -[publishMavenPublicationToMavenRepository]*.onlyIf {false} +tasks.named("publishMavenPublicationToMavenRepository").configure { + enabled = false +} diff --git a/authz/src/main/java/io/grpc/authz/AuthorizationServerInterceptor.java b/authz/src/main/java/io/grpc/authz/AuthorizationServerInterceptor.java new file mode 100644 index 00000000000..2e398666093 --- /dev/null +++ b/authz/src/main/java/io/grpc/authz/AuthorizationServerInterceptor.java @@ -0,0 +1,74 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.authz; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.envoyproxy.envoy.config.rbac.v3.RBAC; +import io.grpc.ExperimentalApi; +import io.grpc.InternalServerInterceptors; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.xds.InternalRbacFilter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Authorization server interceptor for static policy. The class will get + * + * gRPC Authorization policy as a JSON string during initialization. + * This policy will be translated to Envoy RBAC policies to make + * authorization decisions. The policy cannot be changed once created. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9746") +public final class AuthorizationServerInterceptor implements ServerInterceptor { + private final List interceptors = new ArrayList<>(); + + private AuthorizationServerInterceptor(String authorizationPolicy) + throws IOException { + List rbacs = AuthorizationPolicyTranslator.translate(authorizationPolicy); + if (rbacs == null || rbacs.isEmpty() || rbacs.size() > 2) { + throw new IllegalArgumentException("Failed to translate authorization policy"); + } + for (RBAC rbac: rbacs) { + interceptors.add( + InternalRbacFilter.createInterceptor( + io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() + .setRules(rbac).build())); + } + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, + ServerCallHandler next) { + for (ServerInterceptor interceptor: interceptors) { + next = InternalServerInterceptors.interceptCallHandlerCreate(interceptor, next); + } + return next.startCall(call, headers); + } + + // Static method that creates an AuthorizationServerInterceptor. + public static AuthorizationServerInterceptor create(String authorizationPolicy) + throws IOException { + checkNotNull(authorizationPolicy, "authorizationPolicy"); + return new AuthorizationServerInterceptor(authorizationPolicy); + } +} diff --git a/authz/src/test/java/io/grpc/authz/AuthorizationEnd2EndTest.java b/authz/src/test/java/io/grpc/authz/AuthorizationEnd2EndTest.java new file mode 100644 index 00000000000..28c17718d11 --- /dev/null +++ b/authz/src/test/java/io/grpc/authz/AuthorizationEnd2EndTest.java @@ -0,0 +1,374 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.authz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.TlsServerCredentials.ClientAuth; +import io.grpc.internal.testing.TestUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.File; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AuthorizationEnd2EndTest { + public static final String SERVER_0_KEY_FILE = "server0.key"; + public static final String SERVER_0_PEM_FILE = "server0.pem"; + public static final String CLIENT_0_KEY_FILE = "client.key"; + public static final String CLIENT_0_PEM_FILE = "client.pem"; + public static final String CA_PEM_FILE = "ca.pem"; + + private Server server; + private ManagedChannel channel; + + private void initServerWithStaticAuthz( + String authorizationPolicy, ServerCredentials serverCredentials) throws Exception { + AuthorizationServerInterceptor authzInterceptor = + AuthorizationServerInterceptor.create(authorizationPolicy); + server = Grpc.newServerBuilderForPort(0, serverCredentials) + .addService(new SimpleServiceImpl()) + .intercept(authzInterceptor) + .build() + .start(); + } + + private SimpleServiceGrpc.SimpleServiceBlockingStub getStub() { + channel = + Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), InsecureChannelCredentials.create()) + .build(); + return SimpleServiceGrpc.newBlockingStub(channel); + } + + private SimpleServiceGrpc.SimpleServiceBlockingStub getStub( + ChannelCredentials channelCredentials) { + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au") + .build(); + return SimpleServiceGrpc.newBlockingStub(channel); + } + + @After + public void tearDown() { + if (server != null) { + server.shutdown(); + } + if (channel != null) { + channel.shutdown(); + } + } + + @Test + public void staticAuthzAllowsRpcNoMatchInDenyMatchInAllowTest() throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"deny_rules\": [" + + " {" + + " \"name\": \"deny_UnaryRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/UnaryRpc\"" + + " ]," + + " \"headers\": [" + + " {" + + " \"key\": \"dev-path\"," + + " \"values\": [\"/dev/path/*\"]" + + " }" + + " ]" + + " }" + + " }" + + " ]," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_all\"" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + } + + @Test + public void staticAuthzDeniesRpcNoMatchInDenyAndAllowTest() throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"deny_rules\": [" + + " {" + + " \"name\": \"deny_foo\"," + + " \"source\": {" + + " \"principals\": [" + + " \"foo\"" + + " ]" + + " }" + + " }" + + " ]," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_ClientStreamingRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/ClientStreamingRpc\"" + + " ]" + + " }" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + try { + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasMessageThat().isEqualTo( + "PERMISSION_DENIED: Access Denied"); + } catch (Exception e) { + throw new AssertionError("the test failed ", e); + } + } + + @Test + public void staticAuthzDeniesRpcMatchInDenyAndAllowTest() throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"deny_rules\": [" + + " {" + + " \"name\": \"deny_UnaryRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/UnaryRpc\"" + + " ]" + + " }" + + " }" + + " ]," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_UnaryRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/UnaryRpc\"" + + " ]" + + " }" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + try { + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasMessageThat().isEqualTo( + "PERMISSION_DENIED: Access Denied"); + } catch (Exception e) { + throw new AssertionError("the test failed ", e); + } + } + + @Test + public void staticAuthzDeniesRpcMatchInDenyNoMatchInAllowTest() throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"deny_rules\": [" + + " {" + + " \"name\": \"deny_UnaryRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/UnaryRpc\"" + + " ]" + + " }" + + " }" + + " ]," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_ClientStreamingRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/ClientStreamingRpc\"" + + " ]" + + " }" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + try { + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasMessageThat().isEqualTo( + "PERMISSION_DENIED: Access Denied"); + } catch (Exception e) { + throw new AssertionError("the test failed ", e); + } + } + + @Test + public void staticAuthzAllowsRpcEmptyDenyMatchInAllowTest() throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_UnaryRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/UnaryRpc\"" + + " ]" + + " }" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + } + + @Test + public void staticAuthzDeniesRpcEmptyDenyNoMatchInAllowTest() throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_ClientStreamingRpc\"," + + " \"request\": {" + + " \"paths\": [" + + " \"*/ClientStreamingRpc\"" + + " ]" + + " }" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + try { + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasMessageThat().isEqualTo( + "PERMISSION_DENIED: Access Denied"); + } catch (Exception e) { + throw new AssertionError("the test failed ", e); + } + } + + @Test + public void staticAuthzDeniesRpcWithPrincipalsFieldOnUnauthenticatedConnectionTest() + throws Exception { + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_authenticated\"," + + " \"source\": {" + + " \"principals\": [\"*\", \"\"]" + + " }" + + " }" + + " ]" + + "}"; + initServerWithStaticAuthz(policy, InsecureServerCredentials.create()); + try { + getStub().unaryRpc(SimpleRequest.getDefaultInstance()); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre).hasMessageThat().isEqualTo( + "PERMISSION_DENIED: Access Denied"); + } catch (Exception e) { + throw new AssertionError("the test failed ", e); + } + } + + @Test + public void staticAuthzAllowsRpcWithPrincipalsFieldOnMtlsAuthenticatedConnectionTest() + throws Exception { + File caCertFile = TestUtils.loadCert(CA_PEM_FILE); + File serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); + File serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + File clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE); + File clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE); + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_mtls\"," + + " \"source\": {" + + " \"principals\": [\"*\"]" + + " }" + + " }" + + " ]" + + "}"; + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverCert0File, serverKey0File) + .trustManager(caCertFile) + .clientAuth(ClientAuth.REQUIRE) + .build(); + initServerWithStaticAuthz(policy, serverCredentials); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientCert0File, clientKey0File) + .trustManager(caCertFile) + .build(); + getStub(channelCredentials).unaryRpc(SimpleRequest.getDefaultInstance()); + } + + @Test + public void staticAuthzAllowsRpcWithPrincipalsFieldOnTlsAuthenticatedConnectionTest() + throws Exception { + File caCertFile = TestUtils.loadCert(CA_PEM_FILE); + File serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); + File serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + String policy = "{" + + " \"name\" : \"authz\" ," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_tls\"," + + " \"source\": {" + + " \"principals\": [\"\"]" + + " }" + + " }" + + " ]" + + "}"; + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverCert0File, serverKey0File) + .trustManager(caCertFile) + .clientAuth(ClientAuth.OPTIONAL) + .build(); + initServerWithStaticAuthz(policy, serverCredentials); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .trustManager(caCertFile) + .build(); + getStub(channelCredentials).unaryRpc(SimpleRequest.getDefaultInstance()); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest req, StreamObserver respOb) { + respOb.onNext(SimpleResponse.getDefaultInstance()); + respOb.onCompleted(); + } + } +} diff --git a/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java b/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java new file mode 100644 index 00000000000..990228e2c96 --- /dev/null +++ b/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.authz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +import java.io.IOException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + +@RunWith(JUnit4.class) +public class AuthorizationServerInterceptorTest { + @Test + public void invalidPolicyFailsStaticAuthzInterceptorCreation() throws Exception { + String policy = "{ \"name\": \"abc\",, }"; + try { + AuthorizationServerInterceptor.create(policy); + fail("exception expected"); + } catch (IOException ioe) { + assertThat(ioe).hasMessageThat().isEqualTo( + "Use JsonReader.setLenient(true) to accept malformed JSON" + + " at line 1 column 18 path $.name"); + } catch (Exception e) { + throw new AssertionError("the test failed ", e); + } + } + + @Test + public void validPolicyCreatesStaticAuthzInterceptor() throws Exception { + String policy = "{" + + " \"name\" : \"authz\"," + + " \"deny_rules\": [" + + " {" + + " \"name\": \"deny_foo\"," + + " \"source\": {" + + " \"principals\": [" + + " \"spiffe://foo.com\"" + + " ]" + + " }" + + " }" + + " ]," + + " \"allow_rules\": [" + + " {" + + " \"name\": \"allow_all\"" + + " }" + + " ]" + + "}"; + assertNotNull(AuthorizationServerInterceptor.create(policy)); + } +} diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index c8a8669dd1f..90251e6692f 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -5,14 +5,12 @@ plugins { id "com.google.protobuf" id "me.champeau.jmh" + id "ru.vyarus.animalsniffer" } description = "grpc Benchmarks" -startScripts.enabled = false -run.enabled = false - -jmh { +tasks.named("jmh").configure { jvmArgs = ["-server", "-Xms2g", "-Xmx2g"] } @@ -29,19 +27,26 @@ dependencies { project(':grpc-testing'), project(path: ':grpc-xds', configuration: 'shadow'), libraries.hdrhistogram, - libraries.netty_tcnative, - libraries.netty_epoll, - libraries.math - compileOnly libraries.javax_annotation - alpnagent libraries.jetty_alpn_agent + libraries.netty.tcnative, + libraries.netty.tcnative.classes, + libraries.commons.math3 + implementation (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-x86_64" + } + } + compileOnly libraries.javax.annotation + alpnagent libraries.jetty.alpn.agent testImplementation libraries.junit, - libraries.mockito + libraries.mockito.core + + signature libraries.signature.java } import net.ltgt.gradle.errorprone.CheckSeverity -compileJava { +tasks.named("compileJava").configure { // The Control.Void protobuf clashes options.errorprone.check("JavaLangClash", CheckSeverity.OFF) } @@ -55,7 +60,15 @@ def vmArgs = [ "-XX:+PrintGCDetails" ] -task qps_client(type: CreateStartScripts) { +tasks.named("startScripts").configure { + enabled = false +} + +tasks.named("run").configure { + enabled = false +} + +def qps_client = tasks.register("qps_client", CreateStartScripts) { mainClass = "io.grpc.benchmarks.qps.AsyncClient" applicationName = "qps_client" defaultJvmOpts = vmArgs @@ -63,7 +76,7 @@ task qps_client(type: CreateStartScripts) { classpath = startScripts.classpath } -task openloop_client(type: CreateStartScripts) { +def openloop_client = tasks.register("openloop_client", CreateStartScripts) { mainClass = "io.grpc.benchmarks.qps.OpenLoopClient" applicationName = "openloop_client" defaultJvmOpts = vmArgs @@ -71,14 +84,14 @@ task openloop_client(type: CreateStartScripts) { classpath = startScripts.classpath } -task qps_server(type: CreateStartScripts) { +def qps_server = tasks.register("qps_server", CreateStartScripts) { mainClass = "io.grpc.benchmarks.qps.AsyncServer" applicationName = "qps_server" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) classpath = startScripts.classpath } -task benchmark_worker(type: CreateStartScripts) { +def benchmark_worker = tasks.register("benchmark_worker", CreateStartScripts) { mainClass = "io.grpc.benchmarks.driver.LoadWorker" applicationName = "benchmark_worker" defaultJvmOpts = vmArgs diff --git a/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java b/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java index 1af821be837..21f6195075c 100644 --- a/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java +++ b/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java @@ -87,23 +87,20 @@ public void setUp() throws Exception { ServerBuilder serverBuilder; ManagedChannelBuilder channelBuilder; switch (transport) { - case INPROCESS: - { + case INPROCESS: { String name = "bench" + Math.random(); serverBuilder = InProcessServerBuilder.forName(name); channelBuilder = InProcessChannelBuilder.forName(name); break; } - case NETTY: - { + case NETTY: { InetSocketAddress address = new InetSocketAddress("localhost", pickUnusedPort()); serverBuilder = NettyServerBuilder.forAddress(address, serverCreds); channelBuilder = NettyChannelBuilder.forAddress(address) .negotiationType(NegotiationType.PLAINTEXT); break; } - case NETTY_LOCAL: - { + case NETTY_LOCAL: { String name = "bench" + Math.random(); LocalAddress address = new LocalAddress(name); EventLoopGroup group = new DefaultEventLoopGroup(); @@ -118,8 +115,7 @@ public void setUp() throws Exception { groupToShutdown = group; break; } - case NETTY_EPOLL: - { + case NETTY_EPOLL: { InetSocketAddress address = new InetSocketAddress("localhost", pickUnusedPort()); // Reflection used since they are only available on linux. @@ -143,8 +139,7 @@ public void setUp() throws Exception { groupToShutdown = group; break; } - case OKHTTP: - { + case OKHTTP: { int port = pickUnusedPort(); InetSocketAddress address = new InetSocketAddress("localhost", port); serverBuilder = NettyServerBuilder.forAddress(address, serverCreds); diff --git a/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadClientTest.java b/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadClientTest.java index e8c6b929909..2ad8668b9fa 100644 --- a/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadClientTest.java +++ b/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadClientTest.java @@ -21,6 +21,7 @@ import io.grpc.benchmarks.proto.Control; import io.grpc.benchmarks.proto.Stats; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -30,6 +31,14 @@ */ @RunWith(JUnit4.class) public class LoadClientTest { + private LoadClient loadClient; + + @After + public void tearDown() { + if (loadClient != null) { + loadClient.shutdownNow(); + } + } @Test public void testHistogramToStatsConversion() throws Exception { @@ -48,7 +57,7 @@ public void testHistogramToStatsConversion() throws Exception { config.getLoadParamsBuilder().getClosedLoopBuilder(); config.addServerTargets("localhost:9999"); - LoadClient loadClient = new LoadClient(config.build()); + loadClient = new LoadClient(config.build()); loadClient.delay(1); loadClient.delay(10); loadClient.delay(10); diff --git a/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java b/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java index 2f8b7a0ea27..64e2f2694ec 100644 --- a/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java +++ b/benchmarks/src/test/java/io/grpc/benchmarks/driver/LoadWorkerTest.java @@ -16,9 +16,10 @@ package io.grpc.benchmarks.driver; -import static org.junit.Assert.assertTrue; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.ManagedChannel; import io.grpc.benchmarks.Utils; import io.grpc.benchmarks.proto.Control; @@ -26,23 +27,22 @@ import io.grpc.benchmarks.proto.WorkerServiceGrpc; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.StreamObserver; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** - * Basic tests for {@link io.grpc.benchmarks.driver.LoadWorker} + * Basic tests for {@link io.grpc.benchmarks.driver.LoadWorker}. */ @RunWith(JUnit4.class) public class LoadWorkerTest { - private static final int TIMEOUT = 5; + private static final int TIMEOUT = 10; private static final Control.ClientArgs MARK = Control.ClientArgs.newBuilder() .setMark(Control.Mark.newBuilder().setReset(true).build()) .build(); @@ -51,6 +51,7 @@ public class LoadWorkerTest { private ManagedChannel channel; private WorkerServiceGrpc.WorkerServiceStub workerServiceStub; private LinkedBlockingQueue marksQueue; + private StreamObserver serverLifetime; @Before public void setup() throws Exception { @@ -62,6 +63,18 @@ public void setup() throws Exception { marksQueue = new LinkedBlockingQueue<>(); } + @After + public void tearDown() { + if (serverLifetime != null) { + serverLifetime.onCompleted(); + } + try { + WorkerServiceGrpc.newBlockingStub(channel).quitWorker(Control.Void.getDefaultInstance()); + } finally { + channel.shutdownNow(); + } + } + @Test public void runUnaryBlockingClosedLoop() throws Exception { Control.ServerArgs.Builder serverArgsBuilder = Control.ServerArgs.newBuilder(); @@ -181,7 +194,7 @@ private void assertWorkOccurred(StreamObserver clientObserve throws InterruptedException { Stats.ClientStats stat = null; - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 30; i++) { // Poll until we get some stats Thread.sleep(300); clientObserver.onNext(MARK); @@ -194,22 +207,22 @@ private void assertWorkOccurred(StreamObserver clientObserve } } clientObserver.onCompleted(); - assertTrue(stat.hasLatencies()); - assertTrue(stat.getLatencies().getCount() < stat.getLatencies().getSum()); + assertThat(stat.hasLatencies()).isTrue(); + assertThat(stat.getLatencies().getCount()).isLessThan(stat.getLatencies().getSum()); double mean = stat.getLatencies().getSum() / stat.getLatencies().getCount(); - System.out.println("Mean " + mean + " us"); - assertTrue(mean > stat.getLatencies().getMinSeen()); - assertTrue(mean < stat.getLatencies().getMaxSeen()); + System.out.println("Mean " + mean + " ns"); + assertThat(stat.getLatencies().getMinSeen()).isLessThan(mean); + assertThat(stat.getLatencies().getMaxSeen()).isGreaterThan(mean); } private StreamObserver startClient(Control.ClientArgs clientArgs) - throws InterruptedException { - final CountDownLatch clientReady = new CountDownLatch(1); + throws Exception { + final SettableFuture clientReady = SettableFuture.create(); StreamObserver clientObserver = workerServiceStub.runClient( new StreamObserver() { @Override public void onNext(Control.ClientStatus value) { - clientReady.countDown(); + clientReady.set(null); if (value.hasStats()) { marksQueue.add(value.getStats()); } @@ -217,45 +230,43 @@ public void onNext(Control.ClientStatus value) { @Override public void onError(Throwable t) { + clientReady.setException(t); } @Override public void onCompleted() { + clientReady.setException( + new RuntimeException("onCompleted() before receiving response")); } }); // Start the client clientObserver.onNext(clientArgs); - if (!clientReady.await(TIMEOUT, TimeUnit.SECONDS)) { - fail("Client failed to start"); - } + clientReady.get(TIMEOUT, TimeUnit.SECONDS); return clientObserver; } - private int startServer(Control.ServerArgs serverArgs) throws InterruptedException { - final AtomicInteger serverPort = new AtomicInteger(); - final CountDownLatch serverReady = new CountDownLatch(1); - StreamObserver serverObserver = + private int startServer(Control.ServerArgs serverArgs) throws Exception { + final SettableFuture port = SettableFuture.create(); + serverLifetime = workerServiceStub.runServer(new StreamObserver() { @Override public void onNext(Control.ServerStatus value) { - serverPort.set(value.getPort()); - serverReady.countDown(); + port.set(value.getPort()); } @Override public void onError(Throwable t) { + port.setException(t); } @Override public void onCompleted() { + port.setException(new RuntimeException("onCompleted() before receiving response")); } }); // trigger server startup - serverObserver.onNext(serverArgs); - if (!serverReady.await(TIMEOUT, TimeUnit.SECONDS)) { - fail("Server failed to start"); - } - return serverPort.get(); + serverLifetime.onNext(serverArgs); + return port.get(TIMEOUT, TimeUnit.SECONDS); } } diff --git a/binder/build.gradle b/binder/build.gradle index dc3da4c9a72..41c4bd4cc22 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -23,14 +23,14 @@ android { } } } - compileSdkVersion 29 + compileSdkVersion 30 compileOptions { sourceCompatibility 1.8 targetCompatibility 1.8 } defaultConfig { minSdkVersion 19 - targetSdkVersion 29 + targetSdkVersion 30 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" @@ -47,45 +47,45 @@ repositories { dependencies { api project(':grpc-core') - implementation libraries.androidx_annotation - implementation libraries.androidx_core - implementation libraries.androidx_lifecycle_common + implementation libraries.androidx.annotation + implementation libraries.androidx.core + implementation libraries.androidx.lifecycle.common implementation libraries.guava - testImplementation libraries.androidx_core - testImplementation libraries.androidx_test - testImplementation libraries.androidx_lifecycle_common - testImplementation libraries.androidx_lifecycle_service + testImplementation libraries.androidx.core + testImplementation libraries.androidx.test.core + testImplementation libraries.androidx.lifecycle.common + testImplementation libraries.androidx.lifecycle.service testImplementation libraries.junit - testImplementation libraries.mockito + testImplementation libraries.mockito.core testImplementation (libraries.robolectric) { // Unreleased change: https://github.com/robolectric/robolectric/pull/5432 exclude group: 'com.google.auto.service', module: 'auto-service' } - testImplementation (libraries.guava_testlib) { + testImplementation (libraries.guava.testlib) { exclude group: 'junit', module: 'junit' } testImplementation libraries.truth - androidTestAnnotationProcessor libraries.autovalue + androidTestAnnotationProcessor libraries.auto.value androidTestImplementation project(':grpc-testing') androidTestImplementation project(':grpc-protobuf-lite') - androidTestImplementation libraries.autovalue_annotation + androidTestImplementation libraries.auto.value.annotations androidTestImplementation libraries.junit - androidTestImplementation libraries.androidx_core - androidTestImplementation libraries.androidx_test - androidTestImplementation libraries.androidx_test_rules - androidTestImplementation libraries.androidx_test_ext_junit + androidTestImplementation libraries.androidx.core + androidTestImplementation libraries.androidx.test.core + androidTestImplementation libraries.androidx.test.rules + androidTestImplementation libraries.androidx.test.ext.junit androidTestImplementation libraries.truth - androidTestImplementation libraries.mockito_android - androidTestImplementation libraries.androidx_lifecycle_service - androidTestImplementation (libraries.guava_testlib) { + androidTestImplementation libraries.mockito.android + androidTestImplementation libraries.androidx.lifecycle.service + androidTestImplementation (libraries.guava.testlib) { exclude group: 'junit', module: 'junit' } } import net.ltgt.gradle.errorprone.CheckSeverity -tasks.withType(JavaCompile) { +tasks.withType(JavaCompile).configureEach { options.compilerArgs += [ "-Xlint:-cast" ] @@ -94,8 +94,10 @@ tasks.withType(JavaCompile) { options.errorprone.check("UnnecessaryAnonymousClass", CheckSeverity.OFF) } -task javadocs(type: Javadoc) { +tasks.register("javadocs", Javadoc) { source = android.sourceSets.main.java.srcDirs + exclude 'io/grpc/binder/internal/**' + exclude 'io/grpc/binder/Internal*' classpath += files(android.getBootClasspath()) classpath += files({ android.libraryVariants.collect { variant -> @@ -110,12 +112,13 @@ task javadocs(type: Javadoc) { } } -task javadocJar(type: Jar, dependsOn: javadocs) { +tasks.register("javadocJar", Jar) { + dependsOn javadocs archiveClassifier = 'javadoc' from javadocs.destinationDir } -task sourcesJar(type: Jar) { +tasks.register("sourcesJar", Jar) { archiveClassifier = 'sources' from android.sourceSets.main.java.srcDirs } diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java index 56e8f93c140..e3b9978fb36 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java @@ -19,26 +19,25 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; -import android.app.Service; import android.content.Context; -import android.content.Intent; -import android.os.IBinder; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.common.base.Function; import com.google.protobuf.Empty; import io.grpc.CallOptions; import io.grpc.ManagedChannel; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; +import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.protobuf.lite.ProtoLiteUtils; import io.grpc.stub.ClientCalls; import io.grpc.stub.ServerCalls; -import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -59,6 +58,7 @@ public final class BinderSecurityTest { @Nullable ManagedChannel channel; Map> methods = new HashMap<>(); List> calls = new ArrayList<>(); + CountingServerInterceptor countingServerInterceptor; @Before public void setupServiceDefinitionsAndMethods() { @@ -86,6 +86,7 @@ public void setupServiceDefinitionsAndMethods() { } serviceDefinitions.add(builder.build()); } + countingServerInterceptor = new CountingServerInterceptor(); } @After @@ -120,6 +121,7 @@ private Server buildServer( ServerSecurityPolicy serverPolicy) { BinderServerBuilder serverBuilder = BinderServerBuilder.forAddress(listenAddr, receiver); serverBuilder.securityPolicy(serverPolicy); + serverBuilder.intercept(countingServerInterceptor); for (ServerServiceDefinition serviceDefinition : serviceDefinitions) { serverBuilder.addService(serviceDefinition); @@ -195,6 +197,27 @@ public void testPerServicePolicy() throws Exception { } } + @Test + public void testSecurityInterceptorIsClosestToTransport() throws Exception { + createChannel( + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", policy((uid) -> true)) + .servicePolicy("bar", policy((uid) -> false)) + .servicePolicy("baz", policy((uid) -> false)) + .build(), + SecurityPolicies.internalOnly()); + assertThat(countingServerInterceptor.numInterceptedCalls).isEqualTo(0); + for (MethodDescriptor method : methods.values()) { + try { + ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null); + } catch (StatusRuntimeException sre) { + // Ignore. + } + } + // Only the foo calls should have made it to the user interceptor. + assertThat(countingServerInterceptor.numInterceptedCalls).isEqualTo(2); + } + private static SecurityPolicy policy(Function func) { return new SecurityPolicy() { @Override @@ -203,4 +226,17 @@ public Status checkAuthorization(int uid) { } }; } + + private final class CountingServerInterceptor implements ServerInterceptor { + int numInterceptedCalls; + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + numInterceptedCalls += 1; + return next.startCall(call, headers); + } + } } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index b99114bb501..dbacf351780 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -18,42 +18,39 @@ import static com.google.common.truth.Truth.assertThat; -import android.app.Service; import android.content.Context; -import android.content.Intent; -import android.os.IBinder; import android.os.Parcel; import androidx.core.content.ContextCompat; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; -import com.google.common.util.concurrent.testing.TestingExecutors; import com.google.protobuf.Empty; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.Server; import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.HostServices; -import io.grpc.binder.IBinderReceiver; import io.grpc.binder.InboundParcelablePolicy; import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.SecurityPolicy; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; -import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ObjectPool; import io.grpc.internal.StreamListener; import io.grpc.protobuf.lite.ProtoLiteUtils; import io.grpc.stub.ServerCalls; -import java.io.IOException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; import org.junit.After; @@ -94,7 +91,7 @@ public final class BinderClientTransportTest { BinderTransport.BinderClientTransport transport; private final ObjectPool executorServicePool = - new FixedObjectPool<>(TestingExecutors.sameThreadScheduledExecutor()); + new FixedObjectPool<>(Executors.newScheduledThreadPool(1)); private final TestTransportListener transportListener = new TestTransportListener(); private final TestStreamListener streamListener = new TestStreamListener(); @@ -127,39 +124,50 @@ public void setUp() throws Exception { .build(); serverAddress = HostServices.allocateService(appContext); - HostServices.configureService(serverAddress, + HostServices.configureService( + serverAddress, HostServices.serviceParamsBuilder() - .setServerFactory((service, receiver) -> - BinderServerBuilder.forAddress(serverAddress, receiver) - .addService(serviceDef) - .build()) - .build()); - - transport = - new BinderTransport.BinderClientTransport( - appContext, - serverAddress, - BindServiceFlags.DEFAULTS, - ContextCompat.getMainExecutor(appContext), - executorServicePool, - executorServicePool, - SecurityPolicies.internalOnly(), - InboundParcelablePolicy.DEFAULT, - Attributes.EMPTY); - - Runnable r = transport.start(transportListener); - r.run(); - transportListener.awaitReady(); + .setServerFactory( + (service, receiver) -> + BinderServerBuilder.forAddress(serverAddress, receiver) + .addService(serviceDef) + .build()) + .build()); + } + + private class BinderClientTransportBuilder { + private SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); + + public BinderClientTransportBuilder setSecurityPolicy(SecurityPolicy securityPolicy) { + this.securityPolicy = securityPolicy; + return this; + } + + public BinderTransport.BinderClientTransport build() { + return new BinderTransport.BinderClientTransport( + appContext, + serverAddress, + BindServiceFlags.DEFAULTS, + ContextCompat.getMainExecutor(appContext), + executorServicePool, + executorServicePool, + securityPolicy, + InboundParcelablePolicy.DEFAULT, + Attributes.EMPTY); + } } @After public void tearDown() throws Exception { transport.shutdownNow(Status.OK); HostServices.awaitServiceShutdown(); + executorServicePool.getObject().shutdownNow(); } @Test public void testShutdownBeforeStreamStart_b153326034() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream( methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); @@ -170,6 +178,8 @@ public void testShutdownBeforeStreamStart_b153326034() throws Exception { @Test public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -188,6 +198,8 @@ public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception @Test public void testTransactionForDiscardedCall_b155244043() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -206,6 +218,8 @@ public void testTransactionForDiscardedCall_b155244043() throws Exception { @Test public void testBadTransactionStreamThroughput_b163053382() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -225,6 +239,8 @@ public void testBadTransactionStreamThroughput_b163053382() throws Exception { @Test public void testMessageProducerClosedAfterStream_b169313545() { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -243,6 +259,22 @@ public void testMessageProducerClosedAfterStream_b169313545() { streamListener.drainMessages(); } + @Test + public void testNewStreamBeforeTransportReadyFails() throws InterruptedException { + // Use a special SecurityPolicy that lets us act before the transport is setup/ready. + BlockingSecurityPolicy bsp = new BlockingSecurityPolicy(); + transport = new BinderClientTransportBuilder().setSecurityPolicy(bsp).build(); + transport.start(transportListener).run(); + ClientStream stream = + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); + assertThat(streamListener.awaitClose().getCode()).isEqualTo(Code.INTERNAL); + + // Unblock the SETUP_TRANSPORT handshake and make sure it becomes ready in the usual way. + bsp.provideNextCheckAuthorizationResult(Status.OK); + transportListener.awaitReady(); + } + private synchronized void awaitServerCallsCompleted(int calls) { while (serverCallsCompleted < calls) { try { @@ -253,6 +285,12 @@ private synchronized void awaitServerCallsCompleted(int calls) { } } + private static void startAndAwaitReady( + BinderTransport.BinderClientTransport transport, TestTransportListener transportListener) { + transport.start(transportListener).run(); + transportListener.awaitReady(); + } + private static final class TestTransportListener implements ManagedClientTransport.Listener { public boolean ready; public boolean inUse; @@ -313,6 +351,17 @@ public synchronized void awaitMessages() { } } + public synchronized Status awaitClose() { + while (closedStatus == null) { + try { + wait(100); + } catch (InterruptedException inte) { + throw new AssertionError("Interrupted waiting for close"); + } + } + return closedStatus; + } + public int drainMessages() { int n = 0; while (messageProducer.next() != null) { @@ -336,4 +385,24 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { this.closedStatus = status; } } + + /** + * A SecurityPolicy that blocks the transport authorization check until a test sets the outcome. + */ + static class BlockingSecurityPolicy extends SecurityPolicy { + private final BlockingQueue results = new LinkedBlockingQueue<>(); + + public void provideNextCheckAuthorizationResult(Status status) { + results.add(status); + } + + @Override + public Status checkAuthorization(int uid) { + try { + return results.take(); + } catch (InterruptedException e) { + return Status.fromThrowable(e); + } + } + } } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java index 24af04d4d61..5a1b302f768 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java @@ -115,6 +115,11 @@ public void socketStats() throws Exception {} @Override public void flowControlPushBack() throws Exception {} + @Test + @Ignore("Not yet implemented. See https://github.com/grpc/grpc-java/issues/8931") + @Override + public void serverNotListening() throws Exception {} + @Test @Ignore("This test isn't appropriate for BinderTransport.") @Override diff --git a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java index 4809a2db43f..90197ee8382 100644 --- a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java +++ b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java @@ -22,7 +22,6 @@ import android.content.ComponentName; import android.content.Context; import android.content.Intent; -import io.grpc.ExperimentalApi; import java.net.SocketAddress; /** @@ -41,8 +40,7 @@ * fields, namely, an action of {@link ApiConstants#ACTION_BIND}, an empty category set and null * type and data URI. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") -public class AndroidComponentAddress extends SocketAddress { // NOTE: Only temporarily non-final. +public final class AndroidComponentAddress extends SocketAddress { private static final long serialVersionUID = 0L; private final Intent bindIntent; // An "explicit" Intent. In other words, getComponent() != null. @@ -103,6 +101,11 @@ public static AndroidComponentAddress forComponent(ComponentName component) { new Intent(ApiConstants.ACTION_BIND).setComponent(component)); } + /** + * Returns the Authority which is the package name of the target app. + * + *

See {@link android.content.ComponentName}. + */ public String getAuthority() { return getComponent().getPackageName(); } @@ -121,14 +124,29 @@ public Intent asBindIntent() { /** * Returns this address as an "android-app://" uri. + * + *

See {@link Intent#URI_ANDROID_APP_SCHEME} for details. */ public String asAndroidAppUri() { - return bindIntent.toUri(URI_ANDROID_APP_SCHEME); + Intent intentForUri = bindIntent; + if (intentForUri.getPackage() == null) { + // URI_ANDROID_APP_SCHEME requires an "explicit package name" which isn't set by any of our + // factory methods. Oddly, our explicit ComponentName is not enough. + intentForUri = intentForUri.cloneFilter().setPackage(getComponent().getPackageName()); + } + return intentForUri.toUri(URI_ANDROID_APP_SCHEME); } @Override public int hashCode() { - return bindIntent.filterHashCode(); + Intent intentForHashCode = bindIntent; + // Clear a (usually redundant) package filter to work around an Android >= 31 bug where certain + // Intents compare filterEquals() but have different filterHashCode() values. It's always safe + // to include fewer fields in the hashCode() computation. + if (intentForHashCode.getPackage() != null) { + intentForHashCode = intentForHashCode.cloneFilter().setPackage(null); + } + return intentForHashCode.filterHashCode(); } @Override diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java index 91e4e8f1c76..214eb6dc4c5 100644 --- a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java @@ -17,16 +17,13 @@ package io.grpc.binder; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; -import android.app.Application; -import android.content.ComponentName; import android.content.Context; import androidx.core.content.ContextCompat; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.ExperimentalApi; import io.grpc.ForwardingChannelBuilder; import io.grpc.ManagedChannel; @@ -124,6 +121,7 @@ public static BinderChannelBuilder forTarget(String target) { private SecurityPolicy securityPolicy; private InboundParcelablePolicy inboundParcelablePolicy; private BindServiceFlags bindServiceFlags; + private boolean strictLifecycleManagement; private BinderChannelBuilder( @Nullable AndroidComponentAddress directAddress, @@ -164,6 +162,7 @@ public ClientTransportFactory buildClientTransportFactory() { new BinderChannelTransportFactoryBuilder(), null); } + idleTimeout(60, TimeUnit.SECONDS); } @Override @@ -224,6 +223,25 @@ public BinderChannelBuilder inboundParcelablePolicy( return this; } + /** + * Disables the channel idle timeout and prevents it from being enabled. This + * allows a centralized application method to configure the channel builder + * and return it, without worrying about another part of the application + * accidentally enabling the idle timeout. + */ + public BinderChannelBuilder strictLifecycleManagement() { + strictLifecycleManagement = true; + super.idleTimeout(1000, TimeUnit.DAYS); // >30 days disables timeouts entirely. + return this; + } + + @Override + public BinderChannelBuilder idleTimeout(long value, TimeUnit unit) { + checkState(!strictLifecycleManagement, "Idle timeouts are not supported when strict lifecycle management is enabled"); + super.idleTimeout(value, unit); + return this; + } + /** Creates new binder transports. */ private static final class TransportFactory implements ClientTransportFactory { private final Context sourceContext; diff --git a/binder/src/main/java/io/grpc/binder/BinderInternal.java b/binder/src/main/java/io/grpc/binder/BinderInternal.java new file mode 100644 index 00000000000..34f7793714f --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/BinderInternal.java @@ -0,0 +1,34 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import android.os.IBinder; +import io.grpc.Internal; + +/** + * Helper class to expose IBinderReceiver methods for legacy internal builders. + */ +@Internal +public class BinderInternal { + + /** + * Sets the receiver's {@link IBinder} using {@link IBinderReceiver#set(IBinder)}. + */ + public static void setIBinder(IBinderReceiver receiver, IBinder binder) { + receiver.set(binder); + } +} diff --git a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java index 383bd3f8e49..eaa94bffc45 100644 --- a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java @@ -21,34 +21,24 @@ import android.app.Service; import android.os.IBinder; -import com.google.common.base.Supplier; import com.google.errorprone.annotations.DoNotCall; -import io.grpc.CompressorRegistry; -import io.grpc.DecompressorRegistry; import io.grpc.ExperimentalApi; import io.grpc.Server; import io.grpc.ServerBuilder; -import io.grpc.ServerStreamTracer; import io.grpc.binder.internal.BinderServer; import io.grpc.binder.internal.BinderTransportSecurity; import io.grpc.ForwardingServerBuilder; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.InternalServer; import io.grpc.internal.ServerImplBuilder; -import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import java.io.File; -import java.io.IOException; -import java.util.List; import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; /** * Builder for a server that services requests from an Android Service. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public final class BinderServerBuilder extends ForwardingServerBuilder { @@ -81,6 +71,7 @@ public static BinderServerBuilder forPort(int port) { SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); private ServerSecurityPolicy securityPolicy; private InboundParcelablePolicy inboundParcelablePolicy; + private boolean isBuilt; private BinderServerBuilder( AndroidComponentAddress listenAddress, @@ -95,20 +86,13 @@ private BinderServerBuilder( streamTracerFactories, securityPolicy, inboundParcelablePolicy); - binderReceiver.set(server.getHostBinder()); + BinderInternal.setIBinder(binderReceiver, server.getHostBinder()); return server; }); - // Disable compression by default, since there's little benefit when all communication is - // on-device, and it means sending supported-encoding headers with every call. - decompressorRegistry(DecompressorRegistry.emptyInstance()); - compressorRegistry(CompressorRegistry.newEmptyInstance()); - // Disable stats and tracing by default. serverImplBuilder.setStatsEnabled(false); serverImplBuilder.setTracingEnabled(false); - - BinderTransportSecurity.installAuthInterceptor(this); } @Override @@ -117,12 +101,14 @@ protected ServerBuilder delegate() { } /** Enable stats collection using census. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public BinderServerBuilder enableStats() { serverImplBuilder.setStatsEnabled(true); return this; } /** Enable tracing using census. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public BinderServerBuilder enableTracing() { serverImplBuilder.setTracingEnabled(true); return this; @@ -157,12 +143,16 @@ public BinderServerBuilder securityPolicy(ServerSecurityPolicy securityPolicy) { } /** Sets the policy for inbound parcelable objects. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public BinderServerBuilder inboundParcelablePolicy( InboundParcelablePolicy inboundParcelablePolicy) { this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); return this; } + /** + * Always fails. TLS is not supported in BinderServer. + */ @Override public BinderServerBuilder useTransportSecurity(File certChain, File privateKey) { throw new UnsupportedOperationException("TLS not supported in BinderServer"); @@ -177,6 +167,11 @@ public BinderServerBuilder useTransportSecurity(File certChain, File privateKey) */ @Override // For javadoc refinement only. public Server build() { + // Since we install a final interceptor here, we need to ensure we're only built once. + checkState(!isBuilt, "BinderServerBuilder can only be used to build one server instance."); + isBuilt = true; + // We install the security interceptor last, so it's closest to the transport. + BinderTransportSecurity.installAuthInterceptor(this); return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/IBinderReceiver.java b/binder/src/main/java/io/grpc/binder/IBinderReceiver.java index bd8e1f50af9..adf4a0d3d8e 100644 --- a/binder/src/main/java/io/grpc/binder/IBinderReceiver.java +++ b/binder/src/main/java/io/grpc/binder/IBinderReceiver.java @@ -17,24 +17,22 @@ package io.grpc.binder; import android.os.IBinder; -import io.grpc.ExperimentalApi; import javax.annotation.Nullable; /** A container for at most one instance of {@link IBinder}, useful as an "out parameter". */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public final class IBinderReceiver { - @Nullable private IBinder value; + @Nullable private volatile IBinder value; /** Constructs a new, initially empty, container. */ public IBinderReceiver() {} /** Returns the contents of this container or null if it is empty. */ @Nullable - public synchronized IBinder get() { + public IBinder get() { return value; } - public synchronized void set(IBinder value) { + protected void set(IBinder value) { this.value = value; } } diff --git a/binder/src/main/java/io/grpc/binder/ParcelableUtils.java b/binder/src/main/java/io/grpc/binder/ParcelableUtils.java index 164de7de8b8..969344ea68d 100644 --- a/binder/src/main/java/io/grpc/binder/ParcelableUtils.java +++ b/binder/src/main/java/io/grpc/binder/ParcelableUtils.java @@ -17,7 +17,6 @@ package io.grpc.binder; import android.os.Parcelable; -import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.binder.internal.MetadataHelper; @@ -26,7 +25,6 @@ * *

This class models the same pattern as the {@code ProtoLiteUtils} class. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public final class ParcelableUtils { private ParcelableUtils() {} diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java index dcf36be00ca..1a31ef823d3 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java @@ -17,34 +17,47 @@ package io.grpc.binder; import android.annotation.SuppressLint; +import android.app.admin.DevicePolicyManager; +import android.content.Context; import android.content.pm.PackageInfo; import android.content.pm.PackageManager; import android.content.pm.PackageManager.NameNotFoundException; import android.content.pm.Signature; import android.os.Build; +import android.os.Build.VERSION; import android.os.Process; import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.hash.Hashing; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Iterator; import java.util.List; -import javax.annotation.CheckReturnValue; /** Static factory methods for creating standard security policies. */ @CheckReturnValue -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public final class SecurityPolicies { private static final int MY_UID = Process.myUid(); + private static final int SHA_256_BYTES_LENGTH = 32; private SecurityPolicies() {} + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public static ServerSecurityPolicy serverInternalOnly() { return new ServerSecurityPolicy(); } + /** + * Creates a default {@link SecurityPolicy} that allows access only to callers with the same UID + * as the current process. + */ public static SecurityPolicy internalOnly() { return new SecurityPolicy() { @Override @@ -57,6 +70,7 @@ public Status checkAuthorization(int uid) { }; } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public static SecurityPolicy permissionDenied(String description) { Status denied = Status.PERMISSION_DENIED.withDescription(description); return new SecurityPolicy() { @@ -75,12 +89,29 @@ public Status checkAuthorization(int uid) { * @param requiredSignature the allowed signature of the allowed package. * @throws NullPointerException if any of the inputs are {@code null}. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public static SecurityPolicy hasSignature( PackageManager packageManager, String packageName, Signature requiredSignature) { return oneOfSignatures( packageManager, packageName, ImmutableList.of(requiredSignature)); } + /** + * Creates {@link SecurityPolicy} which checks if the SHA-256 hash of the package signature + * matches {@code requiredSignatureSha256Hash}. + * + * @param packageName the package name of the allowed package. + * @param requiredSignatureSha256Hash the SHA-256 digest of the signature of the allowed package. + * @throws NullPointerException if any of the inputs are {@code null}. + * @throws IllegalArgumentException if {@code requiredSignatureSha256Hash} is not of length 32. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") + public static SecurityPolicy hasSignatureSha256Hash( + PackageManager packageManager, String packageName, byte[] requiredSignatureSha256Hash) { + return oneOfSignatureSha256Hash( + packageManager, packageName, ImmutableList.of(requiredSignatureSha256Hash)); + } + /** * Creates a {@link SecurityPolicy} which checks if the package signature * matches any of {@code requiredSignatures}. @@ -90,6 +121,7 @@ public static SecurityPolicy hasSignature( * @throws NullPointerException if any of the inputs are {@code null}. * @throws IllegalArgumentException if {@code requiredSignatures} is empty. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public static SecurityPolicy oneOfSignatures( PackageManager packageManager, String packageName, @@ -114,6 +146,88 @@ public Status checkAuthorization(int uid) { }; } + /** + * Creates {@link SecurityPolicy} which checks if the SHA-256 hash of the package signature + * matches any of {@code requiredSignatureSha256Hashes}. + * + * @param packageName the package name of the allowed package. + * @param requiredSignatureSha256Hashes the SHA-256 digests of the signatures of the allowed + * package. + * @throws NullPointerException if any of the inputs are {@code null}. + * @throws IllegalArgumentException if {@code requiredSignatureSha256Hashes} is empty, or if any + * of the {@code requiredSignatureSha256Hashes} are not of length 32. + */ + public static SecurityPolicy oneOfSignatureSha256Hash( + PackageManager packageManager, + String packageName, + List requiredSignatureSha256Hashes) { + Preconditions.checkNotNull(packageManager); + Preconditions.checkNotNull(packageName); + Preconditions.checkNotNull(requiredSignatureSha256Hashes); + Preconditions.checkArgument(!requiredSignatureSha256Hashes.isEmpty()); + + ImmutableList.Builder immutableListBuilder = ImmutableList.builder(); + for (byte[] requiredSignatureSha256Hash : requiredSignatureSha256Hashes) { + Preconditions.checkNotNull(requiredSignatureSha256Hash); + Preconditions.checkArgument(requiredSignatureSha256Hash.length == SHA_256_BYTES_LENGTH); + immutableListBuilder.add( + Arrays.copyOf(requiredSignatureSha256Hash, requiredSignatureSha256Hash.length)); + } + ImmutableList requiredSignaturesHashesImmutable = immutableListBuilder.build(); + + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return checkUidSha256Signature( + packageManager, uid, packageName, requiredSignaturesHashesImmutable); + } + }; + } + + /** + * Creates {@link SecurityPolicy} which checks if the app is a device owner app. See + * {@link DevicePolicyManager}. + */ + public static SecurityPolicy isDeviceOwner(Context applicationContext) { + DevicePolicyManager devicePolicyManager = + (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); + return anyPackageWithUidSatisfies( + applicationContext, + pkg -> VERSION.SDK_INT >= 18 && devicePolicyManager.isDeviceOwnerApp(pkg), + "Rejected by device owner policy. No packages found for UID.", + "Rejected by device owner policy"); + } + + /** + * Creates {@link SecurityPolicy} which checks if the app is a profile owner app. See + * {@link DevicePolicyManager}. + */ + public static SecurityPolicy isProfileOwner(Context applicationContext) { + DevicePolicyManager devicePolicyManager = + (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); + return anyPackageWithUidSatisfies( + applicationContext, + pkg -> VERSION.SDK_INT >= 21 && devicePolicyManager.isProfileOwnerApp(pkg), + "Rejected by profile owner policy. No packages found for UID.", + "Rejected by profile owner policy"); + } + + /** + * Creates {@link SecurityPolicy} which checks if the app is a profile owner app on an + * organization-owned device. See {@link DevicePolicyManager}. + */ + public static SecurityPolicy isProfileOwnerOnOrganizationOwnedDevice(Context applicationContext) { + DevicePolicyManager devicePolicyManager = + (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); + return anyPackageWithUidSatisfies( + applicationContext, + pkg -> VERSION.SDK_INT >= 30 + && devicePolicyManager.isProfileOwnerApp(pkg) + && devicePolicyManager.isOrganizationOwnedDeviceWithManagedProfile(), + "Rejected by profile owner on organization-owned device policy. No packages found for UID.", + "Rejected by profile owner on organization-owned device policy"); + } + private static Status checkUidSignature( PackageManager packageManager, int uid, @@ -130,7 +244,7 @@ private static Status checkUidSignature( continue; } packageNameMatched = true; - if (checkPackageSignature(packageManager, pkg, requiredSignatures)) { + if (checkPackageSignature(packageManager, pkg, requiredSignatures::contains)) { return Status.OK; } } @@ -139,19 +253,50 @@ private static Status checkUidSignature( + packageNameMatched); } + private static Status checkUidSha256Signature( + PackageManager packageManager, + int uid, + String packageName, + ImmutableList requiredSignatureSha256Hashes) { + String[] packages = packageManager.getPackagesForUid(uid); + if (packages == null) { + return Status.UNAUTHENTICATED.withDescription( + "Rejected by (SHA-256 hash signature check) security policy"); + } + boolean packageNameMatched = false; + for (String pkg : packages) { + if (!packageName.equals(pkg)) { + continue; + } + packageNameMatched = true; + if (checkPackageSignature( + packageManager, + pkg, + (signature) -> + checkSignatureSha256HashesMatch(signature, requiredSignatureSha256Hashes))) { + return Status.OK; + } + } + return Status.PERMISSION_DENIED.withDescription( + "Rejected by (SHA-256 hash signature check) security policy. Package name matched: " + + packageNameMatched); + } + /** * Checks if the signature of {@code packageName} matches one of the given signatures. * * @param packageName the package to be checked - * @param requiredSignatures list of signatures. - * @return {@code true} if {@code packageName} has a matching signature. + * @param signatureCheckFunction {@link Predicate} that takes a signature and verifies if it + * satisfies any signature constraints + * return {@code true} if {@code packageName} has a signature that satisfies {@code + * signatureCheckFunction}. */ @SuppressWarnings("deprecation") // For PackageInfo.signatures @SuppressLint("PackageManagerGetSignatures") // We only allow 1 signature. private static boolean checkPackageSignature( PackageManager packageManager, String packageName, - ImmutableList requiredSignatures) { + Predicate signatureCheckFunction) { PackageInfo packageInfo; try { if (Build.VERSION.SDK_INT >= 28) { @@ -166,7 +311,7 @@ private static boolean checkPackageSignature( : packageInfo.signingInfo.getSigningCertificateHistory(); for (Signature signature : signatures) { - if (requiredSignatures.contains(signature)) { + if (signatureCheckFunction.apply(signature)) { return true; } } @@ -178,7 +323,7 @@ private static boolean checkPackageSignature( return false; } - if (requiredSignatures.contains(packageInfo.signatures[0])) { + if (signatureCheckFunction.apply(packageInfo.signatures[0])) { return true; } } @@ -187,4 +332,175 @@ private static boolean checkPackageSignature( } return false; } + + /** + * Creates a {@link SecurityPolicy} that allows access if and only if *all* of the specified + * {@code securityPolicies} allow access. + * + * @param securityPolicies the security policies that all must allow access. + * @throws NullPointerException if any of the inputs are {@code null}. + * @throws IllegalArgumentException if {@code securityPolicies} is empty. + */ + public static SecurityPolicy allOf(SecurityPolicy... securityPolicies) { + Preconditions.checkNotNull(securityPolicies, "securityPolicies"); + Preconditions.checkArgument(securityPolicies.length > 0, "securityPolicies must not be empty"); + + return allOfSecurityPolicy(securityPolicies); + } + + private static SecurityPolicy allOfSecurityPolicy(SecurityPolicy... securityPolicies) { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + for (SecurityPolicy policy : securityPolicies) { + Status checkAuth = policy.checkAuthorization(uid); + if (!checkAuth.isOk()) { + return checkAuth; + } + } + + return Status.OK; + } + }; + } + + /** + * Creates a {@link SecurityPolicy} that allows access if *any* of the specified {@code + * securityPolicies} allow access. + * + *

Policies will be checked in the order that they are passed. If a policy allows access, + * subsequent policies will not be checked. + * + *

If all policies deny access, the {@link io.grpc.Status} returned by {@code + * checkAuthorization} will included the concatenated descriptions of the failed policies and + * attach any additional causes as suppressed throwables. The status code will be that of the + * first failed policy. + * + * @param securityPolicies the security policies that will be checked. + * @throws NullPointerException if any of the inputs are {@code null}. + * @throws IllegalArgumentException if {@code securityPolicies} is empty. + */ + public static SecurityPolicy anyOf(SecurityPolicy... securityPolicies) { + Preconditions.checkNotNull(securityPolicies, "securityPolicies"); + Preconditions.checkArgument(securityPolicies.length > 0, "securityPolicies must not be empty"); + + return anyOfSecurityPolicy(securityPolicies); + } + + private static SecurityPolicy anyOfSecurityPolicy(SecurityPolicy... securityPolicies) { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + List failed = new ArrayList<>(); + for (SecurityPolicy policy : securityPolicies) { + Status checkAuth = policy.checkAuthorization(uid); + if (checkAuth.isOk()) { + return checkAuth; + } + failed.add(checkAuth); + } + + Iterator iter = failed.iterator(); + Status toReturn = iter.next(); + while (iter.hasNext()) { + Status append = iter.next(); + toReturn = toReturn.augmentDescription(append.getDescription()); + if (append.getCause() != null) { + if (toReturn.getCause() != null) { + toReturn.getCause().addSuppressed(append.getCause()); + } else { + toReturn = toReturn.withCause(append.getCause()); + } + } + } + return toReturn; + } + }; + } + + /** + * Creates a {@link SecurityPolicy} which checks if the caller has all of the given permissions + * from {@code permissions}. + * + * @param permissions all permissions that the calling package needs to have + * @throws NullPointerException if any of the inputs are {@code null} + * @throws IllegalArgumentException if {@code permissions} is empty + */ + public static SecurityPolicy hasPermissions( + PackageManager packageManager, ImmutableSet permissions) { + Preconditions.checkNotNull(packageManager, "packageManager"); + Preconditions.checkNotNull(permissions, "permissions"); + Preconditions.checkArgument(!permissions.isEmpty(), "permissions"); + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return checkPermissions(uid, packageManager, permissions); + } + }; + } + + private static Status checkPermissions( + int uid, PackageManager packageManager, ImmutableSet permissions) { + String[] packages = packageManager.getPackagesForUid(uid); + if (packages == null || packages.length == 0) { + return Status.UNAUTHENTICATED.withDescription( + "Rejected by permission check security policy. No packages found for uid"); + } + for (String pkg : packages) { + for (String permission : permissions) { + if (packageManager.checkPermission(permission, pkg) != PackageManager.PERMISSION_GRANTED) { + return Status.PERMISSION_DENIED.withDescription( + "Rejected by permission check security policy. " + + pkg + + " does not have permission " + + permission); + } + } + } + + return Status.OK; + } + + private static SecurityPolicy anyPackageWithUidSatisfies( + Context applicationContext, + Predicate condition, + String errorMessageForNoPackages, + String errorMessageForDenied) { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + String[] packages = applicationContext.getPackageManager().getPackagesForUid(uid); + if (packages == null || packages.length == 0) { + return Status.UNAUTHENTICATED.withDescription(errorMessageForNoPackages); + } + + for (String pkg : packages) { + if (condition.apply(pkg)) { + return Status.OK; + } + } + return Status.PERMISSION_DENIED.withDescription(errorMessageForDenied); + } + }; + } + + /** + * Checks if the SHA-256 hash of the {@code signature} matches one of the {@code + * expectedSignatureSha256Hashes}. + */ + private static boolean checkSignatureSha256HashesMatch( + Signature signature, List expectedSignatureSha256Hashes) { + byte[] signatureHash = getSha256Hash(signature); + for (byte[] hash : expectedSignatureSha256Hashes) { + if (Arrays.equals(hash, signatureHash)) { + return true; + } + } + return false; + } + + /** Returns SHA-256 hash of the provided signature. */ + private static byte[] getSha256Hash(Signature signature) { + return Hashing.sha256().hashBytes(signature.toByteArray()).asBytes(); + } } diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java index d13f3a863fd..6b0fb40310a 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java @@ -16,7 +16,6 @@ package io.grpc.binder; -import io.grpc.ExperimentalApi; import io.grpc.Status; import javax.annotation.CheckReturnValue; @@ -37,7 +36,6 @@ * re-installation of the applications involved. */ @CheckReturnValue -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public abstract class SecurityPolicy { protected SecurityPolicy() {} diff --git a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java index 46a124e1f47..d91a487a57c 100644 --- a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java @@ -17,7 +17,6 @@ package io.grpc.binder; import com.google.common.collect.ImmutableMap; -import io.grpc.ExperimentalApi; import io.grpc.Status; import java.util.HashMap; import java.util.Map; @@ -28,7 +27,6 @@ * * Contains a default policy, and optional policies for each server. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public final class ServerSecurityPolicy { private final SecurityPolicy defaultPolicy; diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 219651a8b69..bdcd53a9ea6 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -29,6 +29,7 @@ import android.os.RemoteException; import android.os.TransactionTooLargeException; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; @@ -58,7 +59,6 @@ import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; import io.grpc.internal.StatsTraceContext; -import io.grpc.internal.TimeProvider; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedHashSet; @@ -109,17 +109,18 @@ public abstract class BinderTransport * active transport. */ @Internal - public static final Attributes.Key REMOTE_UID = Attributes.Key.create("remote-uid"); + public static final Attributes.Key REMOTE_UID = + Attributes.Key.create("internal:remote-uid"); /** The authority of the server. */ @Internal public static final Attributes.Key SERVER_AUTHORITY = - Attributes.Key.create("server-authority"); + Attributes.Key.create("internal:server-authority"); /** A transport attribute to hold the {@link InboundParcelablePolicy}. */ @Internal public static final Attributes.Key INBOUND_PARCELABLE_POLICY = - Attributes.Key.create("inbound-parcelable-policy"); + Attributes.Key.create("internal:inbound-parcelable-policy"); /** * Version code for this wire format. @@ -201,7 +202,7 @@ protected enum TransportState { @Nullable protected Status shutdownStatus; - @Nullable private IBinder outgoingBinder; + @Nullable private OneWayBinderProxy outgoingBinder; private final FlowController flowController; @@ -278,10 +279,10 @@ final void setState(TransportState newState) { } @GuardedBy("this") - protected boolean setOutgoingBinder(IBinder binder) { + protected boolean setOutgoingBinder(OneWayBinderProxy binder) { this.outgoingBinder = binder; try { - binder.linkToDeath(this, 0); + binder.getDelegate().linkToDeath(this, 0); return true; } catch (RemoteException re) { return false; @@ -326,19 +327,13 @@ final void sendSetupTransaction() { } @GuardedBy("this") - final void sendSetupTransaction(IBinder iBinder) { - Parcel parcel = Parcel.obtain(); - try { - parcel.writeInt(WIRE_FORMAT_VERSION); - parcel.writeStrongBinder(incomingBinder); - if (!iBinder.transact(SETUP_TRANSPORT, parcel, null, IBinder.FLAG_ONEWAY)) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Failed sending SETUP_TRANSPORT transaction"), true); - } + final void sendSetupTransaction(OneWayBinderProxy iBinder) { + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeInt(WIRE_FORMAT_VERSION); + parcel.get().writeStrongBinder(incomingBinder); + iBinder.transact(SETUP_TRANSPORT, parcel); } catch (RemoteException re) { shutdownInternal(statusFromRemoteException(re), true); - } finally { - parcel.recycle(); } } @@ -346,19 +341,16 @@ final void sendSetupTransaction(IBinder iBinder) { private final void sendShutdownTransaction() { if (outgoingBinder != null) { try { - outgoingBinder.unlinkToDeath(this, 0); + outgoingBinder.getDelegate().unlinkToDeath(this, 0); } catch (NoSuchElementException e) { // Ignore. } - Parcel parcel = Parcel.obtain(); - try { + try (ParcelHolder parcel = ParcelHolder.obtain()) { // Send empty flags to avoid a memory leak linked to empty parcels (b/207778694). - parcel.writeInt(0); - outgoingBinder.transact(SHUTDOWN_TRANSPORT, parcel, null, IBinder.FLAG_ONEWAY); + parcel.get().writeInt(0); + outgoingBinder.transact(SHUTDOWN_TRANSPORT, parcel); } catch (RemoteException re) { // Ignore. - } finally { - parcel.recycle(); } } } @@ -369,14 +361,11 @@ protected synchronized void sendPing(int id) throws StatusException { } else if (outgoingBinder == null) { throw Status.FAILED_PRECONDITION.withDescription("Transport not ready.").asException(); } else { - Parcel parcel = Parcel.obtain(); - try { - parcel.writeInt(id); - outgoingBinder.transact(PING, parcel, null, IBinder.FLAG_ONEWAY); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeInt(id); + outgoingBinder.transact(PING, parcel); } catch (RemoteException re) { throw statusFromRemoteException(re).asException(); - } finally { - parcel.recycle(); } } } @@ -401,12 +390,10 @@ final void unregisterCall(int callId) { } } - final void sendTransaction(int callId, Parcel parcel) throws StatusException { - int dataSize = parcel.dataSize(); + final void sendTransaction(int callId, ParcelHolder parcel) throws StatusException { + int dataSize = parcel.get().dataSize(); try { - if (!outgoingBinder.transact(callId, parcel, null, IBinder.FLAG_ONEWAY)) { - throw Status.UNAVAILABLE.withDescription("Failed sending transaction").asException(); - } + outgoingBinder.transact(callId, parcel); } catch (RemoteException re) { throw statusFromRemoteException(re).asException(); } @@ -416,16 +403,13 @@ final void sendTransaction(int callId, Parcel parcel) throws StatusException { } final void sendOutOfBandClose(int callId, Status status) { - Parcel parcel = Parcel.obtain(); - try { - parcel.writeInt(0); // Placeholder for flags. Will be filled in below. - int flags = TransactionUtils.writeStatus(parcel, status); - TransactionUtils.fillInFlags(parcel, flags | TransactionUtils.FLAG_OUT_OF_BAND_CLOSE); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeInt(0); // Placeholder for flags. Will be filled in below. + int flags = TransactionUtils.writeStatus(parcel.get(), status); + TransactionUtils.fillInFlags(parcel.get(), flags | TransactionUtils.FLAG_OUT_OF_BAND_CLOSE); sendTransaction(callId, parcel); } catch (StatusException e) { logger.log(Level.WARNING, "Failed sending oob close transaction", e); - } finally { - parcel.recycle(); } } @@ -496,10 +480,12 @@ protected Inbound createInbound(int callId) { protected void handleSetupTransport(Parcel parcel) {} @GuardedBy("this") - private final void handlePing(Parcel parcel) { + private final void handlePing(Parcel requestParcel) { + int id = requestParcel.readInt(); if (transportState == TransportState.READY) { - try { - outgoingBinder.transact(PING_RESPONSE, parcel, null, IBinder.FLAG_ONEWAY); + try (ParcelHolder replyParcel = ParcelHolder.obtain()) { + replyParcel.get().writeInt(id); + outgoingBinder.transact(PING_RESPONSE, replyParcel); } catch (RemoteException re) { // Ignore. } @@ -510,21 +496,15 @@ private final void handlePing(Parcel parcel) { protected void handlePingResponse(Parcel parcel) {} @GuardedBy("this") - private void sendAcknowledgeBytes(IBinder iBinder) { + private void sendAcknowledgeBytes(OneWayBinderProxy iBinder) { // Send a transaction to acknowledge reception of incoming data. long n = numIncomingBytes.get(); acknowledgedIncomingBytes = n; - Parcel parcel = Parcel.obtain(); - try { - parcel.writeLong(n); - if (!iBinder.transact(ACKNOWLEDGE_BYTES, parcel, null, IBinder.FLAG_ONEWAY)) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Failed sending ack bytes transaction"), true); - } + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeLong(n); + iBinder.transact(ACKNOWLEDGE_BYTES, parcel); } catch (RemoteException re) { shutdownInternal(statusFromRemoteException(re), true); - } finally { - parcel.recycle(); } } @@ -588,7 +568,7 @@ public BinderClientTransport( this.securityPolicy = securityPolicy; this.offloadExecutor = offloadExecutorPool.getObject(); numInUseStreams = new AtomicInteger(); - pingTracker = new PingTracker(TimeProvider.SYSTEM_TIME_PROVIDER, (id) -> sendPing(id)); + pingTracker = new PingTracker(Ticker.systemTicker(), (id) -> sendPing(id)); serviceBinding = new ServiceBinding( @@ -607,7 +587,7 @@ void releaseExecutors() { @Override public synchronized void onBound(IBinder binder) { - sendSetupTransaction(binder); + sendSetupTransaction(OneWayBinderProxy.wrap(binder, offloadExecutor)); } @Override @@ -635,33 +615,39 @@ public synchronized ClientStream newStream( final Metadata headers, final CallOptions callOptions, ClientStreamTracer[] tracers) { - if (isShutdown()) { - return newFailingClientStream(shutdownStatus, attributes, headers, tracers); + if (!inState(TransportState.READY)) { + return newFailingClientStream( + isShutdown() + ? shutdownStatus + : Status.INTERNAL.withDescription("newStream() before transportReady()"), + attributes, + headers, + tracers); + } + + int callId = latestCallId++; + if (latestCallId == LAST_CALL_ID) { + latestCallId = FIRST_CALL_ID; + } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); + Inbound.ClientInbound inbound = + new Inbound.ClientInbound( + this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); + if (ongoingCalls.putIfAbsent(callId, inbound) != null) { + Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); + shutdownInternal(failure, true); + return newFailingClientStream(failure, attributes, headers, tracers); } else { - int callId = latestCallId++; - if (latestCallId == LAST_CALL_ID) { - latestCallId = FIRST_CALL_ID; + if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { + clientTransportListener.transportInUse(true); } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(tracers, attributes, headers); - Inbound.ClientInbound inbound = - new Inbound.ClientInbound( - this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); - if (ongoingCalls.putIfAbsent(callId, inbound) != null) { - Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); - shutdownInternal(failure, true); - return newFailingClientStream(failure, attributes, headers, tracers); + Outbound.ClientOutbound outbound = + new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); + if (method.getType().clientSendsOneMessage()) { + return new SingleMessageClientStream(inbound, outbound, attributes); } else { - if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { - clientTransportListener.transportInUse(true); - } - Outbound.ClientOutbound outbound = - new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); - if (method.getType().clientSendsOneMessage()) { - return new SingleMessageClientStream(inbound, outbound, attributes); - } else { - return new MultiMessageClientStream(inbound, outbound, attributes); - } + return new MultiMessageClientStream(inbound, outbound, attributes); } } } @@ -742,7 +728,7 @@ private void checkSecurityPolicy(IBinder binder) { if (inState(TransportState.SETUP)) { if (!authorization.isOk()) { shutdownInternal(authorization, true); - } else if (!setOutgoingBinder(binder)) { + } else if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { shutdownInternal( Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); } else { @@ -821,7 +807,8 @@ public BinderServerTransport( IBinder callbackBinder) { super(executorServicePool, attributes, buildLogId(attributes)); this.streamTracerFactories = streamTracerFactories; - setOutgoingBinder(callbackBinder); + // TODO(jdcormie): Plumb in the Server's executor() and use it here instead. + setOutgoingBinder(OneWayBinderProxy.wrap(callbackBinder, getScheduledExecutorService())); } public synchronized void setServerTransportListener(ServerTransportListener serverTransportListener) { diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java index d4c8e48f953..b968e744685 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java @@ -40,7 +40,7 @@ public final class BinderTransportSecurity { private static final Attributes.Key TRANSPORT_AUTHORIZATION_STATE = - Attributes.Key.create("transport-authorization-state"); + Attributes.Key.create("internal:transport-authorization-state"); private BinderTransportSecurity() {} diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java index da1a2961546..5ab96085a41 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Inbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -468,7 +468,7 @@ public final synchronized InputStream next() { if (firstMessage != null) { stream = firstMessage; firstMessage = null; - } else if (messageAvailable()) { + } else if (numRequestedMessages > 0 && messageAvailable()) { stream = assembleNextMessage(); } if (stream != null) { diff --git a/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java b/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java new file mode 100644 index 00000000000..09b45dd9936 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java @@ -0,0 +1,142 @@ +package io.grpc.binder.internal; + +import android.os.Binder; +import android.os.IBinder; +import android.os.Parcel; +import android.os.RemoteException; +import io.grpc.internal.SerializingExecutor; +import java.util.concurrent.Executor; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Wraps an {@link IBinder} with a safe and uniformly asynchronous transaction API. + * + *

When the target of your bindService() call is hosted in a different process, Android supplies + * you with an {@link IBinder} that proxies your transactions to the remote {@link + * android.os.Binder} instance. But when the target Service is hosted in the same process, Android + * supplies you with that local instance of {@link android.os.Binder} directly. This in-process + * implementation of {@link IBinder} is problematic for clients that want "oneway" transaction + * semantics because its transact() method simply invokes onTransact() on the caller's thread, even + * when the {@link IBinder#FLAG_ONEWAY} flag is set. Even though this behavior is documented, its + * consequences with respect to reentrancy, locking, and transaction dispatch order can be + * surprising and dangerous. + * + *

Wrap your {@link IBinder}s with an instance of this class to ensure the following + * out-of-process "oneway" semantics are always in effect: + * + *

    + *
  • transact() merely enqueues the transaction for processing. It doesn't wait for onTransact() + * to complete. + *
  • transact() may fail for programming errors or transport-layer errors that are immediately + * obvious on the caller's side, but never for an Exception or false return value from + * onTransact(). + *
  • onTransact() runs without holding any of the locks held by the thread calling transact(). + *
  • onTransact() calls are dispatched one at a time in the same happens-before order as the + * corresponding calls to transact(). + *
+ * + *

NB: One difference that this class can't conceal is that calls to onTransact() are serialized + * per {@link OneWayBinderProxy} instance, not per instance of the wrapped {@link IBinder}. An + * android.os.Binder with in-process callers could still receive concurrent calls to onTransact() on + * different threads if callers used different {@link OneWayBinderProxy} instances or if that Binder + * also had out-of-process callers. + */ +public abstract class OneWayBinderProxy { + private static final Logger logger = Logger.getLogger(OneWayBinderProxy.class.getName()); + protected final IBinder delegate; + + private OneWayBinderProxy(IBinder iBinder) { + this.delegate = iBinder; + } + + /** + * Returns a new instance of {@link OneWayBinderProxy} that wraps {@code iBinder}. + * + * @param iBinder the binder to wrap + * @param inProcessThreadHopExecutor a non-direct Executor used to dispatch calls to onTransact(), + * if necessary + * @return a new instance of {@link OneWayBinderProxy} + */ + public static OneWayBinderProxy wrap(IBinder iBinder, Executor inProcessThreadHopExecutor) { + return (iBinder instanceof Binder) + ? new InProcessImpl(iBinder, inProcessThreadHopExecutor) + : new OutOfProcessImpl(iBinder); + } + + /** + * Enqueues a transaction for the wrapped {@link IBinder} with guaranteed "oneway" semantics. + * + *

NB: Unlike {@link IBinder#transact}, implementations of this method take ownership of the + * {@code data} Parcel. When this method returns, {@code data} will normally be empty, but callers + * should still unconditionally {@link ParcelHolder#close()} it to avoid a leak in case they or + * the implementation throws before ownership is transferred. + * + * @param code identifies the type of this transaction + * @param data a non-empty container of the Parcel to be sent + * @throws RemoteException if the transaction could not even be queued for dispatch on the server. + * Failures from {@link Binder#onTransact} are *never* reported this way. + */ + public abstract void transact(int code, ParcelHolder data) throws RemoteException; + + /** + * Returns the wrapped {@link IBinder} for the purpose of calling methods other than {@link + * IBinder#transact(int, Parcel, Parcel, int)}. + */ + public IBinder getDelegate() { + return delegate; + } + + static class OutOfProcessImpl extends OneWayBinderProxy { + OutOfProcessImpl(IBinder iBinder) { + super(iBinder); + } + + @Override + public void transact(int code, ParcelHolder data) throws RemoteException { + if (!transactAndRecycleParcel(code, data.release())) { + // This cannot happen (see g/android-binder/c/jM4NvS234Rw) but, just in case, let the caller + // handle it along with all the other possible transport-layer errors. + throw new RemoteException("BinderProxy#transact(" + code + ", FLAG_ONEWAY) returned false"); + } + } + } + + protected boolean transactAndRecycleParcel(int code, Parcel data) throws RemoteException { + try { + return delegate.transact(code, data, null, IBinder.FLAG_ONEWAY); + } finally { + data.recycle(); + } + } + + static class InProcessImpl extends OneWayBinderProxy { + private final SerializingExecutor executor; + + InProcessImpl(IBinder binder, Executor executor) { + super(binder); + this.executor = new SerializingExecutor(executor); + } + + @Override + public void transact(int code, ParcelHolder wrappedParcel) { + // Transfer ownership, taking care to handle any RuntimeException from execute(). + Parcel parcel = wrappedParcel.get(); + executor.execute( + () -> { + try { + if (!transactAndRecycleParcel(code, parcel)) { + // onTransact() in our same process returned this. Ignore it, just like Android + // would have if the android.os.Binder was in another process. + logger.log(Level.FINEST, "A oneway transaction was not understood - ignoring"); + } + } catch (Exception e) { + // onTransact() in our same process threw this. Ignore it, just like Android would + // have if the android.os.Binder was in another process. + logger.log(Level.FINEST, "A oneway transaction threw - ignoring", e); + } + }); + wrappedParcel.release(); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/Outbound.java b/binder/src/main/java/io/grpc/binder/internal/Outbound.java index 2a4e968b1e8..e2896be02a1 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Outbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Outbound.java @@ -221,15 +221,14 @@ final void send() throws StatusException { @GuardedBy("this") @SuppressWarnings("fallthrough") protected final void sendInternal() throws StatusException { - Parcel parcel = Parcel.obtain(); - int flags = 0; - parcel.writeInt(0); // Placeholder for flags. Will be filled in below. - parcel.writeInt(transactionIndex++); - try { + try (ParcelHolder parcel = ParcelHolder.obtain()) { + int flags = 0; + parcel.get().writeInt(0); // Placeholder for flags. Will be filled in below. + parcel.get().writeInt(transactionIndex++); switch (outboundState) { case INITIAL: flags |= TransactionUtils.FLAG_PREFIX; - flags |= writePrefix(parcel); + flags |= writePrefix(parcel.get()); onOutboundState(State.PREFIX_SENT); if (!messageAvailable() && !suffixReady) { break; @@ -239,7 +238,7 @@ protected final void sendInternal() throws StatusException { InputStream messageStream = peekNextMessage(); if (messageStream != null) { flags |= TransactionUtils.FLAG_MESSAGE_DATA; - flags |= writeMessageData(parcel, messageStream); + flags |= writeMessageData(parcel.get(), messageStream); } else { checkState(suffixReady); } @@ -252,20 +251,19 @@ protected final void sendInternal() throws StatusException { // Fall-through. case ALL_MESSAGES_SENT: flags |= TransactionUtils.FLAG_SUFFIX; - flags |= writeSuffix(parcel); + flags |= writeSuffix(parcel.get()); onOutboundState(State.SUFFIX_SENT); break; default: throw new AssertionError(); } - TransactionUtils.fillInFlags(parcel, flags); + TransactionUtils.fillInFlags(parcel.get(), flags); + int dataSize = parcel.get().dataSize(); transport.sendTransaction(callId, parcel); - statsTraceContext.outboundWireSize(parcel.dataSize()); - statsTraceContext.outboundUncompressedSize(parcel.dataSize()); + statsTraceContext.outboundWireSize(dataSize); + statsTraceContext.outboundUncompressedSize(dataSize); } catch (IOException e) { throw Status.INTERNAL.withCause(e).asException(); - } finally { - parcel.recycle(); } } diff --git a/binder/src/main/java/io/grpc/binder/internal/ParcelHolder.java b/binder/src/main/java/io/grpc/binder/internal/ParcelHolder.java new file mode 100644 index 00000000000..cad8da9712f --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/ParcelHolder.java @@ -0,0 +1,76 @@ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkState; + +import android.os.Parcel; +import com.google.common.annotations.VisibleForTesting; +import java.io.Closeable; +import javax.annotation.Nullable; + +/** + * Wraps a {@link Parcel} from the static {@link Parcel#obtain()} pool with methods that make it + * easy to eventually {@link Parcel#recycle()} it. + */ +class ParcelHolder implements Closeable { + + @Nullable private Parcel parcel; + + /** + * Creates a new instance that owns a {@link Parcel} newly obtained from Android's object pool. + */ + public static ParcelHolder obtain() { + return new ParcelHolder(Parcel.obtain()); + } + + /** Creates a new instance taking ownership of the specified {@code parcel}. */ + public ParcelHolder(Parcel parcel) { + this.parcel = parcel; + } + + /** + * Returns the wrapped {@link Parcel} if we still own it. + * + * @throws IllegalStateException if ownership has already been given up by {@link #release()} + */ + public Parcel get() { + checkState(parcel != null, "get() after close()/release()"); + return parcel; + } + + /** + * Returns the wrapped {@link Parcel} and releases ownership of it. + * + * @throws IllegalStateException if ownership has already been given up by {@link #release()} + */ + public Parcel release() { + Parcel result = get(); + this.parcel = null; + return result; + } + + /** + * Recycles the wrapped {@link Parcel} to Android's object pool, if we still own it. + * + *

Otherwise, this method has no effect. + */ + @Override + public void close() { + if (parcel != null) { + parcel.recycle(); + parcel = null; + } + } + + /** + * Returns true iff this container no longer owns a {@link Parcel}. + * + *

{@link #isEmpty()} is true after all call to {@link #close()} or {@link #release()}. + * + *

Typically only used for debugging or testing since Parcel-owning code should be calling + * {@link #close()} unconditionally. + */ + @VisibleForTesting + public boolean isEmpty() { + return parcel == null; + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java index 911412d0b66..640d6006824 100644 --- a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java +++ b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java @@ -16,10 +16,10 @@ package io.grpc.binder.internal; +import com.google.common.base.Ticker; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ClientTransport.PingCallback; -import io.grpc.internal.TimeProvider; import java.util.concurrent.Executor; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -38,7 +38,7 @@ interface PingSender { void sendPing(int id) throws StatusException; } - private final TimeProvider timeProvider; + private final Ticker ticker; private final PingSender pingSender; @GuardedBy("this") @@ -48,8 +48,8 @@ interface PingSender { @GuardedBy("this") private int nextPingId; - PingTracker(TimeProvider timeProvider, PingSender pingSender) { - this.timeProvider = timeProvider; + PingTracker(Ticker ticker, PingSender pingSender) { + this.ticker = ticker; this.pingSender = pingSender; } @@ -93,7 +93,7 @@ private final class Ping { this.callback = callback; this.executor = executor; this.id = id; - this.startTimeNanos = timeProvider.currentTimeNanos(); + this.startTimeNanos = ticker.read(); } private synchronized void fail(Status status) { @@ -107,7 +107,7 @@ private synchronized void success() { if (!done) { done = true; executor.execute( - () -> callback.onSuccess(timeProvider.currentTimeNanos() - startTimeNanos)); + () -> callback.onSuccess(ticker.read() - startTimeNanos)); } } } diff --git a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java index 722af081bbc..8c7bc83d214 100644 --- a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java +++ b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java @@ -16,6 +16,7 @@ package io.grpc.binder; +import static android.content.Intent.URI_ANDROID_APP_SCHEME; import static com.google.common.truth.Truth.assertThat; import android.content.ComponentName; @@ -24,9 +25,11 @@ import android.net.Uri; import androidx.test.core.app.ApplicationProvider; import com.google.common.testing.EqualsTester; +import java.net.URISyntaxException; import org.junit.Test; import org.junit.runner.RunWith; import org.robolectric.RobolectricTestRunner; +import org.robolectric.annotation.Config; @RunWith(RobolectricTestRunner.class) public final class AndroidComponentAddressTest { @@ -60,6 +63,30 @@ public void testAsBindIntent() { assertThat(addr.asBindIntent().filterEquals(bindIntent)).isTrue(); } + @Test + @Config(sdk = 30) + public void testAsAndroidAppUriSdk30() throws URISyntaxException { + AndroidComponentAddress addr = + AndroidComponentAddress.forRemoteComponent("com.foo", "com.foo.Service"); + AndroidComponentAddress addrClone = + AndroidComponentAddress.forBindIntent( + Intent.parseUri(addr.asAndroidAppUri(), URI_ANDROID_APP_SCHEME)); + assertThat(addr).isEqualTo(addrClone); + } + + @Test + @Config(sdk = 29) + public void testAsAndroidAppUriSdk29() throws URISyntaxException { + AndroidComponentAddress addr = + AndroidComponentAddress.forRemoteComponent("com.foo", "com.foo.Service"); + AndroidComponentAddress addrClone = + AndroidComponentAddress.forBindIntent( + Intent.parseUri(addr.asAndroidAppUri(), URI_ANDROID_APP_SCHEME)); + // Can't test for equality because URI_ANDROID_APP_SCHEME adds a (redundant) package filter. + assertThat(addr.getComponent()).isEqualTo(addrClone.getComponent()); + assertThat(addr.getAuthority()).isEqualTo(addrClone.getAuthority()); + } + @Test public void testEquality() { new EqualsTester() @@ -85,4 +112,35 @@ public void testEquality() { .setComponent(hostComponent))) .testEquals(); } + + @Test + @Config(sdk = 30) + public void testPackageFilterEquality30AndUp() { + new EqualsTester() + .addEqualityGroup( + AndroidComponentAddress.forBindIntent( + new Intent().setAction("action").setComponent(new ComponentName("pkg", "cls"))), + AndroidComponentAddress.forBindIntent( + new Intent() + .setAction("action") + .setPackage("pkg") + .setComponent(new ComponentName("pkg", "cls")))) + .testEquals(); + } + + @Test + @Config(sdk = 29) + public void testPackageFilterEqualityPre30() { + new EqualsTester() + .addEqualityGroup( + AndroidComponentAddress.forBindIntent( + new Intent().setAction("action").setComponent(new ComponentName("pkg", "cls")))) + .addEqualityGroup( + AndroidComponentAddress.forBindIntent( + new Intent() + .setAction("action") + .setPackage("pkg") + .setComponent(new ComponentName("pkg", "cls")))) + .testEquals(); + } } diff --git a/binder/src/test/java/io/grpc/binder/BinderChannelBuilderTest.java b/binder/src/test/java/io/grpc/binder/BinderChannelBuilderTest.java new file mode 100644 index 00000000000..5dd7e13107e --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/BinderChannelBuilderTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static org.junit.Assert.fail; + +import android.content.Context; +import androidx.test.core.app.ApplicationProvider; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public final class BinderChannelBuilderTest { + private final Context appContext = ApplicationProvider.getApplicationContext(); + private final AndroidComponentAddress addr = AndroidComponentAddress.forContext(appContext); + + @Test + public void strictLifecycleManagementForbidsIdleTimers() { + BinderChannelBuilder builder = BinderChannelBuilder.forAddress(addr, appContext); + builder.strictLifecycleManagement(); + try { + builder.idleTimeout(10, TimeUnit.SECONDS); + fail(); + } catch (IllegalStateException ise) { + // Expected. + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java index 86edb5ad7df..6a692090549 100644 --- a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java +++ b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java @@ -16,22 +16,36 @@ package io.grpc.binder; +import static android.Manifest.permission.ACCESS_COARSE_LOCATION; +import static android.Manifest.permission.ACCESS_FINE_LOCATION; +import static android.Manifest.permission.WRITE_EXTERNAL_STORAGE; +import static android.content.pm.PackageInfo.REQUESTED_PERMISSION_GRANTED; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static org.robolectric.Shadows.shadowOf; +import android.app.admin.DevicePolicyManager; +import android.content.ComponentName; import android.content.Context; import android.content.pm.PackageInfo; import android.content.pm.PackageManager; import android.content.pm.Signature; +import android.os.Build; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; import android.os.Process; import androidx.test.core.app.ApplicationProvider; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.hash.Hashing; import io.grpc.Status; -import io.grpc.binder.SecurityPolicy; +import java.util.HashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.robolectric.RobolectricTestRunner; +import org.robolectric.annotation.Config; @RunWith(RobolectricTestRunner.class) public final class SecurityPoliciesTest { @@ -39,7 +53,6 @@ public final class SecurityPoliciesTest { private static final int MY_UID = Process.myUid(); private static final int OTHER_UID = MY_UID + 1; private static final int OTHER_UID_SAME_SIGNATURE = MY_UID + 2; - private static final int OTHER_UID_NO_SIGNATURE = MY_UID + 3; private static final int OTHER_UID_UNKNOWN = MY_UID + 4; private static final String PERMISSION_DENIED_REASONS = "some reasons"; @@ -49,10 +62,10 @@ public final class SecurityPoliciesTest { private static final String OTHER_UID_PACKAGE_NAME = "other.package"; private static final String OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME = "other.package.samesignature"; - private static final String OTHER_UID_NO_SIGNATURE_PACKAGE_NAME = "other.package.nosignature"; private Context appContext; private PackageManager packageManager; + private DevicePolicyManager devicePolicyManager; private SecurityPolicy policy; @@ -60,24 +73,24 @@ public final class SecurityPoliciesTest { public void setUp() { appContext = ApplicationProvider.getApplicationContext(); packageManager = appContext.getPackageManager(); - installPackage(MY_UID, appContext.getPackageName(), SIG1); - installPackage(OTHER_UID, OTHER_UID_PACKAGE_NAME, SIG2); - installPackage(OTHER_UID_SAME_SIGNATURE, OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME, SIG1); - installPackage(OTHER_UID_NO_SIGNATURE, OTHER_UID_NO_SIGNATURE_PACKAGE_NAME); + devicePolicyManager = + (DevicePolicyManager) appContext.getSystemService(Context.DEVICE_POLICY_SERVICE); } @SuppressWarnings("deprecation") - private void installPackage(int uid, String packageName, Signature... signatures) { - PackageInfo info = new PackageInfo(); - info.packageName = packageName; - info.signatures = signatures; - shadowOf(packageManager).installPackage(info); - shadowOf(packageManager).setPackagesForUid(uid, packageName); + private void installPackages(int uid, PackageInfo... packageInfo) { + String[] packageNames = new String[packageInfo.length]; + for (int i = 0; i < packageInfo.length; i++) { + shadowOf(packageManager).installPackage(packageInfo[i]); + packageNames[i] = packageInfo[i].packageName; + } + shadowOf(packageManager).setPackagesForUid(uid, packageNames); } @Test public void testInternalOnly() throws Exception { policy = SecurityPolicies.internalOnly(); + assertThat(policy.checkAuthorization(MY_UID).getCode()).isEqualTo(Status.OK.getCode()); assertThat(policy.checkAuthorization(OTHER_UID).getCode()) .isEqualTo(Status.PERMISSION_DENIED.getCode()); @@ -99,6 +112,11 @@ public void testPermissionDenied() throws Exception { @Test public void testHasSignature_succeedsIfPackageNameAndSignaturesMatch() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + policy = SecurityPolicies.hasSignature(packageManager, OTHER_UID_PACKAGE_NAME, SIG2); // THEN UID for package that has SIG2 will be authorized @@ -107,6 +125,14 @@ public void testHasSignature_succeedsIfPackageNameAndSignaturesMatch() @Test public void testHasSignature_failsIfPackageNameDoesNotMatch() throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME) + .setSignatures(SIG1) + .build(); + + installPackages(OTHER_UID_SAME_SIGNATURE, info); + policy = SecurityPolicies.hasSignature(packageManager, appContext.getPackageName(), SIG1); // THEN UID for package that has SIG1 but different package name will not be authorized @@ -116,6 +142,11 @@ public void testHasSignature_failsIfPackageNameDoesNotMatch() throws Exception { @Test public void testHasSignature_failsIfSignatureDoesNotMatch() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + policy = SecurityPolicies.hasSignature(packageManager, OTHER_UID_PACKAGE_NAME, SIG1); // THEN UID for package that doesn't have SIG1 will not be authorized @@ -126,6 +157,11 @@ public void testHasSignature_failsIfSignatureDoesNotMatch() throws Exception { @Test public void testOneOfSignatures_succeedsIfPackageNameAndSignaturesMatch() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + policy = SecurityPolicies.oneOfSignatures( packageManager, OTHER_UID_PACKAGE_NAME, ImmutableList.of(SIG2)); @@ -136,6 +172,14 @@ public void testOneOfSignatures_succeedsIfPackageNameAndSignaturesMatch() @Test public void testOneOfSignature_failsIfAllSignaturesDoNotMatch() throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME) + .setSignatures(SIG1) + .build(); + + installPackages(OTHER_UID_SAME_SIGNATURE, info); + policy = SecurityPolicies.oneOfSignatures( packageManager, @@ -150,11 +194,14 @@ public void testOneOfSignature_failsIfAllSignaturesDoNotMatch() throws Exception @Test public void testOneOfSignature_succeedsIfPackageNameAndOneOfSignaturesMatch() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + policy = SecurityPolicies.oneOfSignatures( - packageManager, - OTHER_UID_PACKAGE_NAME, - ImmutableList.of(SIG1, SIG2)); + packageManager, OTHER_UID_PACKAGE_NAME, ImmutableList.of(SIG1, SIG2)); // THEN UID for package that has SIG2 will be authorized assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); @@ -164,11 +211,531 @@ public void testOneOfSignature_succeedsIfPackageNameAndOneOfSignaturesMatch() public void testHasSignature_failsIfUidUnknown() throws Exception { policy = SecurityPolicies.hasSignature( - packageManager, - appContext.getPackageName(), - SIG1); + packageManager, + appContext.getPackageName(), + SIG1); assertThat(policy.checkAuthorization(OTHER_UID_UNKNOWN).getCode()) - .isEqualTo(Status.UNAUTHENTICATED.getCode()); + .isEqualTo(Status.UNAUTHENTICATED.getCode()); + } + + @Test + public void testHasPermissions_sharedUserId_succeedsIfAllPackageHavePermissions() + throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(ACCESS_COARSE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .build(); + + PackageInfo infoSamePerms = + newBuilder() + .setPackageName(OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(ACCESS_COARSE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .build(); + + installPackages(OTHER_UID, info, infoSamePerms); + + policy = + SecurityPolicies.hasPermissions( + packageManager, ImmutableSet.of(ACCESS_FINE_LOCATION, ACCESS_COARSE_LOCATION)); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + public void testHasPermissions_sharedUserId_failsIfOnePackageHasNoPermissions() throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .build(); + + PackageInfo infoNoPerms = + newBuilder() + .setPackageName(OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, 0) + .build(); + + installPackages(OTHER_UID, info, infoNoPerms); + + policy = SecurityPolicies.hasPermissions(packageManager, ImmutableSet.of(ACCESS_FINE_LOCATION)); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .contains(ACCESS_FINE_LOCATION); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .contains(OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME); + } + + @Test + public void testHasPermissions_succeedsIfPackageHasPermissions() throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(ACCESS_COARSE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(WRITE_EXTERNAL_STORAGE, 0) + .build(); + + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.hasPermissions( + packageManager, ImmutableSet.of(ACCESS_FINE_LOCATION, ACCESS_COARSE_LOCATION)); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + public void testHasPermissions_failsIfPackageDoesNotHaveOnePermission() throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(ACCESS_COARSE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(WRITE_EXTERNAL_STORAGE, 0) + .build(); + + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.hasPermissions( + packageManager, ImmutableSet.of(ACCESS_FINE_LOCATION, WRITE_EXTERNAL_STORAGE)); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .contains(WRITE_EXTERNAL_STORAGE); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .contains(OTHER_UID_PACKAGE_NAME); + } + + @Test + public void testHasPermissions_failsIfPackageDoesNotHavePermissions() throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_PACKAGE_NAME) + .setPermission(ACCESS_FINE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(ACCESS_COARSE_LOCATION, REQUESTED_PERMISSION_GRANTED) + .setPermission(WRITE_EXTERNAL_STORAGE, 0) + .build(); + + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.hasPermissions(packageManager, ImmutableSet.of(WRITE_EXTERNAL_STORAGE)); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .contains(WRITE_EXTERNAL_STORAGE); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .contains(OTHER_UID_PACKAGE_NAME); + } + + @Test + @Config(sdk = 18) + public void testIsDeviceOwner_succeedsForDeviceOwner() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + shadowOf(devicePolicyManager) + .setDeviceOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + + policy = SecurityPolicies.isDeviceOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + @Config(sdk = 18) + public void testIsDeviceOwner_failsForNotDeviceOwner() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + + policy = SecurityPolicies.isDeviceOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + @Config(sdk = 18) + public void testIsDeviceOwner_failsWhenNoPackagesForUid() throws Exception { + policy = SecurityPolicies.isDeviceOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + } + + @Test + @Config(sdk = 17) + public void testIsDeviceOwner_failsForSdkLevelTooLow() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + + policy = SecurityPolicies.isDeviceOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + @Config(sdk = 21) + public void testIsProfileOwner_succeedsForProfileOwner() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + shadowOf(devicePolicyManager) + .setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + + policy = SecurityPolicies.isProfileOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + @Config(sdk = 21) + public void testIsProfileOwner_failsForNotProfileOwner() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + + policy = SecurityPolicies.isProfileOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + @Config(sdk = 21) + public void testIsProfileOwner_failsWhenNoPackagesForUid() throws Exception { + policy = SecurityPolicies.isProfileOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + } + + @Test + @Config(sdk = 19) + public void testIsProfileOwner_failsForSdkLevelTooLow() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + + policy = SecurityPolicies.isProfileOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + @Config(sdk = 30) + public void testIsProfileOwnerOnOrgOwned_succeedsForProfileOwnerOnOrgOwned() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + shadowOf(devicePolicyManager) + .setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + shadowOf(devicePolicyManager).setOrganizationOwnedDeviceWithManagedProfile(true); + + policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + + } + + @Test + @Config(sdk = 30) + public void testIsProfileOwnerOnOrgOwned_failsForProfileOwnerOnNonOrgOwned() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + shadowOf(devicePolicyManager) + .setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + shadowOf(devicePolicyManager).setOrganizationOwnedDeviceWithManagedProfile(false); + + policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + @Config(sdk = 21) + public void testIsProfileOwnerOnOrgOwned_failsForNotProfileOwner() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + + policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + @Config(sdk = 21) + public void testIsProfileOwnerOnOrgOwned_failsWhenNoPackagesForUid() throws Exception { + policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + } + + @Test + @Config(sdk = 29) + public void testIsProfileOwnerOnOrgOwned_failsForSdkLevelTooLow() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + + installPackages(OTHER_UID, info); + + policy = SecurityPolicies.isProfileOwner(appContext); + + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + private static PackageInfoBuilder newBuilder() { + return new PackageInfoBuilder(); + } + + private static class PackageInfoBuilder { + private String packageName; + private Signature[] signatures; + private final HashMap permissions = new HashMap<>(); + + public PackageInfoBuilder setPackageName(String packageName) { + this.packageName = packageName; + return this; + } + + public PackageInfoBuilder setPermission(String permissionName, int permissionFlag) { + this.permissions.put(permissionName, permissionFlag); + return this; + } + + public PackageInfoBuilder setSignatures(Signature... signatures) { + this.signatures = signatures; + return this; + } + + public PackageInfo build() { + checkState(this.packageName != null, "packageName is a mandatory field"); + + PackageInfo packageInfo = new PackageInfo(); + + packageInfo.packageName = this.packageName; + + if (this.signatures != null) { + packageInfo.signatures = this.signatures; + } + + if (!this.permissions.isEmpty()) { + String[] requestedPermissions = + this.permissions.keySet().toArray(new String[this.permissions.size()]); + int[] requestedPermissionsFlags = new int[requestedPermissions.length]; + + for (int i = 0; i < requestedPermissions.length; i++) { + requestedPermissionsFlags[i] = this.permissions.get(requestedPermissions[i]); + } + + packageInfo.requestedPermissions = requestedPermissions; + packageInfo.requestedPermissionsFlags = requestedPermissionsFlags; + } + + return packageInfo; + } + } + + @Test + public void testAllOf_succeedsIfAllSecurityPoliciesAllowed() throws Exception { + policy = SecurityPolicies.allOf(SecurityPolicies.internalOnly()); + + assertThat(policy.checkAuthorization(MY_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + public void testAllOf_failsIfOneSecurityPoliciesNotAllowed() throws Exception { + policy = + SecurityPolicies.allOf( + SecurityPolicies.internalOnly(), + SecurityPolicies.permissionDenied("Not allowed SecurityPolicy")); + + assertThat(policy.checkAuthorization(MY_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(MY_UID).getDescription()) + .contains("Not allowed SecurityPolicy"); + } + + @Test + public void testAnyOf_succeedsIfAnySecurityPoliciesAllowed() throws Exception { + RecordingPolicy recordingPolicy = new RecordingPolicy(); + policy = SecurityPolicies.anyOf(SecurityPolicies.internalOnly(), recordingPolicy); + + assertThat(policy.checkAuthorization(MY_UID).getCode()).isEqualTo(Status.OK.getCode()); + assertThat(recordingPolicy.numCalls.get()).isEqualTo(0); + } + + @Test + public void testAnyOf_failsIfNoSecurityPolicyIsAllowed() throws Exception { + policy = + SecurityPolicies.anyOf( + new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return Status.PERMISSION_DENIED.withDescription("Not allowed: first"); + } + }, + new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return Status.UNAUTHENTICATED.withDescription("Not allowed: second"); + } + }); + + assertThat(policy.checkAuthorization(MY_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(MY_UID).getDescription()).contains("Not allowed: first"); + assertThat(policy.checkAuthorization(MY_UID).getDescription()).contains("Not allowed: second"); + } + + private static final class RecordingPolicy extends SecurityPolicy { + private final AtomicInteger numCalls = new AtomicInteger(0); + + @Override + public Status checkAuthorization(int uid) { + numCalls.incrementAndGet(); + return Status.OK; + } + } + + @Test + public void testHasSignatureSha256Hash_succeedsIfPackageNameAndSignatureHashMatch() + throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.hasSignatureSha256Hash( + packageManager, OTHER_UID_PACKAGE_NAME, getSha256Hash(SIG2)); + + // THEN UID for package that has SIG2 will be authorized + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + public void testHasSignatureSha256Hash_failsIfPackageNameDoesNotMatch() throws Exception { + PackageInfo info1 = + newBuilder().setPackageName(appContext.getPackageName()).setSignatures(SIG1).build(); + installPackages(MY_UID, info1); + + PackageInfo info2 = + newBuilder() + .setPackageName(OTHER_UID_SAME_SIGNATURE_PACKAGE_NAME) + .setSignatures(SIG1) + .build(); + installPackages(OTHER_UID_SAME_SIGNATURE, info2); + + policy = + SecurityPolicies.hasSignatureSha256Hash( + packageManager, appContext.getPackageName(), getSha256Hash(SIG1)); + + // THEN UID for package that has SIG1 but different package name will not be authorized + assertThat(policy.checkAuthorization(OTHER_UID_SAME_SIGNATURE).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + public void testHasSignatureSha256Hash_failsIfSignatureHashDoesNotMatch() throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.hasSignatureSha256Hash( + packageManager, OTHER_UID_PACKAGE_NAME, getSha256Hash(SIG1)); + + // THEN UID for package that doesn't have SIG1 will not be authorized + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + public void testOneOfSignatureSha256Hash_succeedsIfPackageNameAndSignatureHashMatch() + throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.oneOfSignatureSha256Hash( + packageManager, OTHER_UID_PACKAGE_NAME, ImmutableList.of(getSha256Hash(SIG2))); + + // THEN UID for package that has SIG2 will be authorized + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + public void testOneOfSignatureSha256Hash_succeedsIfPackageNameAndOneOfSignatureHashesMatch() + throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.oneOfSignatureSha256Hash( + packageManager, + OTHER_UID_PACKAGE_NAME, + ImmutableList.of(getSha256Hash(SIG1), getSha256Hash(SIG2))); + + // THEN UID for package that has SIG2 will be authorized + assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); + } + + @Test + public void + testOneOfSignatureSha256Hash_failsIfPackageNameDoNotMatchAndOneOfSignatureHashesMatch() + throws Exception { + PackageInfo info = + newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.oneOfSignatureSha256Hash( + packageManager, + appContext.getPackageName(), + ImmutableList.of(getSha256Hash(SIG1), getSha256Hash(SIG2))); + + // THEN UID for package that has SIG2 but different package name will not be authorized + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + public void testOneOfSignatureSha256Hash_failsIfPackageNameMatchAndOneOfSignatureHashesNotMatch() + throws Exception { + PackageInfo info = + newBuilder() + .setPackageName(OTHER_UID_PACKAGE_NAME) + .setSignatures(new Signature("1234")) + .build(); + installPackages(OTHER_UID, info); + + policy = + SecurityPolicies.oneOfSignatureSha256Hash( + packageManager, + appContext.getPackageName(), + ImmutableList.of(getSha256Hash(SIG1), getSha256Hash(SIG2))); + + // THEN UID for package that doesn't have SIG1 or SIG2 will not be authorized + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + private static byte[] getSha256Hash(Signature signature) { + return Hashing.sha256().hashBytes(signature.toByteArray()).asBytes(); } } diff --git a/binder/src/test/java/io/grpc/binder/internal/OneWayBinderProxyTest.java b/binder/src/test/java/io/grpc/binder/internal/OneWayBinderProxyTest.java new file mode 100644 index 00000000000..167e33e29a5 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/OneWayBinderProxyTest.java @@ -0,0 +1,197 @@ +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import android.os.Binder; +import android.os.IBinder; +import android.os.Parcel; +import android.os.RemoteException; +import io.grpc.binder.internal.OneWayBinderProxy.InProcessImpl; +import io.grpc.binder.internal.OneWayBinderProxy.OutOfProcessImpl; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Queue; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.RobolectricTestRunner; + +/** Unit tests for the {@link OneWayBinderProxy} implementations. */ +@RunWith(RobolectricTestRunner.class) +public class OneWayBinderProxyTest { + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + + QueuingExecutor queuingExecutor = new QueuingExecutor(); + + @Mock IBinder mockBinder; + + RecordingBinder recordingBinder = new RecordingBinder(); + + @Test + public void shouldProxyInProcessTransactionsOnExecutor() throws RemoteException { + InProcessImpl proxy = new InProcessImpl(recordingBinder, queuingExecutor); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeInt(123); + proxy.transact(456, parcel); + assertThat(parcel.isEmpty()).isTrue(); + assertThat(recordingBinder.txnLog).isEmpty(); + queuingExecutor.runAllQueued(); + assertThat(recordingBinder.txnLog).hasSize(1); + assertThat(recordingBinder.txnLog.get(0).argument).isEqualTo(123); + assertThat(recordingBinder.txnLog.get(0).code).isEqualTo(456); + assertThat(recordingBinder.txnLog.get(0).flags).isEqualTo(IBinder.FLAG_ONEWAY); + } + } + + @Test + public void shouldNotLeakParcelsInCaseOfRejectedExecution() throws RemoteException { + InProcessImpl proxy = new InProcessImpl(recordingBinder, queuingExecutor); + queuingExecutor.shutdown(); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeInt(123); + assertThrows(RejectedExecutionException.class, () -> proxy.transact(123, parcel)); + assertThat(parcel.isEmpty()).isFalse(); // Parcel didn't leak because we still own it. + } + } + + @Test + public void shouldProxyOutOfProcessTransactionsSynchronously() throws RemoteException { + OutOfProcessImpl proxy = new OutOfProcessImpl(recordingBinder); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + parcel.get().writeInt(123); + proxy.transact(456, parcel); + assertThat(parcel.isEmpty()).isTrue(); + assertThat(recordingBinder.txnLog).hasSize(1); + assertThat(recordingBinder.txnLog.get(0).argument).isEqualTo(123); + assertThat(recordingBinder.txnLog.get(0).code).isEqualTo(456); + assertThat(recordingBinder.txnLog.get(0).flags).isEqualTo(IBinder.FLAG_ONEWAY); + } + } + + @Test + public void shouldIgnoreInProcessRemoteExceptions() throws RemoteException { + when(mockBinder.transact(anyInt(), any(), any(), anyInt())).thenThrow(RemoteException.class); + InProcessImpl proxy = new InProcessImpl(mockBinder, queuingExecutor); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + proxy.transact(123, parcel); // Doesn't throw. + verify(mockBinder, never()).transact(anyInt(), any(), any(), anyInt()); + queuingExecutor.runAllQueued(); + } + } + + @Test + public void shouldExposeOutOfProcessRemoteExceptions() throws RemoteException { + when(mockBinder.transact(anyInt(), any(), any(), anyInt())).thenThrow(RemoteException.class); + OutOfProcessImpl proxy = new OutOfProcessImpl(mockBinder); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + assertThrows(RemoteException.class, () -> proxy.transact(123, parcel)); + } + } + + @Test + public void shouldIgnoreUnknownTransactionReturnValueInProcess() throws RemoteException { + when(mockBinder.transact(anyInt(), any(), any(), anyInt())).thenReturn(false); + InProcessImpl proxy = new InProcessImpl(mockBinder, queuingExecutor); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + proxy.transact(123, parcel); // Doesn't throw. + verify(mockBinder, never()).transact(anyInt(), any(), any(), anyInt()); + queuingExecutor.runAllQueued(); + verify(mockBinder).transact(eq(123), any(), any(), anyInt()); + } + } + + @Test + public void shouldReportImpossibleUnknownTransactionReturnValueOutOfProcess() + throws RemoteException { + when(mockBinder.transact(anyInt(), any(), any(), anyInt())).thenReturn(false); + OutOfProcessImpl proxy = new OutOfProcessImpl(mockBinder); + try (ParcelHolder parcel = ParcelHolder.obtain()) { + assertThrows(RemoteException.class, () -> proxy.transact(123, parcel)); + verify(mockBinder).transact(eq(123), any(), any(), anyInt()); + } + } + + /** An Executor that queues up Runnables for later manual execution by a unit test. */ + static class QueuingExecutor implements Executor { + private final Queue runnables = new ArrayDeque<>(); + private volatile boolean isShutdown; + + @Override + public void execute(Runnable r) { + if (isShutdown) { + throw new RejectedExecutionException(); + } + runnables.add(r); + } + + public void runAllQueued() { + Runnable next = null; + while ((next = runnables.poll()) != null) { + next.run(); + } + } + + public void shutdown() { + isShutdown = true; + } + } + + /** An immutable record of a call to {@link IBinder#transact(int, Parcel, Parcel, int)}. */ + static class TransactionRecord { + private final int code; + private final int argument; + private final int flags; + + private TransactionRecord(int code, int argument, int flags) { + this.code = code; + this.argument = argument; + this.flags = flags; + } + } + + /** A {@link Binder} that simply records every transaction it receives. */ + static class RecordingBinder extends Binder { + private final ArrayList txnLog = new ArrayList<>(); + + @Override + protected boolean onTransact(int code, Parcel data, Parcel reply, int flags) + throws RemoteException { + txnLog.add(new TransactionRecord(code, data.readInt(), flags)); + return true; + } + } + + interface ThrowingRunnable { + void run() throws Throwable; + } + + // TODO(jdcormie): Replace with Assert.assertThrows() once we upgrade to junit 4.13. + private static T assertThrows( + Class expectedThrowable, ThrowingRunnable runnable) { + try { + runnable.run(); + } catch (Throwable actualThrown) { + if (expectedThrowable.isInstance(actualThrown)) { + @SuppressWarnings("unchecked") + T retVal = (T) actualThrown; + return retVal; + } else { + AssertionError assertionError = new AssertionError("Unexpected type thrown"); + assertionError.initCause(actualThrown); + throw assertionError; + } + } + throw new AssertionError("Expected " + expectedThrowable + " but nothing was thrown"); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java index e17734baab8..60e7c163105 100644 --- a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java @@ -50,7 +50,7 @@ public void setUp() { callback = new TestCallback(); pingTracker = new PingTracker( - clock.getTimeProvider(), + clock.getTicker(), (id) -> { sentPings.add(id); if (pingFailureStatus != null) { diff --git a/bom/build.gradle b/bom/build.gradle index be2722d4bc0..1b1f98cff18 100644 --- a/bom/build.gradle +++ b/bom/build.gradle @@ -14,9 +14,7 @@ publishing { // Generate bom using subprojects def internalProjects = [ project.name, - 'grpc-authz', 'grpc-compiler', - 'grpc-gae-interop-testing-jdk8', ] def dependencyManagement = asNode().appendNode('dependencyManagement') @@ -25,6 +23,12 @@ publishing { if (internalProjects.contains(subproject.name)) { return } + if (!subproject.hasProperty('publishMavenPublicationToMavenRepository')) { + return + } + if (!subproject.publishMavenPublicationToMavenRepository.enabled) { + return + } def dependencyNode = dependencies.appendNode('dependency') dependencyNode.appendNode('groupId', subproject.group) dependencyNode.appendNode('artifactId', subproject.name) diff --git a/build.gradle b/build.gradle index efd9640e0ba..7270fcdb955 100644 --- a/build.gradle +++ b/build.gradle @@ -4,6 +4,7 @@ plugins { id "com.google.osdetector" apply false id "me.champeau.gradle.japicmp" apply false id "net.ltgt.errorprone" apply false + id 'com.google.cloud.tools.jib' apply false } import net.ltgt.gradle.errorprone.CheckSeverity @@ -19,7 +20,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.45.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.53.0" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() @@ -28,7 +29,7 @@ subprojects { mavenLocal() } - tasks.withType(JavaCompile) { + tasks.withType(JavaCompile).configureEach { it.options.compilerArgs += [ "-Xlint:all", "-Xlint:-options", @@ -41,7 +42,7 @@ subprojects { } } - tasks.withType(GenerateModuleMetadata) { + tasks.withType(GenerateModuleMetadata).configureEach { // Module metadata, introduced in Gradle 6.0, conflicts with our publishing task for // grpc-alts and grpc-compiler. enabled = false @@ -55,14 +56,6 @@ subprojects { protocPluginBaseName = 'protoc-gen-grpc-java' javaPluginPath = "$rootDir/compiler/build/exe/java_plugin/$protocPluginBaseName$exeSuffix" - nettyVersion = '4.1.72.Final' - guavaVersion = '31.0.1-android' - googleauthVersion = '1.4.0' - protobufVersion = '3.19.2' - protocVersion = protobufVersion - opencensusVersion = '0.28.0' - autovalueVersion = '1.9' - configureProtoCompilation = { String generatedSourcePath = "${projectDir}/src/generated" project.protobuf { @@ -70,7 +63,7 @@ subprojects { if (project.hasProperty('protoc')) { path = project.protoc } else { - artifact = "com.google.protobuf:protoc:${protocVersion}" + artifact = libs.protobuf.protoc.get() } } generateProtoTasks { @@ -89,7 +82,7 @@ subprojects { if (rootProject.childProjects.containsKey('grpc-compiler')) { // Only when the codegen is built along with the project, will we be able to run // the grpc code generator. - task syncGeneratedSources { } + def syncGeneratedSources = tasks.register("syncGeneratedSources") { } project.protobuf { plugins { grpc { path = javaPluginPath } } generateProtoTasks { @@ -104,16 +97,20 @@ subprojects { } dependsOn "generate${source}Proto" } - syncGeneratedSources.dependsOn syncTask - - task.dependsOn ':grpc-compiler:java_pluginExecutable' - // Recompile protos when the codegen has been changed - task.inputs.file javaPluginPath - task.plugins { grpc { option 'noversion' } } - if (isAndroid) { - task.plugins { - grpc { - option 'lite' + syncGeneratedSources.configure { + dependsOn syncTask + } + + task.configure { + dependsOn ':grpc-compiler:java_pluginExecutable' + // Recompile protos when the codegen has been changed + inputs.file javaPluginPath + plugins { grpc { option 'noversion' } } + if (isAndroid) { + plugins { + grpc { + option 'lite' + } } } } @@ -121,7 +118,9 @@ subprojects { } } // Re-sync as part of a normal build, to avoid forgetting to run the sync - assemble.dependsOn syncGeneratedSources + tasks.named("assemble").configure { + dependsOn syncGeneratedSources + } } else { // Otherwise, we just use the checked-in generated code. if (isAndroid) { @@ -130,14 +129,13 @@ subprojects { release { java { srcDir "${generatedSourcePath}/release/grpc" } } } } else { - project.sourceSets { - main { java { srcDir "${generatedSourcePath}/main/grpc" } } - test { java { srcDir "${generatedSourcePath}/test/grpc" } } + project.sourceSets.each() { sourceSet -> + sourceSet.java { srcDir "${generatedSourcePath}/${sourceSet.name}/grpc" } } } } - tasks.withType(JavaCompile) { + tasks.withType(JavaCompile).configureEach { appendToProperty( it.options.errorprone.excludedPaths, ".*/src/generated/[^/]+/java/.*" + @@ -146,72 +144,7 @@ subprojects { } } - libraries = [ - android_annotations: "com.google.android:annotations:4.1.1.4", - animalsniffer_annotations: "org.codehaus.mojo:animal-sniffer-annotations:1.19", - autovalue: "com.google.auto.value:auto-value:${autovalueVersion}", - autovalue_annotation: "com.google.auto.value:auto-value-annotations:${autovalueVersion}", - errorprone: "com.google.errorprone:error_prone_annotations:2.10.0", - cronet_api: 'org.chromium.net:cronet-api:92.4515.131', - cronet_embedded: 'org.chromium.net:cronet-embedded:92.4515.131', - gson: "com.google.code.gson:gson:2.8.9", - guava: "com.google.guava:guava:${guavaVersion}", - javax_annotation: 'org.apache.tomcat:annotations-api:6.0.53', - jsr305: 'com.google.code.findbugs:jsr305:3.0.2', - google_api_protos: 'com.google.api.grpc:proto-google-common-protos:2.0.1', - google_auth_credentials: "com.google.auth:google-auth-library-credentials:${googleauthVersion}", - google_auth_oauth2_http: "com.google.auth:google-auth-library-oauth2-http:${googleauthVersion}", - okhttp: 'com.squareup.okhttp:okhttp:2.7.4', - okio: 'com.squareup.okio:okio:1.17.5', - opencensus_api: "io.opencensus:opencensus-api:${opencensusVersion}", - opencensus_contrib_grpc_metrics: "io.opencensus:opencensus-contrib-grpc-metrics:${opencensusVersion}", - opencensus_impl: "io.opencensus:opencensus-impl:${opencensusVersion}", - opencensus_impl_lite: "io.opencensus:opencensus-impl-lite:${opencensusVersion}", - opencensus_proto: "io.opencensus:opencensus-proto:0.2.0", - instrumentation_api: 'com.google.instrumentation:instrumentation-api:0.4.3', - perfmark: 'io.perfmark:perfmark-api:0.23.0', - protobuf: "com.google.protobuf:protobuf-java:${protobufVersion}", - protobuf_lite: "com.google.protobuf:protobuf-javalite:${protobufVersion}", - protobuf_util: "com.google.protobuf:protobuf-java-util:${protobufVersion}", - - netty: "io.netty:netty-codec-http2:[${nettyVersion}]", - netty_epoll: "io.netty:netty-transport-native-epoll:${nettyVersion}:linux-x86_64", - netty_epoll_arm64: "io.netty:netty-transport-native-epoll:${nettyVersion}:linux-aarch_64", - netty_proxy_handler: "io.netty:netty-handler-proxy:${nettyVersion}", - - // Keep the following references of tcnative version in sync whenever it's updated - // SECURITY.md (multiple occurrences) - // examples/example-tls/build.gradle - // examples/example-tls/pom.xml - netty_tcnative: 'io.netty:netty-tcnative-boringssl-static:2.0.46.Final', - - conscrypt: 'org.conscrypt:conscrypt-openjdk-uber:2.5.1', - re2j: 'com.google.re2j:re2j:1.5', - - bouncycastle: 'org.bouncycastle:bcpkix-jdk15on:1.67', - - // Test dependencies. - junit: 'junit:junit:4.12', - mockito: 'org.mockito:mockito-core:3.3.3', - mockito_android: 'org.mockito:mockito-android:3.8.0', - truth: 'com.google.truth:truth:1.0.1', - guava_testlib: "com.google.guava:guava-testlib:${guavaVersion}", - androidx_annotation: "androidx.annotation:annotation:1.1.0", - androidx_core: "androidx.core:core:1.3.0", - androidx_lifecycle_common: "androidx.lifecycle:lifecycle-common:2.3.0", - androidx_lifecycle_service: "androidx.lifecycle:lifecycle-service:2.3.0", - androidx_test: "androidx.test:core:1.3.0", - androidx_test_rules: "androidx.test:rules:1.3.0", - androidx_test_ext_junit: "androidx.test.ext:junit:1.1.2", - robolectric: "org.robolectric:robolectric:4.4", - - // Benchmark dependencies - hdrhistogram: 'org.hdrhistogram:HdrHistogram:2.1.12', - math: 'org.apache.commons:commons-math3:3.6.1', - - // Jetty ALPN dependencies - jetty_alpn_agent: 'org.mortbay.jetty.alpn:jetty-alpn-agent:2.0.10' - ] + libraries = libs appendToProperty = { Property property, String value, String separator -> if (property.present) { @@ -225,7 +158,7 @@ subprojects { // Disable JavaDoc doclint on Java 8. It's annoying. if (JavaVersion.current().isJava8Compatible()) { allprojects { - tasks.withType(Javadoc) { + tasks.withType(Javadoc).configureEach { options.addStringOption('Xdoclint:none', '-quiet') } } @@ -233,7 +166,7 @@ subprojects { checkstyle { configDirectory = file("$rootDir/buildscripts") - toolVersion = "6.17" + toolVersion = libs.versions.checkstyle.get() ignoreFailures = false if (rootProject.hasProperty("checkstyle.ignoreFailures")) { ignoreFailures = rootProject.properties["checkstyle.ignoreFailures"].toBoolean() @@ -242,12 +175,11 @@ subprojects { if (!project.hasProperty('errorProne') || errorProne.toBoolean()) { dependencies { - errorprone 'com.google.errorprone:error_prone_core:2.10.0' - errorproneJavac 'com.google.errorprone:javac:9+181-r4173-1' + errorprone libs.errorprone.core } } else { // Disable Error Prone - tasks.withType(JavaCompile) { + tasks.withType(JavaCompile).configureEach { options.errorprone.enabled = false } } @@ -258,40 +190,44 @@ subprojects { dependencies { testImplementation libraries.junit, - libraries.mockito, + libraries.mockito.core, libraries.truth } - compileTestJava { + tasks.named("compileTestJava").configure { // serialVersionUID is basically guaranteed to be useless in our tests options.compilerArgs += [ "-Xlint:-serial" ] } - jar.manifest { - attributes('Implementation-Title': name, - 'Implementation-Version': version) + tasks.named("jar").configure { + manifest { + attributes('Implementation-Title': name, + 'Implementation-Version': project.version) + } } - javadoc.options { - encoding = 'UTF-8' - use = true - links 'https://docs.oracle.com/javase/8/docs/api/' - source = "8" + tasks.named("javadoc").configure { + options { + encoding = 'UTF-8' + use = true + links 'https://docs.oracle.com/javase/8/docs/api/' + source = "8" + } } - checkstyleMain { + tasks.named("checkstyleMain").configure { source = fileTree(dir: "$projectDir/src/main", include: "**/*.java") } - checkstyleTest { + tasks.named("checkstyleTest").configure { source = fileTree(dir: "$projectDir/src/test", include: "**/*.java") } // At a test failure, log the stack trace to the console so that we don't // have to open the HTML in a browser. - test { + tasks.named("test").configure { testLogging { exceptionFormat = 'full' showExceptions true @@ -303,11 +239,11 @@ subprojects { if (!project.hasProperty('errorProne') || errorProne.toBoolean()) { dependencies { - annotationProcessor 'com.google.guava:guava-beta-checker:1.0' + annotationProcessor libs.guava.betaChecker } } - compileJava { + tasks.named("compileJava").configure { // This project targets Java 7 (no method references) options.errorprone.check("UnnecessaryAnonymousClass", CheckSeverity.OFF) // This project targets Java 7 (no time.Duration class) @@ -316,7 +252,7 @@ subprojects { // The warning fails to provide a source location options.errorprone.check("MissingSummary", CheckSeverity.OFF) } - compileTestJava { + tasks.named("compileTestJava").configure { // LinkedList doesn't hurt much in tests and has lots of usages options.errorprone.check("JdkObsolete", CheckSeverity.OFF) options.errorprone.check("UnnecessaryAnonymousClass", CheckSeverity.OFF) @@ -327,8 +263,7 @@ subprojects { plugins.withId("ru.vyarus.animalsniffer") { // Only available after java plugin has loaded animalsniffer { - // Breaks on upgrade: https://github.com/mojohaus/animal-sniffer/issues/131 - toolVersion = '1.18' + toolVersion = libs.versions.animalsniffer.get() } } } @@ -337,11 +272,13 @@ subprojects { // Detect Maven Enforcer's dependencyConvergence failures. We only care // for artifacts used as libraries by others with Maven. tasks.register('checkUpperBoundDeps') { + inputs.files(configurations.runtimeClasspath).withNormalizer(ClasspathNormalizer) + outputs.file("${buildDir}/tmp/${name}") // Fake output for UP-TO-DATE checking doLast { requireUpperBoundDepsMatch(configurations.runtimeClasspath, project) } } - tasks.named('compileJava') { + tasks.named('compileJava').configure { dependsOn checkUpperBoundDeps } } @@ -349,11 +286,11 @@ subprojects { plugins.withId("me.champeau.jmh") { // invoke jmh on a single benchmark class like so: // ./gradlew -PjmhIncludeSingleClass=StatsTraceContextBenchmark clean :grpc-core:jmh - compileJmhJava { + tasks.named("compileJmhJava").configure { sourceCompatibility = 1.8 targetCompatibility = 1.8 } - jmh { + tasks.named("jmh").configure { warmupIterations = 10 iterations = 10 fork = 1 @@ -362,13 +299,31 @@ subprojects { // depends on core; core's testCompile depends on testing) includeTests = false if (project.hasProperty('jmhIncludeSingleClass')) { - include = [ + includes = [ project.property('jmhIncludeSingleClass') ] } } } + plugins.withId("com.github.johnrengelman.shadow") { + tasks.named("shadowJar").configure { + // Do a dance to remove Class-Path. This needs to run after the doFirst() from the + // shadow plugin that adds Class-Path and before the core jar action. Using doFirst will + // have this run before the shadow plugin, and doLast will run after the core jar + // action. See #8606. + // The shadow plugin adds another doFirst when application is used for setting + // Main-Class. Ordering with it doesn't matter. + actions.add(plugins.hasPlugin("application") ? 2 : 1, new Action() { + @Override public void execute(Task task) { + if (!task.manifest.attributes.remove("Class-Path")) { + throw new AssertionError("Did not find Class-Path to remove from manifest") + } + } + }) + } + } + plugins.withId("maven-publish") { publishing { publications { @@ -498,7 +453,8 @@ subprojects { } // Add a japicmp task that compares the current .jar with baseline .jar - task japicmp(type: me.champeau.gradle.japicmp.JapicmpTask, dependsOn: jar) { + tasks.register("japicmp", me.champeau.gradle.japicmp.JapicmpTask) { + dependsOn jar oldClasspath = files(baselineArtifact) newClasspath = files(jar.archivePath) onlyBinaryIncompatibleModified = false diff --git a/buildscripts/checkstyle.xml b/buildscripts/checkstyle.xml index 52b564201c8..a5aded93a80 100644 --- a/buildscripts/checkstyle.xml +++ b/buildscripts/checkstyle.xml @@ -33,6 +33,11 @@ + + + + + @@ -45,10 +50,6 @@ - - - - @@ -59,12 +60,8 @@ - - - - + - @@ -204,13 +201,10 @@ - - - diff --git a/buildscripts/kokoro/bazel.sh b/buildscripts/kokoro/bazel.sh index 107a305bbf9..fa7224d8b0a 100755 --- a/buildscripts/kokoro/bazel.sh +++ b/buildscripts/kokoro/bazel.sh @@ -3,7 +3,7 @@ set -exu -o pipefail cat /VERSION -use_bazel.sh 4.0.0 +use_bazel.sh 5.0.0 bazel version cd github/grpc-java diff --git a/buildscripts/kokoro/gae-interop.sh b/buildscripts/kokoro/gae-interop.sh index b3973031d85..c4ce56cac52 100755 --- a/buildscripts/kokoro/gae-interop.sh +++ b/buildscripts/kokoro/gae-interop.sh @@ -29,6 +29,7 @@ cd "$GRPC_JAVA_DIR" ## Deploy the dummy 'default' version of the service ## GRADLE_FLAGS="--stacktrace -DgaeStopPreviousVersion=false -PskipCodegen=true -PskipAndroid=true" +export GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" # Deploy the dummy 'default' version. We only require that it exists when cleanup() is called. # It ok if we race with another run and fail here, because the end result is idempotent. diff --git a/buildscripts/kokoro/linux_artifacts.sh b/buildscripts/kokoro/linux_artifacts.sh index e23b2bcc628..619917bdceb 100755 --- a/buildscripts/kokoro/linux_artifacts.sh +++ b/buildscripts/kokoro/linux_artifacts.sh @@ -49,3 +49,7 @@ cp -r "$LOCAL_MVN_TEMP"/* "$MVN_ARTIFACT_DIR"/ # for aarch64 platform sudo apt-get install -y g++-aarch64-linux-gnu SKIP_TESTS=true ARCH=aarch_64 "$GRPC_JAVA_DIR"/buildscripts/kokoro/unix.sh + +# for ppc64le platform +sudo apt-get install -y g++-powerpc64le-linux-gnu +SKIP_TESTS=true ARCH=ppcle_64 "$GRPC_JAVA_DIR"/buildscripts/kokoro/unix.sh diff --git a/buildscripts/kokoro/psm-security.cfg b/buildscripts/kokoro/psm-security.cfg index f2cfd7babff..9df0bbb7867 100644 --- a/buildscripts/kokoro/psm-security.cfg +++ b/buildscripts/kokoro/psm-security.cfg @@ -7,7 +7,7 @@ timeout_mins: 180 action { define_artifacts { regex: "artifacts/**/*sponge_log.xml" - regex: "artifacts/**/*sponge_log.log" + regex: "artifacts/**/*.log" strip_prefix: "artifacts" } } diff --git a/buildscripts/kokoro/psm-security.sh b/buildscripts/kokoro/psm-security.sh index 105e67b2d0f..4f9ef07cc93 100755 --- a/buildscripts/kokoro/psm-security.sh +++ b/buildscripts/kokoro/psm-security.sh @@ -23,6 +23,7 @@ readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" build_java_test_app() { echo "Building Java test app" cd "${SRC_DIR}" + GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ -PskipCodegen=true -PskipAndroid=true --console=plain @@ -38,6 +39,7 @@ build_java_test_app() { # SERVER_IMAGE_NAME: Test server Docker image name # CLIENT_IMAGE_NAME: Test client Docker image name # GIT_COMMIT: SHA-1 of git commit being built +# TESTING_VERSION: version branch under test, f.e. v1.42.x, master # Arguments: # None # Outputs: @@ -53,10 +55,9 @@ build_test_app_docker_images() { cp -v "${docker_dir}/"*.properties "${build_dir}" cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" # Pick a branch name for the built image - if [[ -n $KOKORO_JOB_NAME ]]; then - branch_name=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/java/([^/]+)/.*|\1|') - else - branch_name='experimental' + local branch_name='experimental' + if is_version_branch "${TESTING_VERSION}"; then + branch_name="${TESTING_VERSION}" fi # Run Google Cloud Build gcloud builds submit "${build_dir}" \ @@ -106,6 +107,8 @@ build_docker_images_if_needed() { # SERVER_IMAGE_NAME: Test server Docker image name # CLIENT_IMAGE_NAME: Test client Docker image name # GIT_COMMIT: SHA-1 of git commit being built +# TESTING_VERSION: version branch under test: used by the framework to +# determine the supported PSM features. # Arguments: # Test case name # Outputs: @@ -116,15 +119,20 @@ run_test() { # Test driver usage: # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage local test_name="${1:?Usage: run_test test_name}" + local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}" + mkdir -pv "${out_dir}" set -x python -m "tests.${test_name}" \ --flagfile="${TEST_DRIVER_FLAGFILE}" \ --kube_context="${KUBE_CONTEXT}" \ --server_image="${SERVER_IMAGE_NAME}:${GIT_COMMIT}" \ --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ - --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ - --force_cleanup - set +x + --testing_version="${TESTING_VERSION}" \ + --force_cleanup \ + --collect_app_logs \ + --log_dir="${out_dir}" \ + --xml_output_file="${out_dir}/sponge_log.xml" \ + |& tee "${out_dir}/sponge_log.log" } ####################################### @@ -166,9 +174,15 @@ main() { build_docker_images_if_needed # Run tests cd "${TEST_DRIVER_FULL_DIR}" - run_test baseline_test - run_test security_test - run_test authz_test + local failed_tests=0 + test_suites=("baseline_test" "security_test" "authz_test") + for test in "${test_suites[@]}"; do + run_test $test || (( ++failed_tests )) + done + echo "Failed test suites: ${failed_tests}" + if (( failed_tests > 0 )); then + exit 1 + fi } main "$@" diff --git a/buildscripts/kokoro/unix.sh b/buildscripts/kokoro/unix.sh index 91ee67f1d98..828a599ff7c 100755 --- a/buildscripts/kokoro/unix.sh +++ b/buildscripts/kokoro/unix.sh @@ -9,6 +9,8 @@ # ARCH=x86_32 ./buildscripts/kokoro/unix.sh # For aarch64 arch: # ARCH=aarch_64 ./buildscripts/kokoro/unix.sh +# For ppc64le arch: +# ARCH=ppcle_64 ./buildscripts/kokoro/unix.sh # This script assumes `set -e`. Removing it may lead to undefined behavior. set -exu -o pipefail @@ -37,7 +39,7 @@ GRADLE_FLAGS+=" -PfailOnWarnings=true" GRADLE_FLAGS+=" -PerrorProne=true" GRADLE_FLAGS+=" -PskipAndroid=true" GRADLE_FLAGS+=" -Dorg.gradle.parallel=true" -export GRADLE_OPTS="-Xmx512m" +export GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" # Make protobuf discoverable by :grpc-compiler export LD_LIBRARY_PATH=/tmp/protobuf/lib @@ -81,12 +83,15 @@ if [[ -z "${SKIP_TESTS:-}" ]]; then ../gradlew build $GRADLE_FLAGS popd # TODO(zpencer): also build the GAE examples + pushd examples/example-orca + ../gradlew build $GRADLE_FLAGS + popd fi LOCAL_MVN_TEMP=$(mktemp -d) # Note that this disables parallel=true from GRADLE_FLAGS if [[ -z "${ALL_ARTIFACTS:-}" ]]; then - if [[ $ARCH == "aarch_64" ]]; then + if [[ "$ARCH" = "aarch_64" || "$ARCH" = "ppcle_64" ]]; then GRADLE_FLAGS+=" -x grpc-compiler:generateTestProto -x grpc-compiler:generateTestLiteProto" GRADLE_FLAGS+=" -x grpc-compiler:testGolden -x grpc-compiler:testLiteGolden" GRADLE_FLAGS+=" -x grpc-compiler:testDeprecatedGolden -x grpc-compiler:testDeprecatedLiteGolden" diff --git a/buildscripts/kokoro/upload_artifacts.sh b/buildscripts/kokoro/upload_artifacts.sh index 06d037831ec..ade37ee89bb 100644 --- a/buildscripts/kokoro/upload_artifacts.sh +++ b/buildscripts/kokoro/upload_artifacts.sh @@ -34,6 +34,9 @@ LOCAL_OTHER_ARTIFACTS="$KOKORO_GFILE_DIR"/github/grpc-java/artifacts/ # for linux aarch64 platform [[ "$(find "$LOCAL_MVN_ARTIFACTS" -type f -iname 'protoc-gen-grpc-java-*-linux-aarch_64.exe' | wc -l)" != '0' ]] +# for linux ppc64le platform +[[ "$(find "$LOCAL_MVN_ARTIFACTS" -type f -iname 'protoc-gen-grpc-java-*-linux-ppcle_64.exe' | wc -l)" != '0' ]] + # from macos job: [[ "$(find "$LOCAL_MVN_ARTIFACTS" -type f -iname 'protoc-gen-grpc-java-*-osx-x86_64.exe' | wc -l)" != '0' ]] # copy all x86 artifacts to aarch until native artifacts are built diff --git a/buildscripts/kokoro/windows32.bat b/buildscripts/kokoro/windows32.bat index 5a87b779da3..7c6491cf978 100644 --- a/buildscripts/kokoro/windows32.bat +++ b/buildscripts/kokoro/windows32.bat @@ -26,9 +26,10 @@ cd "%WORKSPACE%" SET TARGET_ARCH=x86_32 SET FAIL_ON_WARNINGS=true -SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\cmake\\build\\Release -SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\cmake\\build\\include +SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\Release +SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\include SET GRADLE_FLAGS=-PtargetArch=%TARGET_ARCH% -PfailOnWarnings=%FAIL_ON_WARNINGS% -PvcProtobufLibs=%VC_PROTOBUF_LIBS% -PvcProtobufInclude=%VC_PROTOBUF_INCLUDE% -PskipAndroid=true +SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" cmd.exe /C "%WORKSPACE%\gradlew.bat %GRADLE_FLAGS% build" set GRADLEEXIT=%ERRORLEVEL% diff --git a/buildscripts/kokoro/windows64.bat b/buildscripts/kokoro/windows64.bat index eaa1fcf845e..cf02336be36 100644 --- a/buildscripts/kokoro/windows64.bat +++ b/buildscripts/kokoro/windows64.bat @@ -25,9 +25,10 @@ cd "%WORKSPACE%" SET TARGET_ARCH=x86_64 SET FAIL_ON_WARNINGS=true -SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\cmake\\build\\Release -SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\cmake\\build\\include +SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\Release +SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\include SET GRADLE_FLAGS=-PtargetArch=%TARGET_ARCH% -PfailOnWarnings=%FAIL_ON_WARNINGS% -PvcProtobufLibs=%VC_PROTOBUF_LIBS% -PvcProtobufInclude=%VC_PROTOBUF_INCLUDE% -PskipAndroid=true +SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" @rem make sure no daemons have any files open cmd.exe /C "%WORKSPACE%\gradlew.bat --stop" diff --git a/buildscripts/kokoro/xds.cfg b/buildscripts/kokoro/xds.cfg deleted file mode 100644 index fdff1eb0df2..00000000000 --- a/buildscripts/kokoro/xds.cfg +++ /dev/null @@ -1,11 +0,0 @@ -# Config file for internal CI - -# Location of the continuous shell script in repository. -build_file: "grpc-java/buildscripts/kokoro/xds.sh" -timeout_mins: 360 -action { - define_artifacts { - regex: "**/*sponge_log.*" - regex: "github/grpc/reports/**" - } -} diff --git a/buildscripts/kokoro/xds.sh b/buildscripts/kokoro/xds.sh deleted file mode 100755 index 83ba7fc864f..00000000000 --- a/buildscripts/kokoro/xds.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -set -exu -o pipefail -if [[ -f /VERSION ]]; then - cat /VERSION -fi - -cd github - -pushd grpc-java/interop-testing -branch=$(git branch --all --no-color --contains "${KOKORO_GITHUB_COMMIT}" \ - | grep -v HEAD | head -1) -shopt -s extglob -branch="${branch//[[:space:]]}" -branch="${branch##remotes/origin/}" -shopt -u extglob -../gradlew installDist -x test -PskipCodegen=true -PskipAndroid=true -popd - -git clone -b "${branch}" --single-branch --depth=1 https://github.com/grpc/grpc.git - -grpc/tools/run_tests/helper_scripts/prep_xds.sh - -# Test cases "path_matching" and "header_matching" are not included in "all", -# because not all interop clients in all languages support these new tests. -# -# TODO(ericgribkoff): remove "path_matching" and "header_matching" from -# --test_case after they are added into "all". -JAVA_OPTS=-Djava.util.logging.config.file=grpc-java/buildscripts/xds_logging.properties \ - python3 grpc/tools/run_tests/run_xds_tests.py \ - --test_case="all,circuit_breaking,timeout,fault_injection,csds" \ - --project_id=grpc-testing \ - --project_num=830293263384 \ - --source_image=projects/grpc-testing/global/images/xds-test-server-5 \ - --path_to_server_binary=/java_server/grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-server \ - --gcp_suffix=$(date '+%s') \ - --verbose \ - ${XDS_V3_OPT-} \ - --client_cmd="grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-client \ - --server=xds:///{server_uri} \ - --stats_port={stats_port} \ - --qps={qps} \ - {rpcs_to_send} \ - {metadata_to_send}" diff --git a/buildscripts/kokoro/xds_k8s_lb.cfg b/buildscripts/kokoro/xds_k8s_lb.cfg index 43971896cd4..10ea2d43b5d 100644 --- a/buildscripts/kokoro/xds_k8s_lb.cfg +++ b/buildscripts/kokoro/xds_k8s_lb.cfg @@ -7,7 +7,7 @@ timeout_mins: 180 action { define_artifacts { regex: "artifacts/**/*sponge_log.xml" - regex: "artifacts/**/*sponge_log.log" + regex: "artifacts/**/*.log" strip_prefix: "artifacts" } } diff --git a/buildscripts/kokoro/xds_k8s_lb.sh b/buildscripts/kokoro/xds_k8s_lb.sh index 22f1e383440..9efb124465b 100755 --- a/buildscripts/kokoro/xds_k8s_lb.sh +++ b/buildscripts/kokoro/xds_k8s_lb.sh @@ -23,6 +23,7 @@ readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" build_java_test_app() { echo "Building Java test app" cd "${SRC_DIR}" + GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ -PskipCodegen=true -PskipAndroid=true --console=plain @@ -38,6 +39,7 @@ build_java_test_app() { # SERVER_IMAGE_NAME: Test server Docker image name # CLIENT_IMAGE_NAME: Test client Docker image name # GIT_COMMIT: SHA-1 of git commit being built +# TESTING_VERSION: version branch under test, f.e. v1.42.x, master # Arguments: # None # Outputs: @@ -53,10 +55,9 @@ build_test_app_docker_images() { cp -v "${docker_dir}/"*.properties "${build_dir}" cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" # Pick a branch name for the built image - if [[ -n $KOKORO_JOB_NAME ]]; then - branch_name=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/java/([^/]+)/.*|\1|') - else - branch_name='experimental' + local branch_name='experimental' + if is_version_branch "${TESTING_VERSION}"; then + branch_name="${TESTING_VERSION}" fi # Run Google Cloud Build gcloud builds submit "${build_dir}" \ @@ -102,6 +103,7 @@ build_docker_images_if_needed() { # Globals: # TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile # KUBE_CONTEXT: The name of kubectl context with GKE cluster access +# SECONDARY_KUBE_CONTEXT: The name of kubectl context with secondary GKE cluster access, if any # TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report # SERVER_IMAGE_NAME: Test server Docker image name # CLIENT_IMAGE_NAME: Test client Docker image name @@ -116,15 +118,21 @@ run_test() { # Test driver usage: # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage local test_name="${1:?Usage: run_test test_name}" + local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}" + mkdir -pv "${out_dir}" set -x python -m "tests.${test_name}" \ --flagfile="${TEST_DRIVER_FLAGFILE}" \ --kube_context="${KUBE_CONTEXT}" \ + --secondary_kube_context="${SECONDARY_KUBE_CONTEXT}" \ --server_image="${SERVER_IMAGE_NAME}:${GIT_COMMIT}" \ --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ - --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ - --force_cleanup - set +x + --testing_version="${TESTING_VERSION}" \ + --force_cleanup \ + --collect_app_logs \ + --log_dir="${out_dir}" \ + --xml_output_file="${out_dir}/sponge_log.xml" \ + |& tee "${out_dir}/sponge_log.log" } ####################################### @@ -155,7 +163,8 @@ main() { echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" - activate_gke_cluster GKE_CLUSTER_PSM_BASIC + activate_gke_cluster GKE_CLUSTER_PSM_LB + activate_secondary_gke_cluster GKE_CLUSTER_PSM_LB set -x if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then @@ -166,7 +175,15 @@ main() { build_docker_images_if_needed # Run tests cd "${TEST_DRIVER_FULL_DIR}" - run_test api_listener_test + local failed_tests=0 + test_suites=("api_listener_test" "change_backend_service_test" "failover_test" "remove_neg_test" "round_robin_test" "affinity_test" "outlier_detection_test" "custom_lb_test") + for test in "${test_suites[@]}"; do + run_test $test || (( ++failed_tests )) + done + echo "Failed test suites: ${failed_tests}" + if (( failed_tests > 0 )); then + exit 1 + fi } main "$@" diff --git a/buildscripts/kokoro/xds_url_map.cfg b/buildscripts/kokoro/xds_url_map.cfg index 4b5be84f880..1fa6c0141cb 100644 --- a/buildscripts/kokoro/xds_url_map.cfg +++ b/buildscripts/kokoro/xds_url_map.cfg @@ -7,7 +7,7 @@ timeout_mins: 90 action { define_artifacts { regex: "artifacts/**/*sponge_log.xml" - regex: "artifacts/**/*sponge_log.log" + regex: "artifacts/**/*.log" strip_prefix: "artifacts" } } diff --git a/buildscripts/kokoro/xds_url_map.sh b/buildscripts/kokoro/xds_url_map.sh index af856a693af..50737e1569d 100755 --- a/buildscripts/kokoro/xds_url_map.sh +++ b/buildscripts/kokoro/xds_url_map.sh @@ -23,6 +23,7 @@ readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" build_java_test_app() { echo "Building Java test app" cd "${SRC_DIR}" + GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ -PskipCodegen=true -PskipAndroid=true --console=plain @@ -38,6 +39,7 @@ build_java_test_app() { # SERVER_IMAGE_NAME: Test server Docker image name # CLIENT_IMAGE_NAME: Test client Docker image name # GIT_COMMIT: SHA-1 of git commit being built +# TESTING_VERSION: version branch under test, f.e. v1.42.x, master # Arguments: # None # Outputs: @@ -53,10 +55,9 @@ build_test_app_docker_images() { cp -v "${docker_dir}/"*.properties "${build_dir}" cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" # Pick a branch name for the built image - if [[ -n $KOKORO_JOB_NAME ]]; then - branch_name="$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/java/([^/]+)/.*|\1|')" - else - branch_name='experimental' + local branch_name='experimental' + if is_version_branch "${TESTING_VERSION}"; then + branch_name="${TESTING_VERSION}" fi # Run Google Cloud Build gcloud builds submit "${build_dir}" \ @@ -105,6 +106,8 @@ build_docker_images_if_needed() { # TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report # CLIENT_IMAGE_NAME: Test client Docker image name # GIT_COMMIT: SHA-1 of git commit being built +# TESTING_VERSION: version branch under test: used by the framework to +# determine the supported PSM features. # Arguments: # Test case name # Outputs: @@ -115,15 +118,19 @@ run_test() { # Test driver usage: # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage local test_name="${1:?Usage: run_test test_name}" + local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}" + mkdir -pv "${out_dir}" set -x python -m "tests.${test_name}" \ --flagfile="${TEST_DRIVER_FLAGFILE}" \ + --flagfile="config/url-map.cfg" \ --kube_context="${KUBE_CONTEXT}" \ --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ - --testing_version="$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/java/([^/]+)/.*|\1|')" \ - --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ - --flagfile="config/url-map.cfg" - set +x + --testing_version="${TESTING_VERSION}" \ + --collect_app_logs \ + --log_dir="${out_dir}" \ + --xml_output_file="${out_dir}/sponge_log.xml" \ + |& tee "${out_dir}/sponge_log.log" } ####################################### diff --git a/buildscripts/kokoro/xds_v3.sh b/buildscripts/kokoro/xds_v3.sh index 73eb50d248a..8c879e1095b 100755 --- a/buildscripts/kokoro/xds_v3.sh +++ b/buildscripts/kokoro/xds_v3.sh @@ -1,3 +1,34 @@ #!/bin/bash -XDS_V3_OPT="--xds_v3_support" `dirname $0`/xds.sh +set -exu -o pipefail +if [[ -f /VERSION ]]; then + cat /VERSION +fi + +cd github + +pushd grpc-java/interop-testing +GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ + ../gradlew installDist -x test -PskipCodegen=true -PskipAndroid=true +popd + +git clone -b master --single-branch --depth=1 https://github.com/grpc/grpc.git + +grpc/tools/run_tests/helper_scripts/prep_xds.sh + +JAVA_OPTS=-Djava.util.logging.config.file=grpc-java/buildscripts/xds_logging.properties \ + python3 grpc/tools/run_tests/run_xds_tests.py \ + --test_case="ping_pong,circuit_breaking" \ + --project_id=grpc-testing \ + --project_num=830293263384 \ + --source_image=projects/grpc-testing/global/images/xds-test-server-5 \ + --path_to_server_binary=/java_server/grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-server \ + --gcp_suffix=$(date '+%s') \ + --verbose \ + --xds_v3_support \ + --client_cmd="grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-client \ + --server=xds:///{server_uri} \ + --stats_port={stats_port} \ + --qps={qps} \ + {rpcs_to_send} \ + {metadata_to_send}" diff --git a/buildscripts/make_dependencies.bat b/buildscripts/make_dependencies.bat index 18c6086adab..2bbfd394d46 100644 --- a/buildscripts/make_dependencies.bat +++ b/buildscripts/make_dependencies.bat @@ -1,12 +1,12 @@ -set PROTOBUF_VER=3.19.2 +set PROTOBUF_VER=21.7 set CMAKE_NAME=cmake-3.3.2-win32-x86 -if not exist "protobuf-%PROTOBUF_VER%\cmake\build\Release\" ( +if not exist "protobuf-%PROTOBUF_VER%\build\Release\" ( call :installProto || exit /b 1 ) echo Compile gRPC-Java with something like: -echo -PtargetArch=x86_32 -PvcProtobufLibs=%cd%\protobuf-%PROTOBUF_VER%\cmake\build\Release -PvcProtobufInclude=%cd%\protobuf-%PROTOBUF_VER%\cmake\build\include +echo -PtargetArch=x86_32 -PvcProtobufLibs=%cd%\protobuf-%PROTOBUF_VER%\build\Release -PvcProtobufInclude=%cd%\protobuf-%PROTOBUF_VER%\build\include goto :eof @@ -23,10 +23,11 @@ set PATH=%PATH%;%cd%\%CMAKE_NAME%\bin powershell -command "$ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/google/protobuf/archive/v%PROTOBUF_VER%.zip -OutFile protobuf.zip }" || exit /b 1 powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('protobuf.zip', '.') }" || exit /b 1 del protobuf.zip -pushd protobuf-%PROTOBUF_VER%\cmake -mkdir build -cd build +mkdir protobuf-%PROTOBUF_VER%\build +pushd protobuf-%PROTOBUF_VER%\build +@rem Workaround https://github.com/protocolbuffers/protobuf/issues/10174 +powershell -command "(Get-Content ..\cmake\extract_includes.bat.in) -replace '\.\.\\', '' | Out-File -encoding ascii ..\cmake\extract_includes.bat.in" @rem cmake does not detect x86_64 from the vcvars64.bat variables. @rem If vcvars64.bat has set PLATFORM to X64, then inform cmake to use the Win64 version of VS if "%PLATFORM%" == "X64" ( diff --git a/buildscripts/make_dependencies.sh b/buildscripts/make_dependencies.sh index 5e7561c4313..0940132eea2 100755 --- a/buildscripts/make_dependencies.sh +++ b/buildscripts/make_dependencies.sh @@ -3,7 +3,7 @@ # Build protoc set -evux -o pipefail -PROTOBUF_VERSION=3.19.2 +PROTOBUF_VERSION=21.7 # ARCH is x86_64 bit unless otherwise specified. ARCH="${ARCH:-x86_64}" @@ -36,6 +36,10 @@ else --prefix="$INSTALL_DIR" elif [[ "$ARCH" == aarch* ]]; then ./configure --disable-shared --host=aarch64-linux-gnu --prefix="$INSTALL_DIR" + elif [[ "$ARCH" == ppc* ]]; then + ./configure --disable-shared --host=powerpc64le-linux-gnu --prefix="$INSTALL_DIR" + elif [[ "$ARCH" == loongarch* ]]; then + ./configure --disable-shared --host=loongarch64-unknown-linux-gnu --prefix="$INSTALL_DIR" fi # the same source dir is used for 32 and 64 bit builds, so we need to clean stale data first make clean diff --git a/buildscripts/run_arm64_tests_in_docker.sh b/buildscripts/run_arm64_tests_in_docker.sh index e7bf82022b7..76ef64ac6b6 100755 --- a/buildscripts/run_arm64_tests_in_docker.sh +++ b/buildscripts/run_arm64_tests_in_docker.sh @@ -10,15 +10,24 @@ else DOCKER_ARGS= fi + +cat <> "${grpc_java_dir}/gradle.properties" +skipAndroid=true +skipCodegen=true +org.gradle.parallel=true +org.gradle.jvmargs=-Xmx1024m +EOF + +export JAVA_OPTS="-Duser.home=/grpc-java/.current-user-home -Djava.util.prefs.userRoot=/grpc-java/.current-user-home/.java/.userPrefs" + # build under x64 docker image to save time over building everything under # aarch64 emulator. We've already built and tested the protoc binaries # so for the rest of the build we will be using "-PskipCodegen=true" # avoid further complicating the build. docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java \ - --user "$(id -u):$(id -g)" \ - -e "JAVA_OPTS=-Duser.home=/grpc-java/.current-user-home -Djava.util.prefs.userRoot=/grpc-java/.current-user-home/.java/.userPrefs" \ + --user "$(id -u):$(id -g)" -e JAVA_OPTS \ openjdk:11-jdk-slim-buster \ - ./gradlew build -x test -PskipAndroid=true -PskipCodegen=true + ./gradlew build -x test # Build and run java tests under aarch64 image. # To be able to run this docker container on x64 machine, one needs to have @@ -33,7 +42,6 @@ docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java # - run docker container under current user's UID to avoid polluting the workspace # - set the user.home property to avoid creating a "?" directory under grpc-java docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java \ - --user "$(id -u):$(id -g)" \ - -e "JAVA_OPTS=-Duser.home=/grpc-java/.current-user-home -Djava.util.prefs.userRoot=/grpc-java/.current-user-home/.java/.userPrefs" \ + --user "$(id -u):$(id -g)" -e JAVA_OPTS \ arm64v8/openjdk:11-jdk-slim-buster \ - ./gradlew build -PskipAndroid=true -PskipCodegen=true + ./gradlew build diff --git a/buildscripts/sync-protos.sh b/buildscripts/sync-protos.sh index 968147ccac1..5f01be2e5c9 100755 --- a/buildscripts/sync-protos.sh +++ b/buildscripts/sync-protos.sh @@ -8,7 +8,7 @@ curl -Ls https://github.com/grpc/grpc-proto/archive/master.tar.gz | tar xz -C "$ base="$tmpdir/grpc-proto-master" # Copy protos in 'src/main/proto' from grpc-proto for these projects -for project in alts grpclb services rls; do +for project in alts grpclb services rls interop-testing; do while read -r proto; do [ -f "$base/$proto" ] && cp "$base/$proto" "$project/src/main/proto/$proto" echo "$proto" diff --git a/census/build.gradle b/census/build.gradle index 35973a5f016..02b7e6395b1 100644 --- a/census/build.gradle +++ b/census/build.gradle @@ -1,6 +1,8 @@ plugins { id "java-library" id "maven-publish" + + id "ru.vyarus.animalsniffer" } description = 'gRPC: Census' @@ -10,17 +12,20 @@ evaluationDependsOn(project(':grpc-api').path) dependencies { api project(':grpc-api') implementation libraries.guava, - libraries.opencensus_api, - libraries.opencensus_contrib_grpc_metrics + libraries.opencensus.api, + libraries.opencensus.contrib.grpc.metrics testImplementation project(':grpc-api').sourceSets.test.output, project(':grpc-context').sourceSets.test.output, project(':grpc-core').sourceSets.test.output, project(':grpc-testing'), - libraries.opencensus_impl + libraries.opencensus.impl + + signature libraries.signature.java + signature libraries.signature.android } -javadoc { +tasks.named("javadoc").configure { failOnError false // no public or protected classes found to document exclude 'io/grpc/census/internal/**' exclude 'io/grpc/census/Internal*' diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index 366be55de68..03eaf73570d 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -188,7 +188,7 @@ private static final class ClientTracer extends ClientStreamTracer { @Nullable private static final AtomicLongFieldUpdater inboundUncompressedSizeUpdater; - /** + /* * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their * JDK reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to * (potentially racy) direct updates of the volatile variables. @@ -268,7 +268,7 @@ public void streamCreated(Attributes transportAttrs, Metadata headers) { } @Override - @SuppressWarnings("NonAtomicVolatileUpdate") + @SuppressWarnings({"NonAtomicVolatileUpdate", "NonAtomicOperationOnVolatileField"}) public void outboundWireSize(long bytes) { if (outboundWireSizeUpdater != null) { outboundWireSizeUpdater.getAndAdd(this, bytes); @@ -369,13 +369,11 @@ void recordFinishedAttempt() { // TODO(songya): remove the deprecated measure constants once they are completed removed. .put(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT, 1) // The latency is double value - .put( - DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, - roundtripNanos / NANOS_PER_MILLI) - .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT, outboundMessageCount) - .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT, inboundMessageCount) - .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES, outboundWireSize) - .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES, inboundWireSize) + .put(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY, roundtripNanos / NANOS_PER_MILLI) + .put(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC, outboundMessageCount) + .put(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_RPC, inboundMessageCount) + .put(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC, outboundWireSize) + .put(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_RPC, inboundWireSize) .put( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES, outboundUncompressedSize) @@ -443,7 +441,7 @@ static final class CallAttemptsTracerFactory extends if (module.recordStartedRpcs) { // Record here in case newClientStreamTracer() would never be called. module.statsRecorder.newMeasureMap() - .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) + .put(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS, 1) .record(startCtx); } } @@ -462,7 +460,7 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metada } if (module.recordStartedRpcs && attemptsPerCall.get() > 0) { module.statsRecorder.newMeasureMap() - .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) + .put(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS, 1) .record(startCtx); } if (info.isTransparentRetry()) { @@ -562,7 +560,7 @@ private static final class ServerTracer extends ServerStreamTracer { @Nullable private static final AtomicLongFieldUpdater inboundUncompressedSizeUpdater; - /** + /* * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their * JDK reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to * (potentially racy) direct updates of the volatile variables. @@ -628,7 +626,7 @@ private static final class ServerTracer extends ServerStreamTracer { this.stopwatch = module.stopwatchSupplier.get().start(); if (module.recordStartedRpcs) { module.statsRecorder.newMeasureMap() - .put(DeprecatedCensusConstants.RPC_SERVER_STARTED_COUNT, 1) + .put(RpcMeasureConstants.GRPC_SERVER_STARTED_RPCS, 1) .record(parentCtx); } } @@ -728,13 +726,11 @@ public void streamClosed(Status status) { // TODO(songya): remove the deprecated measure constants once they are completed removed. .put(DeprecatedCensusConstants.RPC_SERVER_FINISHED_COUNT, 1) // The latency is double value - .put( - DeprecatedCensusConstants.RPC_SERVER_SERVER_LATENCY, - elapsedTimeNanos / NANOS_PER_MILLI) - .put(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_COUNT, outboundMessageCount) - .put(DeprecatedCensusConstants.RPC_SERVER_REQUEST_COUNT, inboundMessageCount) - .put(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_BYTES, outboundWireSize) - .put(DeprecatedCensusConstants.RPC_SERVER_REQUEST_BYTES, inboundWireSize) + .put(RpcMeasureConstants.GRPC_SERVER_SERVER_LATENCY, elapsedTimeNanos / NANOS_PER_MILLI) + .put(RpcMeasureConstants.GRPC_SERVER_SENT_MESSAGES_PER_RPC, outboundMessageCount) + .put(RpcMeasureConstants.GRPC_SERVER_RECEIVED_MESSAGES_PER_RPC, inboundMessageCount) + .put(RpcMeasureConstants.GRPC_SERVER_SENT_BYTES_PER_RPC, outboundWireSize) + .put(RpcMeasureConstants.GRPC_SERVER_RECEIVED_BYTES_PER_RPC, inboundWireSize) .put( DeprecatedCensusConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES, outboundUncompressedSize) @@ -755,10 +751,7 @@ public void streamClosed(Status status) { @Override public Context filterContext(Context context) { - if (!module.tagger.empty().equals(parentCtx)) { - return ContextUtils.withValue(context, parentCtx); - } - return context; + return ContextUtils.withValue(context, parentCtx); } } diff --git a/census/src/main/java/io/grpc/census/CensusTracingModule.java b/census/src/main/java/io/grpc/census/CensusTracingModule.java index 5c635613f33..f413da94388 100644 --- a/census/src/main/java/io/grpc/census/CensusTracingModule.java +++ b/census/src/main/java/io/grpc/census/CensusTracingModule.java @@ -41,7 +41,6 @@ import io.opencensus.trace.Status; import io.opencensus.trace.Tracer; import io.opencensus.trace.propagation.BinaryFormat; -import io.opencensus.trace.unsafe.ContextUtils; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.logging.Level; import java.util.logging.Logger; @@ -66,7 +65,7 @@ final class CensusTracingModule { @Nullable private static final AtomicIntegerFieldUpdater streamClosedUpdater; - /** + /* * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their JDK * reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to * (potentially racy) direct updates of the volatile variables. @@ -93,9 +92,12 @@ final class CensusTracingModule { final Metadata.Key tracingHeader; private final TracingClientInterceptor clientInterceptor = new TracingClientInterceptor(); private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory(); + private final boolean addMessageEvents; CensusTracingModule( - Tracer censusTracer, final BinaryFormat censusPropagationBinaryFormat) { + Tracer censusTracer, + final BinaryFormat censusPropagationBinaryFormat, + boolean addMessageEvents) { this.censusTracer = checkNotNull(censusTracer, "censusTracer"); checkNotNull(censusPropagationBinaryFormat, "censusPropagationBinaryFormat"); this.tracingHeader = @@ -115,6 +117,7 @@ public SpanContext parseBytes(byte[] serialized) { } } }); + this.addMessageEvents = addMessageEvents; } /** @@ -212,9 +215,12 @@ private static EndSpanOptions createEndSpanOptions( .build(); } - private static void recordMessageEvent( + private void recordMessageEvent( Span span, MessageEvent.Type type, int seqNo, long optionalWireSize, long optionalUncompressedSize) { + if (!addMessageEvents) { + return; + } MessageEvent.Builder eventBuilder = MessageEvent.builder(type, seqNo); if (optionalUncompressedSize != -1) { eventBuilder.setUncompressedMessageSize(optionalUncompressedSize); @@ -283,7 +289,7 @@ void callEnded(io.grpc.Status status) { } } - private static final class ClientTracer extends ClientStreamTracer { + private final class ClientTracer extends ClientStreamTracer { private final Span span; final Metadata.Key tracingHeader; final boolean isSampledToLocalTracing; @@ -366,12 +372,18 @@ public void streamClosed(io.grpc.Status status) { span.end(createEndSpanOptions(status, isSampledToLocalTracing)); } + /* + TODO(dnvindhya): Replace deprecated ContextUtils usage with ContextHandleUtils to interact + with io.grpc.Context as described in {@link io.opencensus.trace.unsafeContextUtils} to remove + SuppressWarnings annotation. + */ + @SuppressWarnings("deprecation") @Override public Context filterContext(Context context) { // Access directly the unsafe trace API to create the new Context. This is a safe usage // because gRPC always creates a new Context for each of the server calls and does not // inherit from the parent Context. - return ContextUtils.withValue(context, span); + return io.opencensus.trace.unsafe.ContextUtils.withValue(context, span); } @Override @@ -404,6 +416,8 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata @VisibleForTesting final class TracingClientInterceptor implements ClientInterceptor { + + @SuppressWarnings("deprecation") @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { @@ -412,7 +426,8 @@ public ClientCall interceptCall( // as Tracer.getCurrentSpan() except when no value available when the return value is null // for the direct access and BlankSpan when Tracer API is used. final CallAttemptsTracerFactory tracerFactory = - newClientCallTracer(ContextUtils.getValue(Context.current()), method); + newClientCallTracer( + io.opencensus.trace.unsafe.ContextUtils.getValue(Context.current()), method); ClientCall call = next.newCall( method, diff --git a/census/src/main/java/io/grpc/census/InternalCensusTracingAccessor.java b/census/src/main/java/io/grpc/census/InternalCensusTracingAccessor.java index 2df6c5fb4bd..d04d10739ea 100644 --- a/census/src/main/java/io/grpc/census/InternalCensusTracingAccessor.java +++ b/census/src/main/java/io/grpc/census/InternalCensusTracingAccessor.java @@ -36,10 +36,22 @@ private InternalCensusTracingAccessor() { * Returns a {@link ClientInterceptor} with default tracing implementation. */ public static ClientInterceptor getClientInterceptor() { + return getClientInterceptor(true); + } + + /** + * Returns the client interceptor that facilitates Census-based stats reporting. + * + * @param addMessageEvents add message events to Spans + * @return a {@link ClientInterceptor} with default tracing implementation. + */ + public static ClientInterceptor getClientInterceptor( + boolean addMessageEvents) { CensusTracingModule censusTracing = new CensusTracingModule( Tracing.getTracer(), - Tracing.getPropagationComponent().getBinaryFormat()); + Tracing.getPropagationComponent().getBinaryFormat(), + addMessageEvents); return censusTracing.getClientInterceptor(); } @@ -47,10 +59,19 @@ public static ClientInterceptor getClientInterceptor() { * Returns a {@link ServerStreamTracer.Factory} with default stats implementation. */ public static ServerStreamTracer.Factory getServerStreamTracerFactory() { + return getServerStreamTracerFactory(true); + } + + /** + * Returns a {@link ServerStreamTracer.Factory} with default stats implementation. + */ + public static ServerStreamTracer.Factory getServerStreamTracerFactory( + boolean addMessageEvents) { CensusTracingModule censusTracing = new CensusTracingModule( Tracing.getTracer(), - Tracing.getPropagationComponent().getBinaryFormat()); + Tracing.getPropagationComponent().getBinaryFormat(), + addMessageEvents); return censusTracing.getServerTracerFactory(); } } diff --git a/census/src/main/java/io/grpc/census/internal/DeprecatedCensusConstants.java b/census/src/main/java/io/grpc/census/internal/DeprecatedCensusConstants.java index 2b0a4763a28..7470bc81b15 100644 --- a/census/src/main/java/io/grpc/census/internal/DeprecatedCensusConstants.java +++ b/census/src/main/java/io/grpc/census/internal/DeprecatedCensusConstants.java @@ -27,49 +27,23 @@ public final class DeprecatedCensusConstants { public static final MeasureLong RPC_CLIENT_ERROR_COUNT = RpcMeasureConstants.RPC_CLIENT_ERROR_COUNT; - public static final MeasureDouble RPC_CLIENT_REQUEST_BYTES = - RpcMeasureConstants.RPC_CLIENT_REQUEST_BYTES; - public static final MeasureDouble RPC_CLIENT_RESPONSE_BYTES = - RpcMeasureConstants.RPC_CLIENT_RESPONSE_BYTES; - public static final MeasureDouble RPC_CLIENT_ROUNDTRIP_LATENCY = - RpcMeasureConstants.RPC_CLIENT_ROUNDTRIP_LATENCY; - public static final MeasureDouble RPC_CLIENT_SERVER_ELAPSED_TIME = - RpcMeasureConstants.RPC_CLIENT_SERVER_ELAPSED_TIME; public static final MeasureDouble RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES = RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES; public static final MeasureDouble RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES = RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES; - public static final MeasureLong RPC_CLIENT_STARTED_COUNT = - RpcMeasureConstants.RPC_CLIENT_STARTED_COUNT; public static final MeasureLong RPC_CLIENT_FINISHED_COUNT = RpcMeasureConstants.RPC_CLIENT_FINISHED_COUNT; - public static final MeasureLong RPC_CLIENT_REQUEST_COUNT = - RpcMeasureConstants.RPC_CLIENT_REQUEST_COUNT; - public static final MeasureLong RPC_CLIENT_RESPONSE_COUNT = - RpcMeasureConstants.RPC_CLIENT_RESPONSE_COUNT; public static final MeasureLong RPC_SERVER_ERROR_COUNT = RpcMeasureConstants.RPC_SERVER_ERROR_COUNT; - public static final MeasureDouble RPC_SERVER_REQUEST_BYTES = - RpcMeasureConstants.RPC_SERVER_REQUEST_BYTES; - public static final MeasureDouble RPC_SERVER_RESPONSE_BYTES = - RpcMeasureConstants.RPC_SERVER_RESPONSE_BYTES; public static final MeasureDouble RPC_SERVER_SERVER_ELAPSED_TIME = RpcMeasureConstants.RPC_SERVER_SERVER_ELAPSED_TIME; - public static final MeasureDouble RPC_SERVER_SERVER_LATENCY = - RpcMeasureConstants.RPC_SERVER_SERVER_LATENCY; public static final MeasureDouble RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES = RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES; public static final MeasureDouble RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES = RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES; - public static final MeasureLong RPC_SERVER_STARTED_COUNT = - RpcMeasureConstants.RPC_SERVER_STARTED_COUNT; public static final MeasureLong RPC_SERVER_FINISHED_COUNT = RpcMeasureConstants.RPC_SERVER_FINISHED_COUNT; - public static final MeasureLong RPC_SERVER_REQUEST_COUNT = - RpcMeasureConstants.RPC_SERVER_REQUEST_COUNT; - public static final MeasureLong RPC_SERVER_RESPONSE_COUNT = - RpcMeasureConstants.RPC_SERVER_RESPONSE_COUNT; private DeprecatedCensusConstants() {} } diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index b710d1b4112..9768797d579 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -95,7 +95,6 @@ import io.opencensus.trace.Tracer; import io.opencensus.trace.propagation.BinaryFormat; import io.opencensus.trace.propagation.SpanContextParseException; -import io.opencensus.trace.unsafe.ContextUtils; import java.io.InputStream; import java.util.HashSet; import java.util.List; @@ -226,7 +225,7 @@ public void setUp() throws Exception { new CensusStatsModule( tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true, true, true, false /* real-time */, true); - censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler); + censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler, true); } @After @@ -247,6 +246,7 @@ public void clientInterceptorCustomTag() { // Test that Census ClientInterceptors uses the TagContext and Span out of the current Context // to create the ClientCallTracer, and that it intercepts ClientCall.Listener.onClose() to call // ClientCallTracer.callEnded(). + @SuppressWarnings("deprecation") private void testClientInterceptors(boolean nonDefaultContext) { grpcServerRule.getServiceRegistry().addService( ServerServiceDefinition.builder("package1.service2").addMethod( @@ -284,7 +284,7 @@ public ClientCall interceptCall( .emptyBuilder() .putLocal(StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")) .build()); - ctx = ContextUtils.withValue(ctx, fakeClientParentSpan); + ctx = io.opencensus.trace.unsafe.ContextUtils.withValue(ctx, fakeClientParentSpan); Context origCtx = ctx.attach(); try { call = interceptedChannel.newCall(method, CALL_OPTIONS); @@ -295,7 +295,8 @@ public ClientCall interceptCall( assertEquals( io.opencensus.tags.unsafe.ContextUtils.getValue(Context.ROOT), io.opencensus.tags.unsafe.ContextUtils.getValue(Context.current())); - assertEquals(ContextUtils.getValue(Context.current()), BlankSpan.INSTANCE); + assertEquals(io.opencensus.trace.unsafe.ContextUtils.getValue(Context.current()), + BlankSpan.INSTANCE); call = interceptedChannel.newCall(method, CALL_OPTIONS); } @@ -415,8 +416,7 @@ private void subtestClientBasicStatsDefaultContext( assertEquals(1, record.tags.size()); TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); assertEquals(method.getFullMethodName(), methodTag.asString()); - assertEquals( - 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)); } else { assertNull(statsRecorder.pollRecord()); } @@ -483,25 +483,26 @@ private void subtestClientBasicStatsDefaultContext( 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + 2, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( 1028 + 99, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); assertEquals( 1128 + 865, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); + 2, + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_RPC)); assertEquals( 33 + 154, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_RPC)); assertEquals( 67 + 552, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals(30 + 100 + 16 + 24, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); assertZeroRetryRecorded(); } else { assertNull(statsRecorder.pollRecord()); @@ -525,8 +526,7 @@ public void recordRetryStats() { assertEquals(1, record.tags.size()); TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); assertEquals(method.getFullMethodName(), methodTag.asString()); - assertEquals( - 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)); fakeClock.forwardTime(30, MILLISECONDS); tracer.outboundHeaders(); @@ -552,16 +552,16 @@ record = statsRecorder.pollRecord(); 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + 2, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( - 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + 1028, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); assertEquals( 1128, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals( 30 + 100 + 24, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); // faking retry fakeClock.forwardTime(1000, MILLISECONDS); @@ -570,8 +570,7 @@ record = statsRecorder.pollRecord(); assertEquals(1, record.tags.size()); methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); assertEquals(method.getFullMethodName(), methodTag.asString()); - assertEquals( - 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)); tracer.outboundHeaders(); tracer.outboundMessage(0); assertRealTimeMetric( @@ -594,16 +593,16 @@ record = statsRecorder.pollRecord(); 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + 2, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( - 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + 1028, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); assertEquals( 1128, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals( 100 , - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); // fake transparent retry fakeClock.forwardTime(10, MILLISECONDS); @@ -613,8 +612,7 @@ record = statsRecorder.pollRecord(); assertEquals(1, record.tags.size()); methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); assertEquals(method.getFullMethodName(), methodTag.asString()); - assertEquals( - 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)); tracer.streamClosed(Status.UNAVAILABLE); record = statsRecorder.pollRecord(); statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); @@ -623,9 +621,9 @@ record = statsRecorder.pollRecord(); 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); assertEquals( - 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + 0, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( - 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + 0, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); // fake another transparent retry fakeClock.forwardTime(10, MILLISECONDS); @@ -633,8 +631,7 @@ record = statsRecorder.pollRecord(); STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); record = statsRecorder.pollRecord(); assertEquals(1, record.tags.size()); - assertEquals( - 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)); tracer.outboundHeaders(); tracer.outboundMessage(0); assertRealTimeMetric( @@ -666,25 +663,25 @@ record = statsRecorder.pollRecord(); 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); assertThat(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)).isNull(); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + 2, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( - 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + 1028, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); assertEquals( 1128, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals( - 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); + 1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_RPC)); assertEquals( 33, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_RPC)); assertEquals( 67, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals( 16 + 24 , - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); record = statsRecorder.pollRecord(); methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); @@ -817,9 +814,7 @@ public void clientStreamNeverCreatedStillRecordStats() { assertEquals(1, record.tags.size()); TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); assertEquals(method.getFullMethodName(), methodTag.asString()); - assertEquals( - 1, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)); // Completion record record = statsRecorder.pollRecord(); @@ -836,24 +831,24 @@ record = statsRecorder.pollRecord(); 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); assertEquals( - 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + 0, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( - 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + 0, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); assertEquals( 0, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals( - 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); + 0, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_RPC)); assertEquals( - 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); + 0, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_RPC)); assertEquals(0, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals( 3000, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_SERVER_ELAPSED_TIME)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_SERVER_LATENCY)); assertZeroRetryRecorded(); } @@ -1035,6 +1030,7 @@ public void statsHeaderMalformed() { assertSame(tagger.empty(), headers.get(censusStats.statsHeader)); } + @SuppressWarnings("deprecation") @Test public void traceHeadersPropagateSpanContext() throws Exception { CallAttemptsTracerFactory callTracer = @@ -1062,7 +1058,7 @@ public void traceHeadersPropagateSpanContext() throws Exception { verify(spyServerSpanBuilder).setRecordEvents(eq(true)); Context filteredContext = serverTracer.filterContext(Context.ROOT); - assertSame(spyServerSpan, ContextUtils.getValue(filteredContext)); + assertSame(spyServerSpan, io.opencensus.trace.unsafe.ContextUtils.getValue(filteredContext)); } @Test @@ -1180,9 +1176,7 @@ private void subtestServerBasicStatsNoHeaders( assertEquals(1, record.tags.size()); TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_SERVER_METHOD); assertEquals(method.getFullMethodName(), methodTag.asString()); - assertEquals( - 1, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_STARTED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_STARTED_RPCS)); } else { assertNull(statsRecorder.pollRecord()); } @@ -1256,29 +1250,31 @@ private void subtestServerBasicStatsNoHeaders( assertEquals( 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_ERROR_COUNT)); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_COUNT)); + 2, record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_SENT_MESSAGES_PER_RPC)); assertEquals( 1028 + 99, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_BYTES)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_SENT_BYTES_PER_RPC)); assertEquals( 1128 + 865, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals( - 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_REQUEST_COUNT)); + 2, + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_RECEIVED_MESSAGES_PER_RPC)); assertEquals( 34 + 154, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_REQUEST_BYTES)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_RECEIVED_BYTES_PER_RPC)); assertEquals(67 + 552, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); assertEquals(100 + 16 + 24, - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_SERVER_LATENCY)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_SERVER_LATENCY)); } else { assertNull(statsRecorder.pollRecord()); } } + @SuppressWarnings("deprecation") @Test public void serverBasicTracingNoHeaders() { ServerStreamTracer.Factory tracerFactory = censusTracing.getServerTracerFactory(); @@ -1290,7 +1286,7 @@ public void serverBasicTracingNoHeaders() { verify(spyServerSpanBuilder).setRecordEvents(eq(true)); Context filteredContext = serverStreamTracer.filterContext(Context.ROOT); - assertSame(spyServerSpan, ContextUtils.getValue(filteredContext)); + assertSame(spyServerSpan, io.opencensus.trace.unsafe.ContextUtils.getValue(filteredContext)); serverStreamTracer.serverCallStarted( new CallInfo<>(method, Attributes.EMPTY, null)); @@ -1396,24 +1392,24 @@ public void generateTraceSpanName() { private static void assertNoServerContent(StatsTestUtils.MetricsRecord record) { assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_ERROR_COUNT)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_REQUEST_COUNT)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_COUNT)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_REQUEST_BYTES)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_RECEIVED_MESSAGES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_SENT_MESSAGES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_RECEIVED_BYTES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_SENT_BYTES_PER_RPC)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_SERVER_ELAPSED_TIME)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_SERVER_LATENCY)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_SERVER_LATENCY)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); } private static void assertNoClientContent(StatsTestUtils.MetricsRecord record) { assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); - assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_SERVER_ELAPSED_TIME)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_RPC)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); + assertNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_SERVER_LATENCY)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); } @@ -1511,7 +1507,7 @@ public Long apply(AggregationData arg) { }); } - private static class CallInfo extends ServerCallInfo { + static class CallInfo extends ServerCallInfo { private final MethodDescriptor methodDescriptor; private final Attributes attributes; private final String authority; diff --git a/census/src/test/java/io/grpc/census/CensusTracingNoMessageEventTest.java b/census/src/test/java/io/grpc/census/CensusTracingNoMessageEventTest.java new file mode 100644 index 00000000000..1bdefe4e749 --- /dev/null +++ b/census/src/test/java/io/grpc/census/CensusTracingNoMessageEventTest.java @@ -0,0 +1,200 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.census; + +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; + +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.census.CensusTracingModule.CallAttemptsTracerFactory; +import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder; +import io.grpc.internal.testing.StatsTestUtils.MockableSpan; +import io.grpc.testing.GrpcServerRule; +import io.opencensus.trace.MessageEvent; +import io.opencensus.trace.Span; +import io.opencensus.trace.SpanBuilder; +import io.opencensus.trace.SpanContext; +import io.opencensus.trace.Tracer; +import io.opencensus.trace.propagation.BinaryFormat; +import java.io.InputStream; +import java.util.Random; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Test for {@link CensusTracingModule}. + */ +@RunWith(JUnit4.class) +public class CensusTracingNoMessageEventTest { + private static final ClientStreamTracer.StreamInfo STREAM_INFO = + ClientStreamTracer.StreamInfo.newBuilder().build(); + + private static class StringInputStream extends InputStream { + final String string; + + StringInputStream(String string) { + this.string = string; + } + + @Override + public int read() { + // InProcessTransport doesn't actually read bytes from the InputStream. The InputStream is + // passed to the InProcess server and consumed by MARSHALLER.parse(). + throw new UnsupportedOperationException("Should not be called"); + } + } + + private static final MethodDescriptor.Marshaller MARSHALLER = + new MethodDescriptor.Marshaller() { + @Override + public InputStream stream(String value) { + return new StringInputStream(value); + } + + @Override + public String parse(InputStream stream) { + return ((StringInputStream) stream).string; + } + }; + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + private final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setRequestMarshaller(MARSHALLER) + .setResponseMarshaller(MARSHALLER) + .setFullMethodName("package1.service2/method3") + .build(); + + private final FakeStatsRecorder statsRecorder = new FakeStatsRecorder(); + private final Random random = new Random(1234); + private final Span spyClientSpan = spy(MockableSpan.generateRandomSpan(random)); + private final Span spyAttemptSpan = spy(MockableSpan.generateRandomSpan(random)); + private final SpanContext fakeAttemptSpanContext = spyAttemptSpan.getContext(); + private final Span spyServerSpan = spy(MockableSpan.generateRandomSpan(random)); + private final byte[] binarySpanContext = new byte[]{3, 1, 5}; + private final SpanBuilder spyClientSpanBuilder = spy(new MockableSpan.Builder()); + private final SpanBuilder spyAttemptSpanBuilder = spy(new MockableSpan.Builder()); + private final SpanBuilder spyServerSpanBuilder = spy(new MockableSpan.Builder()); + + @Rule + public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); + + @Mock + private Tracer tracer; + @Mock + private BinaryFormat mockTracingPropagationHandler; + + @Captor + private ArgumentCaptor messageEventCaptor; + + private CensusTracingModule censusTracing; + + @Before + public void setUp() throws Exception { + when(spyClientSpanBuilder.startSpan()).thenReturn(spyClientSpan); + when(spyAttemptSpanBuilder.startSpan()).thenReturn(spyAttemptSpan); + when(tracer.spanBuilderWithExplicitParent( + eq("Sent.package1.service2.method3"), ArgumentMatchers.any())) + .thenReturn(spyClientSpanBuilder); + when(tracer.spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), ArgumentMatchers.any())) + .thenReturn(spyAttemptSpanBuilder); + when(spyServerSpanBuilder.startSpan()).thenReturn(spyServerSpan); + when(tracer.spanBuilderWithRemoteParent(anyString(), ArgumentMatchers.any())) + .thenReturn(spyServerSpanBuilder); + when(mockTracingPropagationHandler.toByteArray(any(SpanContext.class))) + .thenReturn(binarySpanContext); + when(mockTracingPropagationHandler.fromByteArray(any(byte[].class))) + .thenReturn(fakeAttemptSpanContext); + + censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler, false); + } + + @After + public void wrapUp() { + assertNull(statsRecorder.pollRecord()); + } + + @Test + public void clientBasicTracingNoMessageEvents() { + CallAttemptsTracerFactory callTracer = + censusTracing.newClientCallTracer(null, method); + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + + clientStreamTracer.outboundMessage(0); + clientStreamTracer.outboundMessageSent(0, 882, -1); + clientStreamTracer.inboundMessage(0); + clientStreamTracer.outboundMessage(1); + clientStreamTracer.outboundMessageSent(1, -1, 27); + clientStreamTracer.inboundMessageRead(0, 255, 90); + + clientStreamTracer.streamClosed(Status.OK); + callTracer.callEnded(Status.OK); + + InOrder inOrder = inOrder(spyClientSpan, spyAttemptSpan); + inOrder.verify(spyAttemptSpan, times(0)).addMessageEvent(messageEventCaptor.capture()); + } + + @Test + public void serverBasicTracingNoMessageEvents() { + ServerStreamTracer.Factory tracerFactory = censusTracing.getServerTracerFactory(); + ServerStreamTracer serverStreamTracer = + tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); + + serverStreamTracer.serverCallStarted( + new CensusModulesTest.CallInfo<>(method, Attributes.EMPTY, null)); + + serverStreamTracer.outboundMessage(0); + serverStreamTracer.outboundMessageSent(0, 882, -1); + serverStreamTracer.inboundMessage(0); + serverStreamTracer.outboundMessage(1); + serverStreamTracer.outboundMessageSent(1, -1, 27); + serverStreamTracer.inboundMessageRead(0, 255, 90); + + serverStreamTracer.streamClosed(Status.CANCELLED); + + InOrder inOrder = inOrder(spyServerSpan); + inOrder.verify(spyServerSpan, times(0)).addMessageEvent(messageEventCaptor.capture()); + } +} diff --git a/compiler/build.gradle b/compiler/build.gradle index 0b5766578a8..c942442e6de 100644 --- a/compiler/build.gradle +++ b/compiler/build.gradle @@ -50,12 +50,16 @@ model { } } gcc(Gcc) { - target("ppcle_64") + target("ppcle_64") { + cppCompiler.executable = 'powerpc64le-linux-gnu-g++' + linker.executable = 'powerpc64le-linux-gnu-g++' + } target("aarch_64") { cppCompiler.executable = 'aarch64-linux-gnu-g++' linker.executable = 'aarch64-linux-gnu-g++' } target("s390_64") + target("loongarch_64") } clang(Clang) { } @@ -67,6 +71,7 @@ model { ppcle_64 { architecture "ppcle_64" } aarch_64 { architecture "aarch_64" } s390_64 { architecture "s390_64" } + loongarch_64 { architecture "loongarch_64" } } components { @@ -76,7 +81,8 @@ model { 'x86_64', 'ppcle_64', 'aarch_64', - 's390_64' + 's390_64', + 'loongarch_64' ]) { // If arch is not within the defined platforms, we do not specify the // targetPlatform so that Gradle will choose what is appropriate. @@ -134,10 +140,10 @@ configurations { dependencies { testImplementation project(':grpc-protobuf'), project(':grpc-stub'), - libraries.javax_annotation + libraries.javax.annotation testLiteImplementation project(':grpc-protobuf-lite'), project(':grpc-stub'), - libraries.javax_annotation + libraries.javax.annotation } sourceSets { @@ -146,11 +152,11 @@ sourceSets { } } -compileTestJava { +tasks.named("compileTestJava").configure { options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" } -compileTestLiteJava { +tasks.named("compileTestLiteJava").configure { options.compilerArgs = compileTestJava.options.compilerArgs options.compilerArgs += [ "-Xlint:-cast" @@ -158,12 +164,16 @@ compileTestLiteJava { options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" } +tasks.named("checkstyleTestLite").configure { + enabled = false +} + protobuf { protoc { if (project.hasProperty('protoc')) { path = project.protoc } else { - artifact = "com.google.protobuf:protoc:${protocVersion}" + artifact = libs.protobuf.protoc.get() } } plugins { @@ -171,25 +181,33 @@ protobuf { } generateProtoTasks { all().each { task -> - task.dependsOn 'java_pluginExecutable' - task.inputs.file javaPluginPath + task.configure { + dependsOn 'java_pluginExecutable' + inputs.file javaPluginPath + } } - ofSourceSet('test')*.plugins { grpc {} } - ofSourceSet('testLite')*.each { task -> - task.builtins { - java { option 'lite' } + ofSourceSet('test').each { task -> + task.configure { + plugins { grpc {} } } - task.plugins { - grpc { option 'lite' } + } + ofSourceSet('testLite').each { task -> + task.configure { + builtins { + java { option 'lite' } + } + plugins { + grpc { option 'lite' } + } } } } } -println "*** Building codegen requires Protobuf version ${protocVersion}" +println "*** Building codegen requires Protobuf version ${libs.versions.protobuf.get()}" println "*** Please refer to https://github.com/grpc/grpc-java/blob/master/COMPILING.md#how-to-build-code-generation-plugin" -task buildArtifacts(type: Copy) { +tasks.register("buildArtifacts", Copy) { dependsOn 'java_pluginExecutable' from("$buildDir/exe") { if (osdetector.os != 'windows') { @@ -201,7 +219,7 @@ task buildArtifacts(type: Copy) { archivesBaseName = "$protocPluginBaseName" -task checkArtifacts { +def checkArtifacts = tasks.register("checkArtifacts") { dependsOn buildArtifacts doLast { if (!usingVisualCpp) { @@ -284,11 +302,15 @@ def configureTestTask(Task task, String dep, String extraPackage, String service "$projectDir/src/test${dep}/golden/${serviceName}.java.txt" } -task testGolden(type: Exec) -task testLiteGolden(type: Exec) -task testDeprecatedGolden(type: Exec) -task testDeprecatedLiteGolden(type: Exec) -configureTestTask(testGolden, '', '', 'TestService') -configureTestTask(testLiteGolden, 'Lite', '', 'TestService') -configureTestTask(testDeprecatedGolden, '', '', 'TestDeprecatedService') -configureTestTask(testDeprecatedLiteGolden, 'Lite', '', 'TestDeprecatedService') +tasks.register("testGolden", Exec) { + configureTestTask(it, '', '', 'TestService') +} +tasks.register("testLiteGolden", Exec) { + configureTestTask(it, 'Lite', '', 'TestService') +} +tasks.register("testDeprecatedGolden", Exec) { + configureTestTask(it, '', '', 'TestDeprecatedService') +} +tasks.register("testDeprecatedLiteGolden", Exec) { + configureTestTask(it, 'Lite', '', 'TestDeprecatedService') +} diff --git a/compiler/check-artifact.sh b/compiler/check-artifact.sh index 13ae89c744a..67f01aa97cd 100755 --- a/compiler/check-artifact.sh +++ b/compiler/check-artifact.sh @@ -61,6 +61,13 @@ checkArch () assertEq "$format" "elf64-x86-64" $LINENO elif [[ "$ARCH" == aarch_64 ]]; then assertEq "$format" "elf64-little" $LINENO + elif [[ "$ARCH" == loongarch_64 ]]; then + echo $format + assertEq "$format" "elf64-loongarch" $LINENO + elif [[ "$ARCH" == ppcle_64 ]]; then + format="$(powerpc64le-linux-gnu-objdump -f "$1" | grep -o "file format .*$" | grep -o "[^ ]*$")" + echo Format=$format + assertEq "$format" "elf64-powerpcle" $LINENO else fail "Unsupported arch: $ARCH" fi @@ -108,6 +115,12 @@ checkDependencies () elif [[ "$ARCH" == aarch_64 ]]; then dump_cmd='aarch64-linux-gnu-objdump -x '"$1"' |grep "NEEDED"' white_list="linux-vdso\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld-linux-aarch64\.so\.1" + elif [[ "$ARCH" == loongarch_64 ]]; then + dump_cmd='objdump -x '"$1"' | grep NEEDED' + white_list="linux-vdso\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld\.so\.1" + elif [[ "$ARCH" == ppcle_64 ]]; then + dump_cmd='powerpc64le-linux-gnu-objdump -x '"$1"' |grep "NEEDED"' + white_list="linux-vdso64\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld64\.so\.2" fi elif [[ "$OS" == osx ]]; then dump_cmd='otool -L '"$1"' | fgrep dylib' diff --git a/compiler/src/java_plugin/cpp/java_generator.cpp b/compiler/src/java_plugin/cpp/java_generator.cpp index 6882940378f..3bb56ae12ee 100644 --- a/compiler/src/java_plugin/cpp/java_generator.cpp +++ b/compiler/src/java_plugin/cpp/java_generator.cpp @@ -22,11 +22,18 @@ #include #include #include -#include #include #include #include #include +#include + +// Protobuf 3.21 changed the name of this file. +#if GOOGLE_PROTOBUF_VERSION >= 3021000 + #include +#else + #include +#endif // Stringify helpers used solely to cast GRPC_VERSION #ifndef STR diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index bae773cf5cb..f09966db336 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.45.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.53.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index ac5f7b1e84a..c823a9d1542 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.45.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.53.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 70b348a33b0..ad31021c560 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.45.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.53.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index 2ca755211f3..2dbc8478050 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.45.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.53.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/context/src/main/java/io/grpc/Context.java b/context/src/main/java/io/grpc/Context.java index 41d3a5c94a6..f63f021216c 100644 --- a/context/src/main/java/io/grpc/Context.java +++ b/context/src/main/java/io/grpc/Context.java @@ -1000,6 +1000,8 @@ public String toString() { */ public abstract static class Storage { /** + * Unused. + * * @deprecated This is an old API that is no longer used. */ @Deprecated @@ -1029,7 +1031,7 @@ public Context doAttach(Context toAttach) { } /** - * Implements {@link io.grpc.Context#detach} + * Implements {@link io.grpc.Context#detach}. * * @param toDetach the context to be detached. Should be, or be equivalent to, the current * context of the current scope diff --git a/context/src/main/java/io/grpc/Deadline.java b/context/src/main/java/io/grpc/Deadline.java index 73c87605953..62b803267a8 100644 --- a/context/src/main/java/io/grpc/Deadline.java +++ b/context/src/main/java/io/grpc/Deadline.java @@ -45,7 +45,7 @@ public final class Deadline implements Comparable { * *

This is EXPERIMENTAL API and may subject to change. If you'd like it to be * stabilized or have any feedback, please - * let us know. + * let us know. * * @since 1.24.0 */ @@ -81,7 +81,7 @@ public static Deadline after(long duration, TimeUnit units) { * *

This is EXPERIMENTAL API and may subject to change. If you'd like it to be * stabilized or have any feedback, please - * let us know. + * let us know. * * @param duration A non-negative duration. * @param units The time unit for the duration. @@ -113,7 +113,8 @@ private Deadline(Ticker ticker, long baseInstant, long offset, } /** - * Has this deadline expired + * Returns whether this has deadline expired. + * * @return {@code true} if it has, otherwise {@code false}. */ public boolean isExpired() { @@ -266,7 +267,7 @@ public boolean equals(final Object o) { * *

This is EXPERIMENTAL API and may subject to change. If you'd like it to be * stabilized or have any feedback, please - * let us know. + * let us know. * *

In general implementations should be thread-safe, unless it's implemented and used in a * localized environment (like unit tests) where you are sure the usages are synchronized. diff --git a/context/src/test/java/io/grpc/ContextTest.java b/context/src/test/java/io/grpc/ContextTest.java index 5171d4839a9..c2f36f41d71 100644 --- a/context/src/test/java/io/grpc/ContextTest.java +++ b/context/src/test/java/io/grpc/ContextTest.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.truth.TruthJUnit.assume; import static io.grpc.Context.cancellableAncestor; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.IsInstanceOf.instanceOf; @@ -872,6 +873,20 @@ public String call() { @Test public void storageReturnsNullTest() throws Exception { + // TODO(sergiitk): JDK-8210522 changes the behaviour of Java reflection to filter out + // security-sensitive fields in the java.lang.reflect.Field. This prohibits + // Field.class.getDeclaredFields("modifiers") call we rely on in this test. + // Until we have a good solution for setting a custom storage for testing purposes, + // we'll have to skip this test for JDK >= 11. Ref https://bugs.openjdk.org/browse/JDK-8210522 + double javaVersion; + // Graceful version check. Run the test if the version undetermined. + try { + javaVersion = Double.parseDouble(System.getProperty("java.specification.version", "0")); + } catch (NumberFormatException e) { + javaVersion = 0; + } + assume().that(javaVersion).isLessThan(11); + Class lazyStorageClass = Class.forName("io.grpc.Context$LazyStorage"); Field storage = lazyStorageClass.getDeclaredField("storage"); assertTrue(Modifier.isFinal(storage.getModifiers())); diff --git a/context/src/test/java/io/grpc/PersistentHashArrayMappedTrieTest.java b/context/src/test/java/io/grpc/PersistentHashArrayMappedTrieTest.java index f02c2916e36..aeac2f51a5b 100644 --- a/context/src/test/java/io/grpc/PersistentHashArrayMappedTrieTest.java +++ b/context/src/test/java/io/grpc/PersistentHashArrayMappedTrieTest.java @@ -84,15 +84,25 @@ public void leaf_insert() { assertEquals(2, ret.size()); } - @Test(expected = AssertionError.class) + @SuppressWarnings("CheckReturnValue") + @Test public void collisionLeaf_assertKeysDifferent() { Key key1 = new Key(0); - new CollisionLeaf<>(key1, new Object(), key1, new Object()); + try { + new CollisionLeaf<>(key1, new Object(), key1, new Object()); + throw new Error(); + } catch (AssertionError expected) { + } } - @Test(expected = AssertionError.class) + @SuppressWarnings("CheckReturnValue") + @Test public void collisionLeaf_assertHashesSame() { - new CollisionLeaf<>(new Key(0), new Object(), new Key(1), new Object()); + try { + new CollisionLeaf<>(new Key(0), new Object(), new Key(1), new Object()); + throw new Error(); + } catch (AssertionError expected) { + } } @Test diff --git a/context/src/test/java/io/grpc/testing/DeadlineSubject.java b/context/src/test/java/io/grpc/testing/DeadlineSubject.java index 820f91248a1..5d4e86fac15 100644 --- a/context/src/test/java/io/grpc/testing/DeadlineSubject.java +++ b/context/src/test/java/io/grpc/testing/DeadlineSubject.java @@ -19,12 +19,12 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Fact.fact; import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; import com.google.common.truth.ComparableSubject; import com.google.common.truth.FailureMetadata; import com.google.common.truth.Subject; import io.grpc.Deadline; -import java.math.BigInteger; import java.util.concurrent.TimeUnit; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; @@ -32,6 +32,7 @@ /** Propositions for {@link Deadline} subjects. */ @SuppressWarnings("rawtypes") // Generics in this class are going away in a subsequent Truth. public final class DeadlineSubject extends ComparableSubject { + public static final double NANOSECONDS_IN_A_SECOND = SECONDS.toNanos(1) * 1.0; private static final Subject.Factory deadlineFactory = new Factory(); @@ -60,14 +61,14 @@ public void of(Deadline expected) { checkNotNull(actual, "actual value cannot be null. expected=%s", expected); // This is probably overkill, but easier than thinking about overflow. - BigInteger actualTimeRemaining = BigInteger.valueOf(actual.timeRemaining(NANOSECONDS)); - BigInteger expectedTimeRemaining = BigInteger.valueOf(expected.timeRemaining(NANOSECONDS)); - BigInteger deltaNanos = BigInteger.valueOf(timeUnit.toNanos(delta)); - if (actualTimeRemaining.subtract(expectedTimeRemaining).abs().compareTo(deltaNanos) > 0) { + long actualNanos = actual.timeRemaining(NANOSECONDS); + long expectedNanos = expected.timeRemaining(NANOSECONDS); + long deltaNanos = timeUnit.toNanos(delta) ; + if (Math.abs(actualNanos - expectedNanos) > deltaNanos) { failWithoutActual( - fact("expected", expected), - fact("but was", actual), - fact("outside tolerance in ns", deltaNanos)); + fact("expected", expectedNanos / NANOSECONDS_IN_A_SECOND), + fact("but was", expectedNanos / NANOSECONDS_IN_A_SECOND), + fact("outside tolerance in seconds", deltaNanos / NANOSECONDS_IN_A_SECOND)); } } }; diff --git a/core/BUILD.bazel b/core/BUILD.bazel index 60a08798d58..3ca51e66c94 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -65,13 +65,12 @@ java_library( ) # Mirrors the dependencies included in the artifact on Maven Central for usage -# with maven_install's override_targets. Purposefully does not export any -# symbols, as it should only be used as a dep for pre-compiled binaries on -# Maven Central. +# with maven_install's override_targets. Should only be used as a dep for +# pre-compiled binaries on Maven Central. java_library( name = "core_maven", visibility = ["//visibility:public"], - runtime_deps = [ + exports = [ ":inprocess", ":internal", ":util", diff --git a/core/build.gradle b/core/build.gradle index bc8231fd9e1..9778c035f5c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -25,16 +25,16 @@ evaluationDependsOn(project(':grpc-api').path) dependencies { api project(':grpc-api') implementation libraries.gson, - libraries.android_annotations, - libraries.animalsniffer_annotations, - libraries.errorprone, + libraries.android.annotations, + libraries.animalsniffer.annotations, + libraries.errorprone.annotations, libraries.guava, - libraries.perfmark + libraries.perfmark.api testImplementation project(':grpc-context').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, project(':grpc-testing'), project(':grpc-grpclb') - testImplementation (libraries.guava_testlib) { + testImplementation (libraries.guava.testlib) { exclude group: 'junit', module: 'junit' } @@ -42,11 +42,11 @@ dependencies { jmh project(':grpc-testing') - signature "org.codehaus.mojo.signature:java17:1.0@signature" - signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" + signature libraries.signature.java + signature libraries.signature.android } -javadoc { +tasks.named("javadoc").configure { exclude 'io/grpc/internal/**' exclude 'io/grpc/inprocess/Internal*' // Disabled until kinda stable. @@ -87,7 +87,7 @@ def replaceConstant(File file, String needle, String replacement) { } plugins.withId("java") { - compileJava { + tasks.named("compileJava").configure { doLast { // Replace value of Signature Attribute. // https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.7.9 @@ -109,13 +109,13 @@ plugins.withId("java") { } } - compileJmhJava { + tasks.named("compileJmhJava").configure { // This project targets Java 7 (no method references) options.errorprone.check("UnnecessaryAnonymousClass", CheckSeverity.OFF) } } -task versionFile() { +tasks.register("versionFile") { doLast { new File(buildDir, "version").write("${project.version}\n") } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index df396ae2f66..35998a535e2 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -114,6 +114,10 @@ public ClientTransportFactory buildClientTransportFactory() { managedChannelImplBuilder.setStatsRecordStartedRpcs(false); managedChannelImplBuilder.setStatsRecordFinishedRpcs(false); managedChannelImplBuilder.setStatsRecordRetryMetrics(false); + + // By default, In-process transport should not be retriable as that leaks memory. Since + // there is no wire, bytes aren't calculated so buffer limit isn't respected + managedChannelImplBuilder.disableRetry(); } @Internal @@ -123,7 +127,7 @@ protected ManagedChannelBuilder delegate() { } @Override - public final InProcessChannelBuilder maxInboundMessageSize(int max) { + public InProcessChannelBuilder maxInboundMessageSize(int max) { // TODO(carl-mastrangelo): maybe throw an exception since this not enforced? return super.maxInboundMessageSize(max); } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessSocketAddress.java b/core/src/main/java/io/grpc/inprocess/InProcessSocketAddress.java index e5f0515f1d0..98ecf7e1a5f 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessSocketAddress.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessSocketAddress.java @@ -29,7 +29,9 @@ public final class InProcessSocketAddress extends SocketAddress { private final String name; /** - * @param name - The name of the inprocess channel or server. + * Construct an address for a server identified by name. + * + * @param name The name of the inprocess server. * @since 1.0.0 */ public InProcessSocketAddress(String name) { @@ -37,7 +39,7 @@ public InProcessSocketAddress(String name) { } /** - * Gets the name of the inprocess channel or server. + * Gets the name of the inprocess server. * * @since 1.0.0 */ @@ -46,6 +48,8 @@ public String getName() { } /** + * Returns {@link #getName}. + * * @since 1.14.0 */ @Override @@ -53,15 +57,14 @@ public String toString() { return name; } - /** - * @since 1.15.0 - */ @Override public int hashCode() { return name.hashCode(); } /** + * Returns {@code true} if the object is of the same type and server names match. + * * @since 1.15.0 */ @Override diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 2f4870fdcc2..1c2ac3df22b 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -40,6 +40,7 @@ import io.grpc.SecurityLevel; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -65,6 +66,7 @@ import java.util.Collections; import java.util.IdentityHashMap; import java.util.List; +import java.util.Locale; import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -106,6 +108,18 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans private List serverStreamTracerFactories; private final Attributes attributes; + private Thread.UncaughtExceptionHandler uncaughtExceptionHandler = + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + if (e instanceof Error) { + throw new Error(e); + } + throw new RuntimeException(e); + } + }; + + @GuardedBy("this") private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -150,9 +164,9 @@ public InProcessTransport( String name, int maxInboundMetadataSize, String authority, String userAgent, Attributes eagAttrs, ObjectPool serverSchedulerPool, List serverStreamTracerFactories, - ServerListener serverListener) { + ServerListener serverListener, boolean includeCauseWithStatus) { this(new InProcessSocketAddress(name), maxInboundMetadataSize, authority, userAgent, eagAttrs, - Optional.of(serverListener), false); + Optional.of(serverListener), includeCauseWithStatus); this.serverMaxInboundMetadataSize = maxInboundMetadataSize; this.serverSchedulerPool = serverSchedulerPool; this.serverStreamTracerFactories = serverStreamTracerFactories; @@ -227,6 +241,7 @@ public synchronized ClientStream newStream( // statuscodes.md is updated. Status status = Status.RESOURCE_EXHAUSTED.withDescription( String.format( + Locale.US, "Request metadata larger than %d: %d", serverMaxInboundMetadataSize, metadataSize)); @@ -407,8 +422,10 @@ private void streamClosed() { private class InProcessServerStream implements ServerStream { final StatsTraceContext statsTraceCtx; - @GuardedBy("this") + // All callbacks must run in syncContext to avoid possibility of deadlock in direct executors private ClientStreamListener clientStreamListener; + private final SynchronizationContext syncContext = + new SynchronizationContext(uncaughtExceptionHandler); @GuardedBy("this") private int clientRequested; @GuardedBy("this") @@ -444,10 +461,11 @@ public void request(int numMessages) { if (onReady) { synchronized (this) { if (!closed) { - clientStreamListener.onReady(); + syncContext.executeLater(() -> clientStreamListener.onReady()); } } } + syncContext.drain(); } // This method is the only reason we have to synchronize field accesses. @@ -456,28 +474,36 @@ public void request(int numMessages) { * * @return whether onReady should be called on the server */ - private synchronized boolean clientRequested(int numMessages) { - if (closed) { - return false; - } - boolean previouslyReady = clientRequested > 0; - clientRequested += numMessages; - while (clientRequested > 0 && !clientReceiveQueue.isEmpty()) { - clientRequested--; - clientStreamListener.messagesAvailable(clientReceiveQueue.poll()); - } - // Attempt being reentrant-safe - if (closed) { - return false; - } - if (clientReceiveQueue.isEmpty() && clientNotifyStatus != null) { - closed = true; - clientStream.statsTraceCtx.clientInboundTrailers(clientNotifyTrailers); - clientStream.statsTraceCtx.streamClosed(clientNotifyStatus); - clientStreamListener.closed( - clientNotifyStatus, RpcProgress.PROCESSED, clientNotifyTrailers); + private boolean clientRequested(int numMessages) { + boolean previouslyReady; + boolean nowReady; + synchronized (this) { + if (closed) { + return false; + } + + previouslyReady = clientRequested > 0; + clientRequested += numMessages; + while (clientRequested > 0 && !clientReceiveQueue.isEmpty()) { + clientRequested--; + StreamListener.MessageProducer producer = clientReceiveQueue.poll(); + syncContext.executeLater(() -> clientStreamListener.messagesAvailable(producer)); + } + + if (clientReceiveQueue.isEmpty() && clientNotifyStatus != null) { + closed = true; + clientStream.statsTraceCtx.clientInboundTrailers(clientNotifyTrailers); + clientStream.statsTraceCtx.streamClosed(clientNotifyStatus); + Status notifyStatus = this.clientNotifyStatus; + Metadata notifyTrailers = this.clientNotifyTrailers; + syncContext.executeLater(() -> + clientStreamListener.closed(notifyStatus, RpcProgress.PROCESSED, notifyTrailers)); + } + + nowReady = clientRequested > 0; } - boolean nowReady = clientRequested > 0; + + syncContext.drain(); return !previouslyReady && nowReady; } @@ -486,22 +512,26 @@ private void clientCancelled(Status status) { } @Override - public synchronized void writeMessage(InputStream message) { - if (closed) { - return; - } - statsTraceCtx.outboundMessage(outboundSeqNo); - statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); - clientStream.statsTraceCtx.inboundMessage(outboundSeqNo); - clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); - outboundSeqNo++; - StreamListener.MessageProducer producer = new SingleMessageProducer(message); - if (clientRequested > 0) { - clientRequested--; - clientStreamListener.messagesAvailable(producer); - } else { - clientReceiveQueue.add(producer); + public void writeMessage(InputStream message) { + synchronized (this) { + if (closed) { + return; + } + statsTraceCtx.outboundMessage(outboundSeqNo); + statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); + clientStream.statsTraceCtx.inboundMessage(outboundSeqNo); + clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + outboundSeqNo++; + StreamListener.MessageProducer producer = new SingleMessageProducer(message); + if (clientRequested > 0) { + clientRequested--; + syncContext.executeLater(() -> clientStreamListener.messagesAvailable(producer)); + } else { + clientReceiveQueue.add(producer); + } } + + syncContext.drain(); } @Override @@ -526,6 +556,7 @@ public void writeHeaders(Metadata headers) { // Status, which may need to be updated if statuscodes.md is updated. Status failedStatus = Status.RESOURCE_EXHAUSTED.withDescription( String.format( + Locale.US, "Response header metadata larger than %d: %d", clientMaxInboundMetadataSize, metadataSize)); @@ -540,8 +571,9 @@ public void writeHeaders(Metadata headers) { } clientStream.statsTraceCtx.clientInboundHeaders(); - clientStreamListener.headersRead(headers); + syncContext.executeLater(() -> clientStreamListener.headersRead(headers)); } + syncContext.drain(); } @Override @@ -564,6 +596,7 @@ public void close(Status status, Metadata trailers) { // Status, which may need to be updated if statuscodes.md is updated. status = Status.RESOURCE_EXHAUSTED.withDescription( String.format( + Locale.US, "Response header metadata larger than %d: %d", clientMaxInboundMetadataSize, metadataSize)); @@ -585,13 +618,14 @@ private void notifyClientClose(Status status, Metadata trailers) { closed = true; clientStream.statsTraceCtx.clientInboundTrailers(trailers); clientStream.statsTraceCtx.streamClosed(clientStatus); - clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, trailers); + syncContext.executeLater( + () -> clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, trailers)); } else { clientNotifyStatus = clientStatus; clientNotifyTrailers = trailers; } } - + syncContext.drain(); streamClosed(); } @@ -604,24 +638,29 @@ public void cancel(Status status) { streamClosed(); } - private synchronized boolean internalCancel(Status clientStatus) { - if (closed) { - return false; - } - closed = true; - StreamListener.MessageProducer producer; - while ((producer = clientReceiveQueue.poll()) != null) { - InputStream message; - while ((message = producer.next()) != null) { - try { - message.close(); - } catch (Throwable t) { - log.log(Level.WARNING, "Exception closing stream", t); + private boolean internalCancel(Status clientStatus) { + synchronized (this) { + if (closed) { + return false; + } + closed = true; + StreamListener.MessageProducer producer; + while ((producer = clientReceiveQueue.poll()) != null) { + InputStream message; + while ((message = producer.next()) != null) { + try { + message.close(); + } catch (Throwable t) { + log.log(Level.WARNING, "Exception closing stream", t); + } } } + clientStream.statsTraceCtx.streamClosed(clientStatus); + syncContext.executeLater( + () -> + clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, new Metadata())); } - clientStream.statsTraceCtx.streamClosed(clientStatus); - clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, new Metadata()); + syncContext.drain(); return true; } @@ -662,8 +701,10 @@ public int streamId() { private class InProcessClientStream implements ClientStream { final StatsTraceContext statsTraceCtx; final CallOptions callOptions; - @GuardedBy("this") + // All callbacks must run in syncContext to avoid possibility of deadlock in direct executors private ServerStreamListener serverStreamListener; + private final SynchronizationContext syncContext = + new SynchronizationContext(uncaughtExceptionHandler); @GuardedBy("this") private int serverRequested; @GuardedBy("this") @@ -693,9 +734,10 @@ public void request(int numMessages) { if (onReady) { synchronized (this) { if (!closed) { - serverStreamListener.onReady(); + syncContext.executeLater(() -> serverStreamListener.onReady()); } } + syncContext.drain(); } } @@ -705,21 +747,29 @@ public void request(int numMessages) { * * @return whether onReady should be called on the server */ - private synchronized boolean serverRequested(int numMessages) { - if (closed) { - return false; - } - boolean previouslyReady = serverRequested > 0; - serverRequested += numMessages; - while (serverRequested > 0 && !serverReceiveQueue.isEmpty()) { - serverRequested--; - serverStreamListener.messagesAvailable(serverReceiveQueue.poll()); - } - if (serverReceiveQueue.isEmpty() && serverNotifyHalfClose) { - serverNotifyHalfClose = false; - serverStreamListener.halfClosed(); + private boolean serverRequested(int numMessages) { + boolean previouslyReady; + boolean nowReady; + synchronized (this) { + if (closed) { + return false; + } + previouslyReady = serverRequested > 0; + serverRequested += numMessages; + + while (serverRequested > 0 && !serverReceiveQueue.isEmpty()) { + serverRequested--; + StreamListener.MessageProducer producer = serverReceiveQueue.poll(); + syncContext.executeLater(() -> serverStreamListener.messagesAvailable(producer)); + } + + if (serverReceiveQueue.isEmpty() && serverNotifyHalfClose) { + serverNotifyHalfClose = false; + syncContext.executeLater(() -> serverStreamListener.halfClosed()); + } + nowReady = serverRequested > 0; } - boolean nowReady = serverRequested > 0; + syncContext.drain(); return !previouslyReady && nowReady; } @@ -728,22 +778,25 @@ private void serverClosed(Status serverListenerStatus, Status serverTracerStatus } @Override - public synchronized void writeMessage(InputStream message) { - if (closed) { - return; - } - statsTraceCtx.outboundMessage(outboundSeqNo); - statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); - serverStream.statsTraceCtx.inboundMessage(outboundSeqNo); - serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); - outboundSeqNo++; - StreamListener.MessageProducer producer = new SingleMessageProducer(message); - if (serverRequested > 0) { - serverRequested--; - serverStreamListener.messagesAvailable(producer); - } else { - serverReceiveQueue.add(producer); + public void writeMessage(InputStream message) { + synchronized (this) { + if (closed) { + return; + } + statsTraceCtx.outboundMessage(outboundSeqNo); + statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); + serverStream.statsTraceCtx.inboundMessage(outboundSeqNo); + serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + outboundSeqNo++; + StreamListener.MessageProducer producer = new SingleMessageProducer(message); + if (serverRequested > 0) { + serverRequested--; + syncContext.executeLater(() -> serverStreamListener.messagesAvailable(producer)); + } else { + serverReceiveQueue.add(producer); + } } + syncContext.drain(); } @Override @@ -768,39 +821,45 @@ public void cancel(Status reason) { streamClosed(); } - private synchronized boolean internalCancel( + private boolean internalCancel( Status serverListenerStatus, Status serverTracerStatus) { - if (closed) { - return false; - } - closed = true; - - StreamListener.MessageProducer producer; - while ((producer = serverReceiveQueue.poll()) != null) { - InputStream message; - while ((message = producer.next()) != null) { - try { - message.close(); - } catch (Throwable t) { - log.log(Level.WARNING, "Exception closing stream", t); + synchronized (this) { + if (closed) { + return false; + } + closed = true; + + StreamListener.MessageProducer producer; + while ((producer = serverReceiveQueue.poll()) != null) { + InputStream message; + while ((message = producer.next()) != null) { + try { + message.close(); + } catch (Throwable t) { + log.log(Level.WARNING, "Exception closing stream", t); + } } } + serverStream.statsTraceCtx.streamClosed(serverTracerStatus); + syncContext.executeLater(() -> serverStreamListener.closed(serverListenerStatus)); } - serverStream.statsTraceCtx.streamClosed(serverTracerStatus); - serverStreamListener.closed(serverListenerStatus); + syncContext.drain(); return true; } @Override - public synchronized void halfClose() { - if (closed) { - return; - } - if (serverReceiveQueue.isEmpty()) { - serverStreamListener.halfClosed(); - } else { - serverNotifyHalfClose = true; + public void halfClose() { + synchronized (this) { + if (closed) { + return; + } + if (serverReceiveQueue.isEmpty()) { + syncContext.executeLater(() -> serverStreamListener.halfClosed()); + } else { + serverNotifyHalfClose = true; + } } + syncContext.drain(); } @Override diff --git a/core/src/main/java/io/grpc/inprocess/InternalInProcess.java b/core/src/main/java/io/grpc/inprocess/InternalInProcess.java index 021b07a80bc..680373533c8 100644 --- a/core/src/main/java/io/grpc/inprocess/InternalInProcess.java +++ b/core/src/main/java/io/grpc/inprocess/InternalInProcess.java @@ -51,7 +51,8 @@ public static ConnectionClientTransport createInProcessTransport( Attributes eagAttrs, ObjectPool serverSchedulerPool, List serverStreamTracerFactories, - ServerListener serverListener) { + ServerListener serverListener, + boolean includeCauseWithStatus) { return new InProcessTransport( name, maxInboundMetadataSize, @@ -60,6 +61,7 @@ public static ConnectionClientTransport createInProcessTransport( eagAttrs, serverSchedulerPool, serverStreamTracerFactories, - serverListener); + serverListener, + includeCauseWithStatus); } } diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 531206b29ca..4ef743bf96d 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -331,8 +331,7 @@ protected void inboundHeadersReceived(Metadata headers) { if (compressedStream) { deframeFailed( Status.INTERNAL - .withDescription( - String.format("Full stream and gRPC message encoding cannot both be set")) + .withDescription("Full stream and gRPC message encoding cannot both be set") .asRuntimeException()); return; } diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index c1268cfdb09..574117d0461 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -148,6 +148,48 @@ public T handshakeTimeout(long timeout, TimeUnit unit) { return thisT(); } + @Override + public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + delegate().keepAliveTime(keepAliveTime, timeUnit); + return thisT(); + } + + @Override + public T keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { + delegate().keepAliveTimeout(keepAliveTimeout, timeUnit); + return thisT(); + } + + @Override + public T maxConnectionIdle(long maxConnectionIdle, TimeUnit timeUnit) { + delegate().maxConnectionIdle(maxConnectionIdle, timeUnit); + return thisT(); + } + + @Override + public T maxConnectionAge(long maxConnectionAge, TimeUnit timeUnit) { + delegate().maxConnectionAge(maxConnectionAge, timeUnit); + return thisT(); + } + + @Override + public T maxConnectionAgeGrace(long maxConnectionAgeGrace, TimeUnit timeUnit) { + delegate().maxConnectionAgeGrace(maxConnectionAgeGrace, timeUnit); + return thisT(); + } + + @Override + public T permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + delegate().permitKeepAliveTime(keepAliveTime, timeUnit); + return thisT(); + } + + @Override + public T permitKeepAliveWithoutCalls(boolean permit) { + delegate().permitKeepAliveWithoutCalls(permit); + return thisT(); + } + @Override public T maxInboundMessageSize(int bytes) { delegate().maxInboundMessageSize(bytes); diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index 3513ec9346f..94cdfa4a572 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -50,7 +50,7 @@ protected interface Sink { * @param flush {@code true} if more data may not be arriving soon * @param numMessages the number of messages this frame represents */ - void writeFrame(@Nullable WritableBuffer frame, boolean flush, int numMessages); + void writeFrame(WritableBuffer frame, boolean flush, int numMessages); /** * Sends trailers to the remote end point. This call implies end of stream. @@ -108,7 +108,14 @@ public final void deliverFrame( WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) { // Since endOfStream is triggered by the sending of trailers, avoid flush here and just flush // after the trailers. - abstractServerStreamSink().writeFrame(frame, endOfStream ? false : flush, numMessages); + if (frame == null) { + assert endOfStream; + return; + } + if (endOfStream) { + flush = false; + } + abstractServerStreamSink().writeFrame(frame, flush, numMessages); } @Override @@ -217,8 +224,8 @@ public final void onStreamAllocated() { @Override public void deframerClosed(boolean hasPartialMessage) { deframerClosed = true; - if (endOfStream) { - if (!immediateCloseRequested && hasPartialMessage) { + if (endOfStream && !immediateCloseRequested) { + if (hasPartialMessage) { // We've received the entire stream and have data available but we don't have // enough to read the next frame ... this is bad. deframeFailed( diff --git a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java index 01c48b9efcf..b8c9cab7459 100644 --- a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java @@ -20,11 +20,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import io.grpc.Attributes; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; -import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; @@ -67,10 +65,14 @@ private static final class NoopLoadBalancer extends LoadBalancer { @Override @Deprecated - public void handleResolvedAddressGroups(List s, Attributes a) {} + @SuppressWarnings("InlineMeSuggester") + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {} + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + return true; + } @Override public void handleNameResolutionError(Status error) {} @@ -97,14 +99,10 @@ public final class AutoConfiguredLoadBalancer { } /** - * Returns non-OK status if resolvedAddresses is empty and delegate lb requires address ({@link - * LoadBalancer#canHandleEmptyAddressListFromNameResolution()} returns {@code false}). {@code - * AutoConfiguredLoadBalancer} doesn't expose {@code - * canHandleEmptyAddressListFromNameResolution} because it depends on the delegated LB. + * Returns non-OK status if the delegate rejects the resolvedAddresses (e.g. if it does not + * support an empty list). */ - Status tryHandleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - List servers = resolvedAddresses.getAddresses(); - Attributes attributes = resolvedAddresses.getAttributes(); + boolean tryAcceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PolicySelection policySelection = (PolicySelection) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -118,7 +116,7 @@ Status tryHandleResolvedAddresses(ResolvedAddresses resolvedAddresses) { delegate.shutdown(); delegateProvider = null; delegate = new NoopLoadBalancer(); - return Status.OK; + return true; } policySelection = new PolicySelection(defaultProvider, /* config= */ null); @@ -141,20 +139,12 @@ Status tryHandleResolvedAddresses(ResolvedAddresses resolvedAddresses) { ChannelLogLevel.DEBUG, "Load-balancing config: {0}", policySelection.config); } - LoadBalancer delegate = getDelegate(); - if (resolvedAddresses.getAddresses().isEmpty() - && !delegate.canHandleEmptyAddressListFromNameResolution()) { - return Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. addrs=" + servers + ", attrs=" + attributes); - } else { - delegate.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(resolvedAddresses.getAddresses()) - .setAttributes(attributes) - .setLoadBalancingPolicyConfig(lbConfig) - .build()); - return Status.OK; - } + return getDelegate().acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(resolvedAddresses.getAddresses()) + .setAttributes(resolvedAddresses.getAttributes()) + .setLoadBalancingPolicyConfig(lbConfig) + .build()); } void handleNameResolutionError(Status error) { diff --git a/core/src/main/java/io/grpc/internal/BackoffPolicy.java b/core/src/main/java/io/grpc/internal/BackoffPolicy.java index cdca4a22606..c80ef9e1f9d 100644 --- a/core/src/main/java/io/grpc/internal/BackoffPolicy.java +++ b/core/src/main/java/io/grpc/internal/BackoffPolicy.java @@ -20,7 +20,7 @@ * Determines how long to wait before doing some action (typically a retry, or a reconnect). */ public interface BackoffPolicy { - public interface Provider { + interface Provider { BackoffPolicy get(); } diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 6b6472825d2..1537d1c664f 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2016 The gRPC Authors + * Copyright 2016,2022 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; import io.grpc.CompositeCallCredentials; +import io.grpc.InternalMayRequireSpecificExecutor; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -126,6 +127,11 @@ public ClientStream newStream( return method; } + @Override + public CallOptions getCallOptions() { + return callOptions; + } + @Override public SecurityLevel getSecurityLevel() { return firstNonNull( @@ -144,8 +150,21 @@ public Attributes getTransportAttrs() { } }; try { - creds.applyRequestMetadata( - requestInfo, firstNonNull(callOptions.getExecutor(), appExecutor), applier); + // Hack to allow appengine to work when using AppEngineCredentials (b/244209681) + // since processing must happen on a specific thread. + // + // Ideally would always use appExecutor and we could eliminate the interface + // InternalMayRequireSpecificExecutor + Executor executor; + if (creds instanceof InternalMayRequireSpecificExecutor + && ((InternalMayRequireSpecificExecutor)creds).isSpecificExecutorRequired() + && callOptions.getExecutor() != null) { + executor = callOptions.getExecutor(); + } else { + executor = appExecutor; + } + + creds.applyRequestMetadata(requestInfo, executor, applier); } catch (Throwable t) { applier.fail(Status.UNAUTHENTICATED .withDescription("Credentials should use fail() instead of throwing exceptions") diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index db1a992b968..000ede77057 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -72,6 +72,7 @@ final class ClientCallImpl extends ClientCall { private static final Logger log = Logger.getLogger(ClientCallImpl.class.getName()); private static final byte[] FULL_STREAM_DECOMPRESSION_ENCODINGS = "gzip".getBytes(Charset.forName("US-ASCII")); + private static final double NANO_TO_SECS = 1.0 * TimeUnit.SECONDS.toNanos(1); private final MethodDescriptor method; private final Tag tag; @@ -259,10 +260,12 @@ public void runInContext() { } else { ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers(callOptions, headers, 0, false); - stream = new FailingClientStream( - DEADLINE_EXCEEDED.withDescription( - "ClientCall started after deadline exceeded: " + effectiveDeadline), - tracers); + String deadlineName = + isFirstMin(callOptions.getDeadline(), context.getDeadline()) ? "CallOptions" : "Context"; + String description = String.format( + "ClientCall started after %s deadline was exceeded .9%f seconds ago", deadlineName, + effectiveDeadline.timeRemaining(TimeUnit.NANOSECONDS) / NANO_TO_SECS); + stream = new FailingClientStream(DEADLINE_EXCEEDED.withDescription(description), tracers); } if (callExecutorIsDirect) { @@ -358,12 +361,13 @@ private static void logIfContextNarrowedTimeout( long effectiveTimeout = max(0, effectiveDeadline.timeRemaining(TimeUnit.NANOSECONDS)); StringBuilder builder = new StringBuilder(String.format( + Locale.US, "Call timeout set to '%d' ns, due to context deadline.", effectiveTimeout)); if (callDeadline == null) { builder.append(" Explicit call timeout was not set."); } else { long callTimeout = callDeadline.timeRemaining(TimeUnit.NANOSECONDS); - builder.append(String.format(" Explicit call timeout was '%d' ns.", callTimeout)); + builder.append(String.format(Locale.US, " Explicit call timeout was '%d' ns.", callTimeout)); } log.fine(builder.toString()); @@ -430,6 +434,16 @@ private static Deadline min(@Nullable Deadline deadline0, @Nullable Deadline dea return deadline0.minimum(deadline1); } + private static boolean isFirstMin(@Nullable Deadline deadline0, @Nullable Deadline deadline1) { + if (deadline0 == null) { + return false; + } + if (deadline1 == null) { + return true; + } + return deadline0.isBefore(deadline1); + } + @Override public void request(int numMessages) { PerfMark.startTask("ClientCall.request", tag); diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index fbb24633d72..ee600e52d68 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -81,6 +81,17 @@ protected DelayedClientCall( initialDeadlineMonitor = scheduleDeadlineIfNeeded(scheduler, deadline); } + // If one argument is null, consider the other the "Before" + private boolean isAbeforeB(@Nullable Deadline a, @Nullable Deadline b) { + if (b == null) { + return true; + } else if (a == null) { + return false; + } + + return a.isBefore(b); + } + @Nullable private ScheduledFuture scheduleDeadlineIfNeeded( ScheduledExecutorService scheduler, @Nullable Deadline deadline) { @@ -90,35 +101,45 @@ private ScheduledFuture scheduleDeadlineIfNeeded( } long remainingNanos = Long.MAX_VALUE; if (deadline != null) { - remainingNanos = Math.min(remainingNanos, deadline.timeRemaining(NANOSECONDS)); + remainingNanos = deadline.timeRemaining(NANOSECONDS); } + if (contextDeadline != null && contextDeadline.timeRemaining(NANOSECONDS) < remainingNanos) { remainingNanos = contextDeadline.timeRemaining(NANOSECONDS); if (logger.isLoggable(Level.FINE)) { StringBuilder builder = new StringBuilder( String.format( + Locale.US, "Call timeout set to '%d' ns, due to context deadline.", remainingNanos)); if (deadline == null) { builder.append(" Explicit call timeout was not set."); } else { long callTimeout = deadline.timeRemaining(TimeUnit.NANOSECONDS); - builder.append(String.format(" Explicit call timeout was '%d' ns.", callTimeout)); + builder.append(String.format( + Locale.US, " Explicit call timeout was '%d' ns.", callTimeout)); } logger.fine(builder.toString()); } } + long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); final StringBuilder buf = new StringBuilder(); + String deadlineName = isAbeforeB(contextDeadline, deadline) ? "Context" : "CallOptions"; if (remainingNanos < 0) { - buf.append("ClientCall started after deadline exceeded. Deadline exceeded after -"); + buf.append("ClientCall started after "); + buf.append(deadlineName); + buf.append(" deadline was exceeded. Deadline has been exceeded for "); } else { - buf.append("Deadline exceeded after "); + buf.append("Deadline "); + buf.append(deadlineName); + buf.append(" will be exceeded in "); } buf.append(seconds); buf.append(String.format(Locale.US, ".%09d", nanos)); buf.append("s. "); + /** Cancels the call if deadline exceeded prior to the real call being set. */ class DeadlineExceededRunnable implements Runnable { @Override @@ -141,15 +162,20 @@ public void run() { *

No-op if either this method or {@link #cancel} have already been called. */ // When this method returns, passThrough is guaranteed to be true - public final void setCall(ClientCall call) { + public final Runnable setCall(ClientCall call) { synchronized (this) { // If realCall != null, then either setCall() or cancel() has been called. if (realCall != null) { - return; + return null; } setRealCall(checkNotNull(call, "call")); } - drainPendingCalls(); + return new ContextRunnable(context) { + @Override + public void runInContext() { + drainPendingCalls(); + } + }; } @Override diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolver.java b/core/src/main/java/io/grpc/internal/DnsNameResolver.java index 5418a0bd32d..5ef6dd863c2 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolver.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolver.java @@ -411,6 +411,7 @@ final int getPort() { } /** + * Parse TXT service config records as JSON. * * @throws IOException if one of the txt records contains improperly formatted JSON. */ diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java index 1c9290d2fc0..8078aa0d4c9 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java @@ -21,7 +21,11 @@ import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; /** * A provider for {@link DnsNameResolver}. @@ -75,4 +79,9 @@ protected boolean isAvailable() { public int priority() { return 5; } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index ab4bf7f7657..762ddd162b6 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -56,7 +56,10 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Collection; +import java.util.Collections; +import java.util.EnumSet; import java.util.List; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -75,6 +78,17 @@ public final class GrpcUtil { private static final Logger log = Logger.getLogger(GrpcUtil.class.getName()); + private static final Set INAPPROPRIATE_CONTROL_PLANE_STATUS + = Collections.unmodifiableSet(EnumSet.of( + Status.Code.OK, + Status.Code.INVALID_ARGUMENT, + Status.Code.NOT_FOUND, + Status.Code.ALREADY_EXISTS, + Status.Code.FAILED_PRECONDITION, + Status.Code.ABORTED, + Status.Code.OUT_OF_RANGE, + Status.Code.DATA_LOSS)); + public static final Charset US_ASCII = Charset.forName("US-ASCII"); /** @@ -203,7 +217,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - private static final String IMPLEMENTATION_VERSION = "1.45.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + private static final String IMPLEMENTATION_VERSION = "1.53.0"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. @@ -636,7 +650,7 @@ public static String getHost(InetSocketAddress addr) { /** * Marshals a nanoseconds representation of the timeout to and from a string representation, * consisting of an ASCII decimal representation of a number with at most 8 digits, followed by a - * unit: + * unit. Available units: * n = nanoseconds * u = microseconds * m = milliseconds @@ -747,10 +761,12 @@ public ListenableFuture getStats() { } if (!result.getStatus().isOk()) { if (result.isDrop()) { - return new FailingClientTransport(result.getStatus(), RpcProgress.DROPPED); + return new FailingClientTransport( + replaceInappropriateControlPlaneStatus(result.getStatus()), RpcProgress.DROPPED); } if (!isWaitForReady) { - return new FailingClientTransport(result.getStatus(), RpcProgress.PROCESSED); + return new FailingClientTransport( + replaceInappropriateControlPlaneStatus(result.getStatus()), RpcProgress.PROCESSED); } } return null; @@ -799,6 +815,25 @@ public static void closeQuietly(@Nullable Closeable message) { } } + /** Reads {@code in} until end of stream. */ + public static void exhaust(InputStream in) throws IOException { + byte[] buf = new byte[256]; + while (in.read(buf) != -1) {} + } + + /** + * Some status codes from the control plane are not appropritate to use in the data plane. If one + * is given it will be replaced with INTERNAL, indicating a bug in the control plane + * implementation. + */ + public static Status replaceInappropriateControlPlaneStatus(Status status) { + checkArgument(status != null); + return INAPPROPRIATE_CONTROL_PLANE_STATUS.contains(status.getCode()) + ? Status.INTERNAL.withDescription( + "Inappropriate status code from control plane: " + status.getCode() + " " + + status.getDescription()).withCause(status.getCause()) : status; + } + /** * Checks whether the given item exists in the iterable. This is copied from Guava Collect's * {@code Iterables.contains()} because Guava Collect is not Android-friendly thus core can't diff --git a/core/src/main/java/io/grpc/internal/InternalServer.java b/core/src/main/java/io/grpc/internal/InternalServer.java index 0445ae3dfab..a6079081233 100644 --- a/core/src/main/java/io/grpc/internal/InternalServer.java +++ b/core/src/main/java/io/grpc/internal/InternalServer.java @@ -50,7 +50,7 @@ public interface InternalServer { void shutdown(); /** - * Returns the first listening socket address. May change after {@link start(ServerListener)} is + * Returns the first listening socket address. May change after {@link #start(ServerListener)} is * called. */ SocketAddress getListenSocketAddress(); @@ -61,7 +61,7 @@ public interface InternalServer { @Nullable InternalInstrumented getListenSocketStats(); /** - * Returns a list of listening socket addresses. May change after {@link start(ServerListener)} + * Returns a list of listening socket addresses. May change after {@link #start(ServerListener)} * is called. */ List getListenSocketAddresses(); diff --git a/core/src/main/java/io/grpc/internal/JsonUtil.java b/core/src/main/java/io/grpc/internal/JsonUtil.java index 117135fe634..65f7cf5649e 100644 --- a/core/src/main/java/io/grpc/internal/JsonUtil.java +++ b/core/src/main/java/io/grpc/internal/JsonUtil.java @@ -20,6 +20,7 @@ import java.text.ParseException; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; @@ -244,7 +245,8 @@ public static Boolean getBoolean(Map obj, String key) { for (int i = 0; i < rawList.size(); i++) { if (!(rawList.get(i) instanceof Map)) { throw new ClassCastException( - String.format("value %s for idx %d in %s is not object", rawList.get(i), i, rawList)); + String.format( + Locale.US, "value %s for idx %d in %s is not object", rawList.get(i), i, rawList)); } } return (List>) rawList; @@ -260,6 +262,7 @@ public static List checkStringList(List rawList) { if (!(rawList.get(i) instanceof String)) { throw new ClassCastException( String.format( + Locale.US, "value '%s' for idx %d in '%s' is not string", rawList.get(i), i, rawList)); } } diff --git a/netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java similarity index 94% rename from netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java rename to core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java index 6470e440327..dd539e75a18 100644 --- a/netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java +++ b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.netty; +package io.grpc.internal; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -22,11 +22,11 @@ import javax.annotation.CheckReturnValue; /** Monitors the client's PING usage to make sure the rate is permitted. */ -class KeepAliveEnforcer { +public final class KeepAliveEnforcer { @VisibleForTesting - static final int MAX_PING_STRIKES = 2; + public static final int MAX_PING_STRIKES = 2; @VisibleForTesting - static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2); + public static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2); private final boolean permitWithoutCalls; private final long minTimeNanos; diff --git a/core/src/main/java/io/grpc/internal/LogExceptionRunnable.java b/core/src/main/java/io/grpc/internal/LogExceptionRunnable.java index 546e72911ca..f4c6f2fc8ef 100644 --- a/core/src/main/java/io/grpc/internal/LogExceptionRunnable.java +++ b/core/src/main/java/io/grpc/internal/LogExceptionRunnable.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.Throwables; - import java.util.logging.Level; import java.util.logging.Logger; diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 601c7740ca4..2606083db94 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -349,6 +349,9 @@ private class IdleModeTimer implements Runnable { @Override public void run() { + // Workaround timer scheduled while in idle mode. This can happen from handleNotInUse() after + // an explicit enterIdleMode() by the user. Protecting here as other locations are a bit too + // subtle to change rapidly to resolve the channel panic. See #8714 if (lbHelper == null) { return; } @@ -491,6 +494,8 @@ private void refreshNameResolution() { } private final class ChannelStreamProvider implements ClientStreamProvider { + volatile Throttle throttle; + private ClientTransport getTransport(PickSubchannelArgs args) { SubchannelPicker pickerCopy = subchannelPicker; if (shutdown.get()) { @@ -546,7 +551,6 @@ public ClientStream newStream( context.detach(origContext); } } else { - final Throttle throttle = lastServiceConfig.getRetryThrottling(); MethodInfo methodInfo = callOptions.getOption(MethodInfo.KEY); final RetryPolicy retryPolicy = methodInfo == null ? null : methodInfo.retryPolicy; final HedgingPolicy hedgingPolicy = methodInfo == null ? null : methodInfo.hedgingPolicy; @@ -599,7 +603,7 @@ ClientStream newSubstream( } } - private final ClientStreamProvider transportProvider = new ChannelStreamProvider(); + private final ChannelStreamProvider transportProvider = new ChannelStreamProvider(); private final Rescheduler idleTimer; @@ -618,10 +622,12 @@ ClientStream newSubstream( this.executor = checkNotNull(executorPool.getObject(), "executor"); this.originalChannelCreds = builder.channelCredentials; this.originalTransportFactory = clientTransportFactory; + this.offloadExecutorHolder = + new ExecutorHolder(checkNotNull(builder.offloadExecutorPool, "offloadExecutorPool")); this.transportFactory = new CallCredentialsApplyingTransportFactory( - clientTransportFactory, builder.callCredentials, this.executor); + clientTransportFactory, builder.callCredentials, this.offloadExecutorHolder); this.oobTransportFactory = new CallCredentialsApplyingTransportFactory( - clientTransportFactory, null, this.executor); + clientTransportFactory, null, this.offloadExecutorHolder); this.scheduledExecutor = new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService()); maxTraceEvents = builder.maxTraceEvents; @@ -633,9 +639,6 @@ ClientStream newSubstream( builder.proxyDetector != null ? builder.proxyDetector : GrpcUtil.DEFAULT_PROXY_DETECTOR; this.retryEnabled = builder.retryEnabled; this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy); - this.offloadExecutorHolder = - new ExecutorHolder( - checkNotNull(builder.offloadExecutorPool, "offloadExecutorPool")); this.nameResolverRegistry = builder.nameResolverRegistry; ScParser serviceConfigParser = new ScParser( @@ -643,6 +646,7 @@ ClientStream newSubstream( builder.maxRetryAttempts, builder.maxHedgedAttempts, loadBalancerFactory); + this.authorityOverride = builder.authorityOverride; this.nameResolverArgs = NameResolver.Args.newBuilder() .setDefaultPort(builder.getDefaultPort()) @@ -651,16 +655,9 @@ ClientStream newSubstream( .setScheduledExecutorService(scheduledExecutor) .setServiceConfigParser(serviceConfigParser) .setChannelLogger(channelLogger) - .setOffloadExecutor( - // Avoid creating the offloadExecutor until it is first used - new Executor() { - @Override - public void execute(Runnable command) { - offloadExecutorHolder.getExecutor().execute(command); - } - }) + .setOffloadExecutor(this.offloadExecutorHolder) + .setOverrideAuthority(this.authorityOverride) .build(); - this.authorityOverride = builder.authorityOverride; this.nameResolverFactory = builder.nameResolverFactory; this.nameResolver = getNameResolver( target, authorityOverride, nameResolverFactory, nameResolverArgs); @@ -884,6 +881,7 @@ public String toString() { } updateSubchannelPicker(new PanicSubchannelPicker()); + realChannel.updateConfigSelector(null); channelLogger.log(ChannelLogLevel.ERROR, "PANIC! Entering TRANSIENT_FAILURE"); channelStateManager.gotoState(TRANSIENT_FAILURE); } @@ -1099,22 +1097,25 @@ private final class PendingCall extends DelayedClientCall realCall; - Context previous = context.attach(); - try { - realCall = newClientCall(method, callOptions); - } finally { - context.detach(previous); - } - setCall(realCall); - syncContext.execute(new PendingCallRemoval()); - } + ClientCall realCall; + Context previous = context.attach(); + try { + realCall = newClientCall(method, callOptions); + } finally { + context.detach(previous); + } + Runnable toRun = setCall(realCall); + if (toRun == null) { + syncContext.execute(new PendingCallRemoval()); + } else { + getCallExecutor(callOptions).execute(new Runnable() { + @Override + public void run() { + toRun.run(); + syncContext.execute(new PendingCallRemoval()); } - ); + }); + } } @Override @@ -1199,7 +1200,8 @@ public void start(Listener observer, Metadata headers) { InternalConfigSelector.Result result = configSelector.selectConfig(args); Status status = result.getStatus(); if (!status.isOk()) { - executeCloseObserverInContext(observer, status); + executeCloseObserverInContext(observer, + GrpcUtil.replaceInappropriateControlPlaneStatus(status)); delegate = (ClientCall) NOOP_CALL; return; } @@ -1453,15 +1455,13 @@ void remove(RetriableStream retriableStream) { private final class LbHelperImpl extends LoadBalancer.Helper { AutoConfiguredLoadBalancer lb; - boolean nsRefreshedByLb; - boolean ignoreRefreshNsCheck; @Override public AbstractSubchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); // No new subchannel should be created after load balancer has been shutdown. checkState(!terminating, "Channel is being terminated"); - return new SubchannelImpl(args, this); + return new SubchannelImpl(args); } @Override @@ -1493,7 +1493,6 @@ public void run() { @Override public void refreshNameResolution() { syncContext.throwIfNotInThisSynchronizationContext(); - nsRefreshedByLb = true; final class LoadBalancerRefreshNameResolution implements Runnable { @Override public void run() { @@ -1504,11 +1503,6 @@ public void run() { syncContext.execute(new LoadBalancerRefreshNameResolution()); } - @Override - public void ignoreRefreshNameResolutionCheck() { - ignoreRefreshNsCheck = true; - } - @Override public ManagedChannel createOobChannel(EquivalentAddressGroup addressGroup, String authority) { return createOobChannel(Collections.singletonList(addressGroup), authority); @@ -1749,6 +1743,9 @@ final class NamesResolved implements Runnable { @SuppressWarnings("ReferenceEquality") @Override public void run() { + if (ManagedChannelImpl.this.nameResolver != resolver) { + return; + } List servers = resolutionResult.getAddresses(); channelLogger.log( @@ -1815,6 +1812,10 @@ public void run() { channelLogger.log( ChannelLogLevel.INFO, "Fallback to error due to invalid first service config without default config"); + // This error could be an "inappropriate" control plane error that should not bleed + // through to client code using gRPC. We let them flow through here to the LB as + // we later check for these error codes when investigating pick results in + // GrpcUtil.getTransportFromPickResult(). onError(configOrError.getError()); return; } else { @@ -1830,6 +1831,7 @@ public void run() { "Service config changed{0}", effectiveServiceConfig == EMPTY_SERVICE_CONFIG ? " to empty" : ""); lastServiceConfig = effectiveServiceConfig; + transportProvider.throttle = effectiveServiceConfig.getRetryThrottling(); } try { @@ -1857,16 +1859,17 @@ public void run() { .set(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG, healthCheckingConfig) .build(); } + Attributes attributes = attrBuilder.build(); - Status handleResult = helper.lb.tryHandleResolvedAddresses( + boolean addressesAccepted = helper.lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) - .setAttributes(attrBuilder.build()) + .setAttributes(attributes) .setLoadBalancingPolicyConfig(effectiveServiceConfig.getLoadBalancingConfig()) .build()); - if (!handleResult.isOk()) { - handleErrorInSyncContext(handleResult.augmentDescription(resolver + " was used")); + if (!addressesAccepted) { + scheduleExponentialBackOffInSyncContext(); } } } @@ -1930,7 +1933,6 @@ private void scheduleExponentialBackOffInSyncContext() { private final class SubchannelImpl extends AbstractSubchannel { final CreateSubchannelArgs args; - final LbHelperImpl helper; final InternalLogId subchannelLogId; final ChannelLoggerImpl subchannelLogger; final ChannelTracer subchannelTracer; @@ -1940,15 +1942,15 @@ private final class SubchannelImpl extends AbstractSubchannel { boolean shutdown; ScheduledHandle delayedShutdownTask; - SubchannelImpl(CreateSubchannelArgs args, LbHelperImpl helper) { + SubchannelImpl(CreateSubchannelArgs args) { + checkNotNull(args, "args"); addressGroups = args.getAddresses(); if (authorityOverride != null) { List eagsWithoutOverrideAttr = stripOverrideAuthorityAttributes(args.getAddresses()); args = args.toBuilder().setAddresses(eagsWithoutOverrideAttr).build(); } - this.args = checkNotNull(args, "args"); - this.helper = checkNotNull(helper, "helper"); + this.args = args; subchannelLogId = InternalLogId.allocate("Subchannel", /*details=*/ authority()); subchannelTracer = new ChannelTracer( subchannelLogId, maxTraceEvents, timeProvider.currentTimeNanos(), @@ -1976,16 +1978,6 @@ void onTerminated(InternalSubchannel is) { void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) { checkState(listener != null, "listener is null"); listener.onSubchannelState(newState); - if (newState.getState() == TRANSIENT_FAILURE || newState.getState() == IDLE) { - if (!helper.ignoreRefreshNsCheck && !helper.nsRefreshedByLb) { - logger.log(Level.WARNING, - "LoadBalancer should call Helper.refreshNameResolution() to refresh name " - + "resolution if subchannel state becomes TRANSIENT_FAILURE or IDLE. " - + "This will no longer happen automatically in the future releases"); - refreshAndResetNameResolution(); - helper.nsRefreshedByLb = true; - } - } } @Override @@ -2209,8 +2201,10 @@ protected void handleNotInUse() { /** * Lazily request for Executor from an executor pool. + * Also act as an Executor directly to simply run a cmd */ - private static final class ExecutorHolder { + @VisibleForTesting + static final class ExecutorHolder implements Executor { private final ObjectPool pool; private Executor executor; @@ -2230,6 +2224,11 @@ synchronized void release() { executor = pool.returnObject(executor); } } + + @Override + public void execute(Runnable command) { + getExecutor().execute(command); + } } private static final class RestrictedScheduledExecutor implements ScheduledExecutorService { @@ -2288,7 +2287,7 @@ public T invokeAny(Collection> tasks) @Override public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) - throws InterruptedException, ExecutionException, TimeoutException { + throws InterruptedException, ExecutionException, TimeoutException { return delegate.invokeAny(tasks, timeout, unit); } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 243a555ad1a..536216b20a9 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -31,6 +31,7 @@ import io.grpc.DecompressorRegistry; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalChannelz; +import io.grpc.InternalGlobalInterceptors; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.NameResolver; @@ -280,8 +281,8 @@ static String makeTargetStringForDirectAddress(SocketAddress address) { public ManagedChannelImplBuilder(SocketAddress directServerAddress, String authority, ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { - this(directServerAddress, authority, null, null, clientTransportFactoryBuilder, - channelBuilderDefaultPortProvider); + this(directServerAddress, authority, null, null, clientTransportFactoryBuilder, + channelBuilderDefaultPortProvider); } /** @@ -636,9 +637,15 @@ public ManagedChannel build() { // TODO(zdapeng): FIX IT @VisibleForTesting List getEffectiveInterceptors() { - List effectiveInterceptors = - new ArrayList<>(this.interceptors); - if (statsEnabled) { + List effectiveInterceptors = new ArrayList<>(this.interceptors); + boolean isGlobalInterceptorsSet = false; + List globalClientInterceptors = + InternalGlobalInterceptors.getClientInterceptors(); + if (globalClientInterceptors != null) { + effectiveInterceptors.addAll(globalClientInterceptors); + isGlobalInterceptorsSet = true; + } + if (!isGlobalInterceptorsSet && statsEnabled) { ClientInterceptor statsInterceptor = null; try { Class censusStatsAccessor = @@ -674,7 +681,7 @@ List getEffectiveInterceptors() { effectiveInterceptors.add(0, statsInterceptor); } } - if (tracingEnabled) { + if (!isGlobalInterceptorsSet && tracingEnabled) { ClientInterceptor tracingInterceptor = null; try { Class censusTracingAccessor = diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java index 542e84b9c8b..aed3a461fb4 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java @@ -29,6 +29,13 @@ import java.util.logging.LogRecord; import java.util.logging.Logger; +/** + * Best effort detecting channels that has not been properly cleaned up. + * Use {@link WeakReference} to avoid keeping the channel alive and retaining too much memory. + * Check lost references only on new channel creation and log message to indicate + * the previous channel (id and target) that has not been shutdown. This is done to avoid Object + * finalizers. + */ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel { private static final ReferenceQueue refqueue = new ReferenceQueue<>(); @@ -148,7 +155,7 @@ static int cleanQueue(ReferenceQueue refqueue) { Level level = Level.SEVERE; if (logger.isLoggable(level)) { String fmt = - "*~*~*~ Channel {0} was not shutdown properly!!! ~*~*~*" + "*~*~*~ Previous channel {0} was not shutdown properly!!! ~*~*~*" + System.getProperty("line.separator") + " Make sure to call shutdown()/shutdownNow() and wait " + "until awaitTermination() returns true."; diff --git a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java index 47cf53a7251..d38721af78d 100644 --- a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java @@ -94,6 +94,8 @@ interface Listener { /** * The transport is ready to accept traffic, because the connection is established. This is * called at most once. + * + *

Streams created before this milestone are not guaranteed to function. */ void transportReady(); diff --git a/netty/src/main/java/io/grpc/netty/MaxConnectionIdleManager.java b/core/src/main/java/io/grpc/internal/MaxConnectionIdleManager.java similarity index 77% rename from netty/src/main/java/io/grpc/netty/MaxConnectionIdleManager.java rename to core/src/main/java/io/grpc/internal/MaxConnectionIdleManager.java index 964ae44b178..4d4a36dda01 100644 --- a/netty/src/main/java/io/grpc/netty/MaxConnectionIdleManager.java +++ b/core/src/main/java/io/grpc/internal/MaxConnectionIdleManager.java @@ -14,11 +14,9 @@ * limitations under the License. */ -package io.grpc.netty; +package io.grpc.internal; import com.google.common.annotations.VisibleForTesting; -import io.grpc.internal.LogExceptionRunnable; -import io.netty.channel.ChannelHandlerContext; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -27,7 +25,7 @@ /** * Monitors connection idle time; shutdowns the connection if the max connection idle is reached. */ -abstract class MaxConnectionIdleManager { +public final class MaxConnectionIdleManager { private static final Ticker systemTicker = new Ticker() { @Override public long nanoTime() { @@ -46,23 +44,23 @@ public long nanoTime() { private boolean shutdownDelayed; private boolean isActive; - MaxConnectionIdleManager(long maxConnectionIdleInNanos) { + public MaxConnectionIdleManager(long maxConnectionIdleInNanos) { this(maxConnectionIdleInNanos, systemTicker); } @VisibleForTesting - MaxConnectionIdleManager(long maxConnectionIdleInNanos, Ticker ticker) { + public MaxConnectionIdleManager(long maxConnectionIdleInNanos, Ticker ticker) { this.maxConnectionIdleInNanos = maxConnectionIdleInNanos; this.ticker = ticker; } - /** A {@link NettyServerHandler} was added to the transport. */ - void start(ChannelHandlerContext ctx) { - start(ctx, ctx.executor()); - } - - @VisibleForTesting - void start(final ChannelHandlerContext ctx, final ScheduledExecutorService scheduler) { + /** + * Start the initial scheduled shutdown given the transport status reaches max connection idle. + * + * @param closeJob Closes the connection by sending GO_AWAY with status code NO_ERROR and ASCII + * debug data max_idle and then doing the graceful connection termination. + */ + public void start(final Runnable closeJob, final ScheduledExecutorService scheduler) { this.scheduler = scheduler; nextIdleMonitorTime = ticker.nanoTime() + maxConnectionIdleInNanos; @@ -78,7 +76,7 @@ public void run() { } // if isActive, exit. Will schedule a new shutdownFuture once onTransportIdle } else { - close(ctx); + closeJob.run(); shutdownFuture = null; } } @@ -88,20 +86,15 @@ public void run() { scheduler.schedule(shutdownTask, maxConnectionIdleInNanos, TimeUnit.NANOSECONDS); } - /** - * Closes the connection by sending GO_AWAY with status code NO_ERROR and ASCII debug data - * max_idle and then doing the graceful connection termination. - */ - abstract void close(ChannelHandlerContext ctx); /** There are outstanding RPCs on the transport. */ - void onTransportActive() { + public void onTransportActive() { isActive = true; shutdownDelayed = true; } /** There are no outstanding RPCs on the transport. */ - void onTransportIdle() { + public void onTransportIdle() { isActive = false; if (shutdownFuture == null) { return; @@ -116,7 +109,7 @@ void onTransportIdle() { } /** Transport is being terminated. */ - void onTransportTermination() { + public void onTransportTermination() { if (shutdownFuture != null) { shutdownFuture.cancel(false); shutdownFuture = null; @@ -124,7 +117,7 @@ void onTransportTermination() { } @VisibleForTesting - interface Ticker { + public interface Ticker { long nanoTime(); } } diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index 534398315e8..f6da1d0a670 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -28,6 +28,7 @@ import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.Locale; import java.util.zip.DataFormatException; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; @@ -386,7 +387,7 @@ private void processHeader() { requiredLength = nextFrame.readInt(); if (requiredLength < 0 || requiredLength > maxInboundMessageSize) { throw Status.RESOURCE_EXHAUSTED.withDescription( - String.format("gRPC message exceeds maximum size %d: %d", + String.format(Locale.US, "gRPC message exceeds maximum size %d: %d", maxInboundMessageSize, requiredLength)) .asRuntimeException(); } @@ -516,9 +517,9 @@ private void reportCount() { private void verifySize() { if (count > maxMessageSize) { - throw Status.RESOURCE_EXHAUSTED.withDescription(String.format( - "Decompressed gRPC message exceeds maximum size %d", - maxMessageSize)).asRuntimeException(); + throw Status.RESOURCE_EXHAUSTED + .withDescription("Decompressed gRPC message exceeds maximum size " + maxMessageSize) + .asRuntimeException(); } } } diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 2042bddca03..93d35250a0f 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -34,6 +34,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import javax.annotation.Nullable; /** @@ -172,7 +173,8 @@ private int writeUncompressed(InputStream message, int messageLength) throws IOE if (maxOutboundMessageSize >= 0 && written > maxOutboundMessageSize) { throw Status.RESOURCE_EXHAUSTED .withDescription( - String.format("message too large %d > %d", written , maxOutboundMessageSize)) + String.format( + Locale.US, "message too large %d > %d", written , maxOutboundMessageSize)) .asRuntimeException(); } writeBufferChain(bufferChain, false); @@ -192,7 +194,8 @@ private int writeCompressed(InputStream message, int unusedMessageLength) throws if (maxOutboundMessageSize >= 0 && written > maxOutboundMessageSize) { throw Status.RESOURCE_EXHAUSTED .withDescription( - String.format("message too large %d > %d", written , maxOutboundMessageSize)) + String.format( + Locale.US, "message too large %d > %d", written , maxOutboundMessageSize)) .asRuntimeException(); } @@ -215,7 +218,8 @@ private int writeKnownLengthUncompressed(InputStream message, int messageLength) if (maxOutboundMessageSize >= 0 && messageLength > maxOutboundMessageSize) { throw Status.RESOURCE_EXHAUSTED .withDescription( - String.format("message too large %d > %d", messageLength , maxOutboundMessageSize)) + String.format( + Locale.US, "message too large %d > %d", messageLength , maxOutboundMessageSize)) .asRuntimeException(); } headerScratch.clear(); diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 6893713c1d2..12cab15053f 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -82,7 +82,8 @@ public void apply(Metadata headers) { public void fail(Status status) { checkArgument(!status.isOk(), "Cannot fail with OK status"); checkState(!finalized, "apply() or fail() already called"); - finalizeWith(new FailingClientStream(status, tracers)); + finalizeWith( + new FailingClientStream(GrpcUtil.replaceInappropriateControlPlaneStatus(status), tracers)); } private void finalizeWith(ClientStream stream) { diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index d5f74db54a7..e9c4d79150a 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -45,8 +45,15 @@ final class PickFirstLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { List servers = resolvedAddresses.getAddresses(); + if (servers.isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } + if (subchannel == null) { final Subchannel subchannel = helper.createSubchannel( CreateSubchannelArgs.newBuilder() @@ -67,6 +74,8 @@ public void onSubchannelState(ConnectivityStateInfo stateInfo) { } else { subchannel.updateAddresses(servers); } + + return true; } @Override diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java index 0b11af94e2d..7f7b366564e 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java @@ -18,6 +18,7 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import java.util.Map; diff --git a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java index 3e7dd010e22..c69ee25ef31 100644 --- a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java +++ b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java @@ -133,7 +133,7 @@ public PasswordAuthentication requestPasswordAuthentication( // let url be null log.log( Level.WARNING, - String.format("failed to create URL for Authenticator: %s %s", protocol, host)); + "failed to create URL for Authenticator: {0} {1}", new Object[] {protocol, host}); } // TODO(spencerfang): consider using java.security.AccessController here return Authenticator.requestPasswordAuthentication( @@ -150,6 +150,8 @@ public ProxySelector get() { }; /** + * Experimental environment variable name for enabling proxy support. + * * @deprecated Use the standard Java proxy configuration instead with flags such as: * -Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT */ diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index d0430b5edbd..ab85ec1bdb5 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -107,6 +107,8 @@ public void uncaughtException(Thread t, Throwable e) { */ private final AtomicBoolean noMoreTransparentRetry = new AtomicBoolean(); private final AtomicInteger localOnlyTransparentRetries = new AtomicInteger(); + private final AtomicInteger inFlightSubStreams = new AtomicInteger(); + private SavedCloseMasterListenerReason savedCloseMasterListenerReason; // Used for recording the share of buffer used for the current call out of the channel buffer. // This field would not be necessary if there is no channel buffer limit. @@ -220,7 +222,17 @@ private void commitAndRun(Substream winningSubstream) { } } + // returns null means we should not create new sub streams, e.g. cancelled or + // other close condition is met for retriableStream. + @Nullable private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { + int inFlight; + do { + inFlight = inFlightSubStreams.get(); + if (inFlight < 0) { + return null; + } + } while (!inFlightSubStreams.compareAndSet(inFlight, inFlight + 1)); Substream sub = new Substream(previousAttemptCount); // one tracer per substream final ClientStreamTracer bufferSizeTracer = new BufferSizeTracer(sub); @@ -367,6 +379,9 @@ public final void start(ClientStreamListener listener) { } Substream substream = createSubstream(0, false); + if (substream == null) { + return; + } if (isHedging) { FutureCanceller scheduledHedgingRef = null; @@ -434,16 +449,19 @@ private final class HedgingRunnable implements Runnable { @Override public void run() { + // It's safe to read state.hedgingAttemptCount here. + // If this run is not cancelled, the value of state.hedgingAttemptCount won't change + // until state.addActiveHedge() is called subsequently, even the state could possibly + // change. + Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); + if (newSubstream == null) { + return; + } callExecutor.execute( new Runnable() { @SuppressWarnings("GuardedBy") @Override public void run() { - // It's safe to read state.hedgingAttemptCount here. - // If this run is not cancelled, the value of state.hedgingAttemptCount won't change - // until state.addActiveHedge() is called subsequently, even the state could possibly - // change. - Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); boolean cancelled = false; FutureCanceller future = null; @@ -490,15 +508,7 @@ public final void cancel(final Status reason) { if (runnable != null) { runnable.run(); - listenerSerializeExecutor.execute( - new Runnable() { - @Override - public void run() { - isClosed = true; - masterListener.closed(reason, RpcProgress.PROCESSED, new Metadata()); - - } - }); + safeCloseMasterListener(reason, RpcProgress.PROCESSED, new Metadata()); return; } @@ -550,6 +560,10 @@ class SendMessageEntry implements BufferEntry { @Override public void runWith(Substream substream) { substream.stream.writeMessage(method.streamRequest(message)); + // TODO(ejona): Workaround Netty memory leak. Message writes always need to be followed by + // flushes (or half close), but retry appears to have a code path that the flushes may + // not happen. The code needs to be fixed and this removed. See #9340. + substream.stream.flush(); } } @@ -799,6 +813,33 @@ private void freezeHedging() { } } + private void safeCloseMasterListener(Status status, RpcProgress progress, Metadata metadata) { + savedCloseMasterListenerReason = new SavedCloseMasterListenerReason(status, progress, + metadata); + if (inFlightSubStreams.addAndGet(Integer.MIN_VALUE) == Integer.MIN_VALUE) { + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(status, progress, metadata); + } + }); + } + } + + private static final class SavedCloseMasterListenerReason { + private final Status status; + private final RpcProgress progress; + private final Metadata metadata; + + SavedCloseMasterListenerReason(Status status, RpcProgress progress, Metadata metadata) { + this.status = status; + this.progress = progress; + this.metadata = metadata; + } + } + private interface BufferEntry { /** Replays the buffer entry with the given stream. */ void runWith(Substream substream); @@ -813,6 +854,10 @@ private final class Sublistener implements ClientStreamListener { @Override public void headersRead(final Metadata headers) { + if (substream.previousAttemptCount > 0) { + headers.discardAll(GRPC_PREVIOUS_RPC_ATTEMPTS); + headers.put(GRPC_PREVIOUS_RPC_ATTEMPTS, String.valueOf(substream.previousAttemptCount)); + } commitAndRun(substream); if (state.winningSubstream == substream) { if (throttle != null) { @@ -836,37 +881,38 @@ public void closed( closedSubstreamsInsight.append(status.getCode()); } + if (inFlightSubStreams.decrementAndGet() == Integer.MIN_VALUE) { + assert savedCloseMasterListenerReason != null; + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(savedCloseMasterListenerReason.status, + savedCloseMasterListenerReason.progress, + savedCloseMasterListenerReason.metadata); + } + }); + return; + } + // handle a race between buffer limit exceeded and closed, when setting // substream.bufferLimitExceeded = true happens before state.substreamClosed(substream). if (substream.bufferLimitExceeded) { commitAndRun(substream); if (state.winningSubstream == substream) { - listenerSerializeExecutor.execute( - new Runnable() { - @Override - public void run() { - isClosed = true; - masterListener.closed(status, rpcProgress, trailers); - } - }); + safeCloseMasterListener(status, rpcProgress, trailers); } return; } if (rpcProgress == RpcProgress.MISCARRIED - && localOnlyTransparentRetries.incrementAndGet() > 10_000) { + && localOnlyTransparentRetries.incrementAndGet() > 1_000) { commitAndRun(substream); if (state.winningSubstream == substream) { Status tooManyTransparentRetries = Status.INTERNAL .withDescription("Too many transparent retries. Might be a bug in gRPC") .withCause(status.asRuntimeException()); - listenerSerializeExecutor.execute( - new Runnable() { - @Override - public void run() { - isClosed = true; - masterListener.closed(tooManyTransparentRetries, rpcProgress, trailers); - } - }); + safeCloseMasterListener(tooManyTransparentRetries, rpcProgress, trailers); } return; } @@ -877,6 +923,9 @@ public void run() { && noMoreTransparentRetry.compareAndSet(false, true))) { // transparent retry final Substream newSubstream = createSubstream(substream.previousAttemptCount, true); + if (newSubstream == null) { + return; + } if (isHedging) { boolean commit = false; synchronized (lock) { @@ -938,6 +987,11 @@ public void run() { } else { RetryPlan retryPlan = makeRetryDecision(status, trailers); if (retryPlan.shouldRetry) { + // retry + Substream newSubstream = createSubstream(substream.previousAttemptCount + 1, false); + if (newSubstream == null) { + return; + } // The check state.winningSubstream == null, checking if is not already committed, is // racy, but is still safe b/c the retry will also handle committed/cancellation FutureCanceller scheduledRetryCopy; @@ -951,10 +1005,6 @@ public void run() { new Runnable() { @Override public void run() { - // retry - Substream newSubstream = createSubstream( - substream.previousAttemptCount + 1, - false); drain(newSubstream); } }); @@ -974,14 +1024,7 @@ public void run() { commitAndRun(substream); if (state.winningSubstream == substream) { - listenerSerializeExecutor.execute( - new Runnable() { - @Override - public void run() { - isClosed = true; - masterListener.closed(status, rpcProgress, trailers); - } - }); + safeCloseMasterListener(status, rpcProgress, trailers); } } @@ -1056,6 +1099,7 @@ public void messagesAvailable(final MessageProducer producer) { checkState( savedState.winningSubstream != null, "Headers should be received prior to messages."); if (savedState.winningSubstream != substream) { + GrpcUtil.closeQuietly(producer); return; } listenerSerializeExecutor.execute( diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index b31aadd08a9..47ffdf9caaf 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.internal.GrpcAttributes.ATTR_SECURITY_LEVEL; import static io.grpc.internal.GrpcUtil.ACCEPT_ENCODING_SPLITTER; import static io.grpc.internal.GrpcUtil.CONTENT_LENGTH_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; @@ -34,8 +35,10 @@ import io.grpc.Context; import io.grpc.DecompressorRegistry; import io.grpc.InternalDecompressorRegistry; +import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; import io.grpc.ServerCall; import io.grpc.Status; import io.perfmark.PerfMark; @@ -167,7 +170,9 @@ private void sendMessageInternal(RespT message) { try { InputStream resp = method.streamResponse(message); stream.writeMessage(resp); - stream.flush(); + if (!getMethodDescriptor().getType().serverSendsOneMessage()) { + stream.flush(); + } } catch (RuntimeException e) { close(Status.fromThrowable(e), new Metadata()); } catch (Error e) { @@ -250,6 +255,16 @@ public MethodDescriptor getMethodDescriptor() { return method; } + @Override + public SecurityLevel getSecurityLevel() { + final Attributes attributes = getAttributes(); + if (attributes == null) { + return super.getSecurityLevel(); + } + final SecurityLevel securityLevel = attributes.get(ATTR_SECURITY_LEVEL); + return securityLevel == null ? super.getSecurityLevel() : securityLevel; + } + /** * Close the {@link ServerStream} because an internal error occurred. Allow the application to * run until completion, but silently ignore interactions with the {@link ServerStream} from now @@ -354,19 +369,22 @@ public void closed(Status status) { } private void closedInternal(Status status) { + Throwable cancelCause = null; try { if (status.isOk()) { listener.onComplete(); } else { call.cancelled = true; listener.onCancel(); + // The status will not have a cause in all failure scenarios but we want to make sure + // we always cancel the context with one to keep the context cancelled state consistent. + cancelCause = InternalStatus.asRuntimeException( + Status.CANCELLED.withDescription("RPC cancelled"), null, false); } } finally { // Cancel context after delivering RPC closure notification to allow the application to // clean up and update any state based on whether onComplete or onCancel was called. - // Note that in failure situations JumpToApplicationThreadServerStreamListener has already - // closed the context. In these situations this cancel() call will be a no-op. - context.cancel(null); + context.cancel(cancelCause); } } diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 6bfe2d38ab3..bbd52c14bba 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -45,6 +45,7 @@ import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalServerInterceptors; +import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallExecutorSupplier; @@ -894,9 +895,18 @@ private void closedInternal(final Status status) { // For cancellations, promptly inform any users of the context that their work should be // aborted. Otherwise, we can wait until pending work is done. if (!status.isOk()) { + // Since status was not OK we know that the call did not complete and got cancelled. To + // reflect this on the context we need to close it with a cause exception. Since not every + // failed status has an exception we will create one here if needed. + Throwable cancelCause = status.getCause(); + if (cancelCause == null) { + cancelCause = InternalStatus.asRuntimeException( + Status.CANCELLED.withDescription("RPC cancelled"), null, false); + } + // The callExecutor might be busy doing user work. To avoid waiting, use an executor that // is not serializing. - cancelExecutor.execute(new ContextCloser(context, status.getCause())); + cancelExecutor.execute(new ContextCloser(context, cancelCause)); } final Link link = PerfMark.linkOut(); diff --git a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java index 277e476143d..cd18457d51b 100644 --- a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java @@ -30,6 +30,7 @@ import io.grpc.DecompressorRegistry; import io.grpc.HandlerRegistry; import io.grpc.InternalChannelz; +import io.grpc.InternalGlobalInterceptors; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerCallExecutorSupplier; @@ -246,7 +247,17 @@ public Server build() { @VisibleForTesting List getTracerFactories() { ArrayList tracerFactories = new ArrayList<>(); - if (statsEnabled) { + boolean isGlobalInterceptorsTracersSet = false; + List globalServerInterceptors + = InternalGlobalInterceptors.getServerInterceptors(); + List globalServerStreamTracerFactories + = InternalGlobalInterceptors.getServerStreamTracerFactories(); + if (globalServerInterceptors != null) { + tracerFactories.addAll(globalServerStreamTracerFactories); + interceptors.addAll(globalServerInterceptors); + isGlobalInterceptorsTracersSet = true; + } + if (!isGlobalInterceptorsTracersSet && statsEnabled) { ServerStreamTracer.Factory censusStatsTracerFactory = null; try { Class censusStatsAccessor = @@ -278,7 +289,7 @@ List getTracerFactories() { tracerFactories.add(censusStatsTracerFactory); } } - if (tracingEnabled) { + if (!isGlobalInterceptorsTracersSet && tracingEnabled) { ServerStreamTracer.Factory tracingStreamTracerFactory = null; try { Class censusTracingAccessor = diff --git a/core/src/main/java/io/grpc/internal/ServiceConfigState.java b/core/src/main/java/io/grpc/internal/ServiceConfigState.java index 73486bb0bd9..d916c2d0e9f 100644 --- a/core/src/main/java/io/grpc/internal/ServiceConfigState.java +++ b/core/src/main/java/io/grpc/internal/ServiceConfigState.java @@ -35,6 +35,8 @@ final class ServiceConfigState { private boolean updated; /** + * Construct new instance. + * * @param defaultServiceConfig The initial service config, or {@code null} if absent. * @param lookUpServiceConfig {@code true} if service config updates might occur. */ diff --git a/core/src/main/java/io/grpc/internal/StreamListener.java b/core/src/main/java/io/grpc/internal/StreamListener.java index 090c0555b0a..a893a9c84b6 100644 --- a/core/src/main/java/io/grpc/internal/StreamListener.java +++ b/core/src/main/java/io/grpc/internal/StreamListener.java @@ -59,6 +59,6 @@ interface MessageProducer { * messages until the producer returns null, at which point the producer may be discarded. */ @Nullable - public InputStream next(); + InputStream next(); } } diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java index 9c9102b12cb..1530834d609 100644 --- a/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -246,7 +246,8 @@ private UpdateResult readAndUpdate(File keyFile, File certFile, long oldKeyTime, * Mainly used to avoid throwing IO Exceptions in java.io.Closeable. */ public interface Closeable extends java.io.Closeable { - @Override public void close(); + @Override + void close(); } } diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java index 51bf57aeb34..7465e632104 100644 --- a/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java @@ -295,7 +295,8 @@ private long readAndUpdate(File trustCertFile, long oldTime) // Mainly used to avoid throwing IO Exceptions in java.io.Closeable. public interface Closeable extends java.io.Closeable { - @Override public void close(); + @Override + void close(); } public static Builder newBuilder() { diff --git a/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java b/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java index 628eda3b71b..cefcbf344ea 100644 --- a/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/ForwardingLoadBalancer.java @@ -17,14 +17,10 @@ package io.grpc.util; import com.google.common.base.MoreObjects; -import io.grpc.Attributes; import io.grpc.ConnectivityStateInfo; -import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; -import io.grpc.NameResolver; import io.grpc.Status; -import java.util.List; @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") public abstract class ForwardingLoadBalancer extends LoadBalancer { @@ -33,14 +29,6 @@ public abstract class ForwardingLoadBalancer extends LoadBalancer { */ protected abstract LoadBalancer delegate(); - @Override - @Deprecated - public void handleResolvedAddressGroups( - List servers, - @NameResolver.ResolutionResultAttr Attributes attributes) { - delegate().handleResolvedAddressGroups(servers, attributes); - } - @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { delegate().handleResolvedAddresses(resolvedAddresses); diff --git a/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java b/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java index 05db207806d..c684051f0b2 100644 --- a/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java +++ b/core/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java @@ -22,10 +22,10 @@ import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.NameResolver; @@ -95,6 +95,7 @@ public void refreshNameResolution() { } @Override + @Deprecated public void ignoreRefreshNameResolutionCheck() { delegate().ignoreRefreshNameResolutionCheck(); } diff --git a/core/src/main/java/io/grpc/util/MutableHandlerRegistry.java b/core/src/main/java/io/grpc/util/MutableHandlerRegistry.java index 1e923418923..c31102f4213 100644 --- a/core/src/main/java/io/grpc/util/MutableHandlerRegistry.java +++ b/core/src/main/java/io/grpc/util/MutableHandlerRegistry.java @@ -65,7 +65,7 @@ public ServerServiceDefinition addService(BindableService bindableService) { } /** - * Removes a registered service + * Removes a registered service. * * @return true if the service was found to be removed. */ diff --git a/core/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/core/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java new file mode 100644 index 00000000000..e7d93d9d52f --- /dev/null +++ b/core/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -0,0 +1,1082 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ForwardingMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.TimeProvider; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.Nullable; + +/** + * Wraps a child {@code LoadBalancer} while monitoring for outlier backends and removing them from + * the use of the child LB. + * + *

This implements the outlier detection gRFC: + * https://github.com/grpc/proposal/blob/master/A50-xds-outlier-detection.md + */ +@Internal +public final class OutlierDetectionLoadBalancer extends LoadBalancer { + + @VisibleForTesting + final AddressTrackerMap trackerMap; + + private final SynchronizationContext syncContext; + private final Helper childHelper; + private final GracefulSwitchLoadBalancer switchLb; + private TimeProvider timeProvider; + private final ScheduledExecutorService timeService; + private ScheduledHandle detectionTimerHandle; + private Long detectionTimerStartNanos; + + private static final Attributes.Key ADDRESS_TRACKER_ATTR_KEY + = Attributes.Key.create("addressTrackerKey"); + + /** + * Creates a new instance of {@link OutlierDetectionLoadBalancer}. + */ + public OutlierDetectionLoadBalancer(Helper helper, TimeProvider timeProvider) { + childHelper = new ChildHelper(checkNotNull(helper, "helper")); + switchLb = new GracefulSwitchLoadBalancer(childHelper); + trackerMap = new AddressTrackerMap(); + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); + this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); + this.timeProvider = timeProvider; + } + + @Override + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + OutlierDetectionLoadBalancerConfig config + = (OutlierDetectionLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + + // The map should only retain entries for addresses in this latest update. + ArrayList addresses = new ArrayList<>(); + for (EquivalentAddressGroup addressGroup : resolvedAddresses.getAddresses()) { + addresses.addAll(addressGroup.getAddresses()); + } + trackerMap.keySet().retainAll(addresses); + + trackerMap.updateTrackerConfigs(config); + + // Add any new ones. + trackerMap.putNewTrackers(config, addresses); + + switchLb.switchTo(config.childPolicy.getProvider()); + + // If outlier detection is actually configured, start a timer that will periodically try to + // detect outliers. + if (config.outlierDetectionEnabled()) { + Long initialDelayNanos; + + if (detectionTimerStartNanos == null) { + // On the first go we use the configured interval. + initialDelayNanos = config.intervalNanos; + } else { + // If a timer has started earlier we cancel it and use the difference between the start + // time and now as the interval. + initialDelayNanos = Math.max(0L, + config.intervalNanos - (timeProvider.currentTimeNanos() - detectionTimerStartNanos)); + } + + // If a timer has been previously created we need to cancel it and reset all the call counters + // for a fresh start. + if (detectionTimerHandle != null) { + detectionTimerHandle.cancel(); + trackerMap.resetCallCounters(); + } + + detectionTimerHandle = syncContext.scheduleWithFixedDelay(new DetectionTimer(config), + initialDelayNanos, config.intervalNanos, NANOSECONDS, timeService); + } else if (detectionTimerHandle != null) { + // Outlier detection is not configured, but we have a lingering timer. Let's cancel it and + // uneject any addresses we may have ejected. + detectionTimerHandle.cancel(); + detectionTimerStartNanos = null; + trackerMap.cancelTracking(); + } + + switchLb.handleResolvedAddresses( + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config.childPolicy.getConfig()) + .build()); + return true; + } + + @Override + public void handleNameResolutionError(Status error) { + switchLb.handleNameResolutionError(error); + } + + @Override + public void shutdown() { + switchLb.shutdown(); + } + + /** + * This timer will be invoked periodically, according to configuration, and it will look for any + * outlier subchannels. + */ + class DetectionTimer implements Runnable { + + OutlierDetectionLoadBalancerConfig config; + + DetectionTimer(OutlierDetectionLoadBalancerConfig config) { + this.config = config; + } + + @Override + public void run() { + detectionTimerStartNanos = timeProvider.currentTimeNanos(); + + trackerMap.swapCounters(); + + for (OutlierEjectionAlgorithm algo : OutlierEjectionAlgorithm.forConfig(config)) { + algo.ejectOutliers(trackerMap, detectionTimerStartNanos); + } + + trackerMap.maybeUnejectOutliers(detectionTimerStartNanos); + } + } + + /** + * This child helper wraps the provided helper so that it can hand out wrapped {@link + * OutlierDetectionSubchannel}s and manage the address info map. + */ + class ChildHelper extends ForwardingLoadBalancerHelper { + + private Helper delegate; + + ChildHelper(Helper delegate) { + this.delegate = delegate; + } + + @Override + protected Helper delegate() { + return delegate; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + // Subchannels are wrapped so that we can monitor call results and to trigger failures when + // we decide to eject the subchannel. + OutlierDetectionSubchannel subchannel = new OutlierDetectionSubchannel( + delegate.createSubchannel(args)); + + // If the subchannel is associated with a single address that is also already in the map + // the subchannel will be added to the map and be included in outlier detection. + List addressGroups = args.getAddresses(); + if (hasSingleAddress(addressGroups) + && trackerMap.containsKey(addressGroups.get(0).getAddresses().get(0))) { + AddressTracker tracker = trackerMap.get(addressGroups.get(0).getAddresses().get(0)); + tracker.addSubchannel(subchannel); + + // If this address has already been ejected, we need to immediately eject this Subchannel. + if (tracker.ejectionTimeNanos != null) { + subchannel.eject(); + } + } + + return subchannel; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new OutlierDetectionPicker(newPicker)); + } + } + + class OutlierDetectionSubchannel extends ForwardingSubchannel { + + private final Subchannel delegate; + private AddressTracker addressTracker; + private boolean ejected; + private ConnectivityStateInfo lastSubchannelState; + private SubchannelStateListener subchannelStateListener; + + OutlierDetectionSubchannel(Subchannel delegate) { + this.delegate = delegate; + } + + @Override + public void start(SubchannelStateListener listener) { + subchannelStateListener = listener; + super.start(new OutlierDetectionSubchannelStateListener(listener)); + } + + @Override + public Attributes getAttributes() { + if (addressTracker != null) { + return delegate.getAttributes().toBuilder().set(ADDRESS_TRACKER_ATTR_KEY, addressTracker) + .build(); + } else { + return delegate.getAttributes(); + } + } + + @Override + public void updateAddresses(List addressGroups) { + // Outlier detection only supports subchannels with a single address, but the list of + // addressGroups associated with a subchannel can change at any time, so we need to react to + // changes in the address list plurality. + + // No change in address plurality, we replace the single one with a new one. + if (hasSingleAddress(getAllAddresses()) && hasSingleAddress(addressGroups)) { + // Remove the current subchannel from the old address it is associated with in the map. + if (trackerMap.containsValue(addressTracker)) { + addressTracker.removeSubchannel(this); + } + + // If the map has an entry for the new address, we associate this subchannel with it. + SocketAddress address = addressGroups.get(0).getAddresses().get(0); + if (trackerMap.containsKey(address)) { + trackerMap.get(address).addSubchannel(this); + } + } else if (hasSingleAddress(getAllAddresses()) && !hasSingleAddress(addressGroups)) { + // We go from a single address to having multiple, making this subchannel uneligible for + // outlier detection. Remove it from all trackers and reset the call counters of all the + // associated trackers. + // Remove the current subchannel from the old address it is associated with in the map. + if (trackerMap.containsKey(getAddresses().getAddresses().get(0))) { + AddressTracker tracker = trackerMap.get(getAddresses().getAddresses().get(0)); + tracker.removeSubchannel(this); + tracker.resetCallCounters(); + } + } else if (!hasSingleAddress(getAllAddresses()) && hasSingleAddress(addressGroups)) { + // We go from, previously uneligble, multiple address mode to a single address. If the map + // has an entry for the new address, we associate this subchannel with it. + SocketAddress address = addressGroups.get(0).getAddresses().get(0); + if (trackerMap.containsKey(address)) { + AddressTracker tracker = trackerMap.get(address); + tracker.addSubchannel(this); + } + } + + // We could also have multiple addressGroups and get an update for multiple new ones. This is + // a no-op as we will just continue to ignore multiple address subchannels. + + delegate.updateAddresses(addressGroups); + } + + /** + * If the {@link Subchannel} is considered for outlier detection the associated {@link + * AddressTracker} should be set. + */ + void setAddressTracker(AddressTracker addressTracker) { + this.addressTracker = addressTracker; + } + + void clearAddressTracker() { + this.addressTracker = null; + } + + void eject() { + ejected = true; + subchannelStateListener.onSubchannelState( + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + } + + void uneject() { + ejected = false; + if (lastSubchannelState != null) { + subchannelStateListener.onSubchannelState(lastSubchannelState); + } + } + + boolean isEjected() { + return ejected; + } + + @Override + protected Subchannel delegate() { + return delegate; + } + + /** + * Wraps the actual listener so that state changes from the actual one can be intercepted. + */ + class OutlierDetectionSubchannelStateListener implements SubchannelStateListener { + + private final SubchannelStateListener delegate; + + OutlierDetectionSubchannelStateListener(SubchannelStateListener delegate) { + this.delegate = delegate; + } + + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + lastSubchannelState = newState; + if (!ejected) { + delegate.onSubchannelState(newState); + } + } + } + } + + + /** + * This picker delegates the actual picking logic to a wrapped delegate, but associates a {@link + * ClientStreamTracer} with each pick to track the results of each subchannel stream. + */ + class OutlierDetectionPicker extends SubchannelPicker { + + private final SubchannelPicker delegate; + + OutlierDetectionPicker(SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + PickResult pickResult = delegate.pickSubchannel(args); + + Subchannel subchannel = pickResult.getSubchannel(); + if (subchannel != null) { + return PickResult.withSubchannel(subchannel, + new ResultCountingClientStreamTracerFactory( + subchannel.getAttributes().get(ADDRESS_TRACKER_ATTR_KEY))); + } + + return pickResult; + } + + /** + * Builds instances of {@link ResultCountingClientStreamTracer}. + */ + class ResultCountingClientStreamTracerFactory extends ClientStreamTracer.Factory { + + private final AddressTracker tracker; + + ResultCountingClientStreamTracerFactory(AddressTracker tracker) { + this.tracker = tracker; + } + + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new ResultCountingClientStreamTracer(tracker); + } + } + + /** + * Counts the results (successful/unsuccessful) of a particular {@link + * OutlierDetectionSubchannel}s streams and increments the counter in the associated {@link + * AddressTracker}. + */ + class ResultCountingClientStreamTracer extends ClientStreamTracer { + + AddressTracker tracker; + + public ResultCountingClientStreamTracer(AddressTracker tracker) { + this.tracker = tracker; + } + + @Override + public void streamClosed(Status status) { + tracker.incrementCallCount(status.isOk()); + } + } + } + + /** + * Tracks additional information about a set of equivalent addresses needed for outlier + * detection. + */ + static class AddressTracker { + + private OutlierDetectionLoadBalancerConfig config; + // Marked as volatile to assure that when the inactive counter is swapped in as the new active + // one, all threads see the change and don't hold on to a reference to the now inactive counter. + private volatile CallCounter activeCallCounter = new CallCounter(); + private CallCounter inactiveCallCounter = new CallCounter(); + private Long ejectionTimeNanos; + private int ejectionTimeMultiplier; + private final Set subchannels = new HashSet<>(); + + AddressTracker(OutlierDetectionLoadBalancerConfig config) { + this.config = config; + } + + void setConfig(OutlierDetectionLoadBalancerConfig config) { + this.config = config; + } + + /** + * Adds a subchannel to the tracker, while assuring that the subchannel ejection status is + * updated to match the tracker's if needed. + */ + boolean addSubchannel(OutlierDetectionSubchannel subchannel) { + // Make sure that the subchannel is in the same ejection state as the new tracker it is + // associated with. + if (subchannelsEjected() && !subchannel.isEjected()) { + subchannel.eject(); + } else if (!subchannelsEjected() && subchannel.isEjected()) { + subchannel.uneject(); + } + subchannel.setAddressTracker(this); + return subchannels.add(subchannel); + } + + boolean removeSubchannel(OutlierDetectionSubchannel subchannel) { + subchannel.clearAddressTracker(); + return subchannels.remove(subchannel); + } + + boolean containsSubchannel(OutlierDetectionSubchannel subchannel) { + return subchannels.contains(subchannel); + } + + @VisibleForTesting + Set getSubchannels() { + return ImmutableSet.copyOf(subchannels); + } + + void incrementCallCount(boolean success) { + // If neither algorithm is configured, no point in incrementing counters. + if (config.successRateEjection == null && config.failurePercentageEjection == null) { + return; + } + + if (success) { + activeCallCounter.successCount.getAndIncrement(); + } else { + activeCallCounter.failureCount.getAndIncrement(); + } + } + + @VisibleForTesting + long activeVolume() { + return activeCallCounter.successCount.get() + activeCallCounter.failureCount.get(); + } + + long inactiveVolume() { + return inactiveCallCounter.successCount.get() + inactiveCallCounter.failureCount.get(); + } + + double successRate() { + return ((double) inactiveCallCounter.successCount.get()) / inactiveVolume(); + } + + double failureRate() { + return ((double)inactiveCallCounter.failureCount.get()) / inactiveVolume(); + } + + void resetCallCounters() { + activeCallCounter.reset(); + inactiveCallCounter.reset(); + } + + void decrementEjectionTimeMultiplier() { + // The multiplier should not go negative. + ejectionTimeMultiplier = ejectionTimeMultiplier == 0 ? 0 : ejectionTimeMultiplier - 1; + } + + void resetEjectionTimeMultiplier() { + ejectionTimeMultiplier = 0; + } + + /** + * Swaps the active and inactive counters. + * + *

Note that this method is not thread safe as the swap is not done atomically. This is + * expected to only be called from the timer that is scheduled at a fixed delay, assuring that + * only one timer is active at a time. + */ + void swapCounters() { + inactiveCallCounter.reset(); + CallCounter tempCounter = activeCallCounter; + activeCallCounter = inactiveCallCounter; + inactiveCallCounter = tempCounter; + } + + void ejectSubchannels(long ejectionTimeNanos) { + this.ejectionTimeNanos = ejectionTimeNanos; + ejectionTimeMultiplier++; + for (OutlierDetectionSubchannel subchannel : subchannels) { + subchannel.eject(); + } + } + + /** + * Uneject a currently ejected address. + */ + void unejectSubchannels() { + checkState(ejectionTimeNanos != null, "not currently ejected"); + ejectionTimeNanos = null; + for (OutlierDetectionSubchannel subchannel : subchannels) { + subchannel.uneject(); + } + } + + boolean subchannelsEjected() { + return ejectionTimeNanos != null; + } + + public boolean maxEjectionTimeElapsed(long currentTimeNanos) { + // The instant in time beyond which the address should no longer be ejected. Also making sure + // we honor any maximum ejection time setting. + long maxEjectionDurationSecs + = Math.max(config.baseEjectionTimeNanos, config.maxEjectionTimeNanos); + long maxEjectionTimeNanos = + ejectionTimeNanos + Math.min( + config.baseEjectionTimeNanos * ejectionTimeMultiplier, + maxEjectionDurationSecs); + + return currentTimeNanos > maxEjectionTimeNanos; + } + + /** Tracks both successful and failed call counts. */ + private static class CallCounter { + AtomicLong successCount = new AtomicLong(); + AtomicLong failureCount = new AtomicLong(); + + void reset() { + successCount.set(0); + failureCount.set(0); + } + } + } + + /** + * Maintains a mapping from addresses to their trackers. + */ + static class AddressTrackerMap extends ForwardingMap { + private final Map trackerMap; + + AddressTrackerMap() { + trackerMap = new HashMap<>(); + } + + @Override + protected Map delegate() { + return trackerMap; + } + + void updateTrackerConfigs(OutlierDetectionLoadBalancerConfig config) { + for (AddressTracker tracker: trackerMap.values()) { + tracker.setConfig(config); + } + } + + /** Adds a new tracker for every given address. */ + void putNewTrackers(OutlierDetectionLoadBalancerConfig config, + Collection addresses) { + for (SocketAddress address : addresses) { + if (!trackerMap.containsKey(address)) { + trackerMap.put(address, new AddressTracker(config)); + } + } + } + + /** Resets the call counters for all the trackers in the map. */ + void resetCallCounters() { + for (AddressTracker tracker : trackerMap.values()) { + tracker.resetCallCounters(); + } + } + + /** + * When OD gets disabled we need to uneject any subchannels that may have been ejected and + * to reset the ejection time multiplier. + */ + void cancelTracking() { + for (AddressTracker tracker : trackerMap.values()) { + if (tracker.subchannelsEjected()) { + tracker.unejectSubchannels(); + } + tracker.resetEjectionTimeMultiplier(); + } + } + + /** Swaps the active and inactive counters for each tracker. */ + void swapCounters() { + for (AddressTracker tracker : trackerMap.values()) { + tracker.swapCounters(); + } + } + + /** + * At the end of a timer run we need to decrement the ejection time multiplier for trackers + * that don't have ejected subchannels and uneject ones that have spent the maximum ejection + * time allowed. + */ + void maybeUnejectOutliers(Long detectionTimerStartNanos) { + for (AddressTracker tracker : trackerMap.values()) { + if (!tracker.subchannelsEjected()) { + tracker.decrementEjectionTimeMultiplier(); + } + + if (tracker.subchannelsEjected() && tracker.maxEjectionTimeElapsed( + detectionTimerStartNanos)) { + tracker.unejectSubchannels(); + } + } + } + + /** + * How many percent of the addresses have been ejected. + */ + double ejectionPercentage() { + if (trackerMap.isEmpty()) { + return 0; + } + int totalAddresses = 0; + int ejectedAddresses = 0; + for (AddressTracker tracker : trackerMap.values()) { + totalAddresses++; + if (tracker.subchannelsEjected()) { + ejectedAddresses++; + } + } + return ((double)ejectedAddresses / totalAddresses) * 100; + } + } + + + /** + * Implementations provide different ways of ejecting outlier addresses.. + */ + interface OutlierEjectionAlgorithm { + + /** Eject any outlier addresses. */ + void ejectOutliers(AddressTrackerMap trackerMap, long ejectionTimeNanos); + + /** Builds a list of algorithms that are enabled in the given config. */ + @Nullable + static List forConfig(OutlierDetectionLoadBalancerConfig config) { + ImmutableList.Builder algoListBuilder = ImmutableList.builder(); + if (config.successRateEjection != null) { + algoListBuilder.add(new SuccessRateOutlierEjectionAlgorithm(config)); + } + if (config.failurePercentageEjection != null) { + algoListBuilder.add(new FailurePercentageOutlierEjectionAlgorithm(config)); + } + return algoListBuilder.build(); + } + } + + /** + * This algorithm ejects addresses that don't maintain a required rate of successful calls. The + * required rate is not fixed, but is based on the mean and standard deviation of the success + * rates of all of the addresses. + */ + static class SuccessRateOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { + + private final OutlierDetectionLoadBalancerConfig config; + + SuccessRateOutlierEjectionAlgorithm(OutlierDetectionLoadBalancerConfig config) { + checkArgument(config.successRateEjection != null, "success rate ejection config is null"); + this.config = config; + } + + @Override + public void ejectOutliers(AddressTrackerMap trackerMap, long ejectionTimeNanos) { + + // Only consider addresses that have the minimum request volume specified in the config. + List trackersWithVolume = trackersWithVolume(trackerMap, + config.successRateEjection.requestVolume); + // If we don't have enough addresses with significant volume then there's nothing to do. + if (trackersWithVolume.size() < config.successRateEjection.minimumHosts + || trackersWithVolume.size() == 0) { + return; + } + + // Calculate mean and standard deviation of the fractions of successful calls. + List successRates = new ArrayList<>(); + for (AddressTracker tracker : trackersWithVolume) { + successRates.add(tracker.successRate()); + } + double mean = mean(successRates); + double stdev = standardDeviation(successRates, mean); + + double requiredSuccessRate = + mean - stdev * (config.successRateEjection.stdevFactor / 1000f); + + for (AddressTracker tracker : trackersWithVolume) { + // If we are above or equal to the max ejection percentage, don't eject any more. This will + // allow the total ejections to go one above the max, but at the same time it assures at + // least one ejection, which the spec calls for. This behavior matches what Envoy proxy + // does. + if (trackerMap.ejectionPercentage() >= config.maxEjectionPercent) { + return; + } + + // If success rate is below the threshold, eject the address. + if (tracker.successRate() < requiredSuccessRate) { + // Only eject some addresses based on the enforcement percentage. + if (new Random().nextInt(100) < config.successRateEjection.enforcementPercentage) { + tracker.ejectSubchannels(ejectionTimeNanos); + } + } + } + } + + /** Calculates the mean of the given values. */ + @VisibleForTesting + static double mean(Collection values) { + double totalValue = 0; + for (double value : values) { + totalValue += value; + } + + return totalValue / values.size(); + } + + /** Calculates the standard deviation for the given values and their mean. */ + @VisibleForTesting + static double standardDeviation(Collection values, double mean) { + double squaredDifferenceSum = 0; + for (double value : values) { + double difference = value - mean; + squaredDifferenceSum += difference * difference; + } + double variance = squaredDifferenceSum / values.size(); + + return Math.sqrt(variance); + } + } + + static class FailurePercentageOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { + + private final OutlierDetectionLoadBalancerConfig config; + + FailurePercentageOutlierEjectionAlgorithm(OutlierDetectionLoadBalancerConfig config) { + this.config = config; + } + + @Override + public void ejectOutliers(AddressTrackerMap trackerMap, long ejectionTimeNanos) { + + // Only consider addresses that have the minimum request volume specified in the config. + List trackersWithVolume = trackersWithVolume(trackerMap, + config.failurePercentageEjection.requestVolume); + // If we don't have enough addresses with significant volume then there's nothing to do. + if (trackersWithVolume.size() < config.failurePercentageEjection.minimumHosts + || trackersWithVolume.size() == 0) { + return; + } + + // If this address does not have enough volume to be considered, skip to the next one. + for (AddressTracker tracker : trackersWithVolume) { + // If we are above or equal to the max ejection percentage, don't eject any more. This will + // allow the total ejections to go one above the max, but at the same time it assures at + // least one ejection, which the spec calls for. This behavior matches what Envoy proxy + // does. + if (trackerMap.ejectionPercentage() >= config.maxEjectionPercent) { + return; + } + + if (tracker.inactiveVolume() < config.failurePercentageEjection.requestVolume) { + continue; + } + + // If the failure rate is above the threshold, we should eject... + double maxFailureRate = ((double)config.failurePercentageEjection.threshold) / 100; + if (tracker.failureRate() > maxFailureRate) { + // ...but only enforce this based on the enforcement percentage. + if (new Random().nextInt(100) < config.failurePercentageEjection.enforcementPercentage) { + tracker.ejectSubchannels(ejectionTimeNanos); + } + } + } + } + } + + /** Returns only the trackers that have the minimum configured volume to be considered. */ + private static List trackersWithVolume(AddressTrackerMap trackerMap, + int volume) { + List trackersWithVolume = new ArrayList<>(); + for (AddressTracker tracker : trackerMap.values()) { + if (tracker.inactiveVolume() >= volume) { + trackersWithVolume.add(tracker); + } + } + return trackersWithVolume; + } + + /** Counts how many addresses are in a given address group. */ + private static boolean hasSingleAddress(List addressGroups) { + int addressCount = 0; + for (EquivalentAddressGroup addressGroup : addressGroups) { + addressCount += addressGroup.getAddresses().size(); + if (addressCount > 1) { + return false; + } + } + return true; + } + + /** + * The configuration for {@link OutlierDetectionLoadBalancer}. + */ + public static final class OutlierDetectionLoadBalancerConfig { + + public final Long intervalNanos; + public final Long baseEjectionTimeNanos; + public final Long maxEjectionTimeNanos; + public final Integer maxEjectionPercent; + public final SuccessRateEjection successRateEjection; + public final FailurePercentageEjection failurePercentageEjection; + public final PolicySelection childPolicy; + + private OutlierDetectionLoadBalancerConfig(Long intervalNanos, + Long baseEjectionTimeNanos, + Long maxEjectionTimeNanos, + Integer maxEjectionPercent, + SuccessRateEjection successRateEjection, + FailurePercentageEjection failurePercentageEjection, + PolicySelection childPolicy) { + this.intervalNanos = intervalNanos; + this.baseEjectionTimeNanos = baseEjectionTimeNanos; + this.maxEjectionTimeNanos = maxEjectionTimeNanos; + this.maxEjectionPercent = maxEjectionPercent; + this.successRateEjection = successRateEjection; + this.failurePercentageEjection = failurePercentageEjection; + this.childPolicy = childPolicy; + } + + /** Builds a new {@link OutlierDetectionLoadBalancerConfig}. */ + public static class Builder { + Long intervalNanos = 10_000_000_000L; // 10s + Long baseEjectionTimeNanos = 30_000_000_000L; // 30s + Long maxEjectionTimeNanos = 30_000_000_000L; // 30s + Integer maxEjectionPercent = 10; + SuccessRateEjection successRateEjection; + FailurePercentageEjection failurePercentageEjection; + PolicySelection childPolicy; + + /** The interval between outlier detection sweeps. */ + public Builder setIntervalNanos(Long intervalNanos) { + checkArgument(intervalNanos != null); + this.intervalNanos = intervalNanos; + return this; + } + + /** The base time an address is ejected for. */ + public Builder setBaseEjectionTimeNanos(Long baseEjectionTimeNanos) { + checkArgument(baseEjectionTimeNanos != null); + this.baseEjectionTimeNanos = baseEjectionTimeNanos; + return this; + } + + /** The longest time an address can be ejected. */ + public Builder setMaxEjectionTimeNanos(Long maxEjectionTimeNanos) { + checkArgument(maxEjectionTimeNanos != null); + this.maxEjectionTimeNanos = maxEjectionTimeNanos; + return this; + } + + /** The algorithm agnostic maximum percentage of addresses that can be ejected. */ + public Builder setMaxEjectionPercent(Integer maxEjectionPercent) { + checkArgument(maxEjectionPercent != null); + this.maxEjectionPercent = maxEjectionPercent; + return this; + } + + /** Set to enable success rate ejection. */ + public Builder setSuccessRateEjection( + SuccessRateEjection successRateEjection) { + this.successRateEjection = successRateEjection; + return this; + } + + /** Set to enable failure percentage ejection. */ + public Builder setFailurePercentageEjection( + FailurePercentageEjection failurePercentageEjection) { + this.failurePercentageEjection = failurePercentageEjection; + return this; + } + + /** Sets the child policy the {@link OutlierDetectionLoadBalancer} delegates to. */ + public Builder setChildPolicy(PolicySelection childPolicy) { + checkState(childPolicy != null); + this.childPolicy = childPolicy; + return this; + } + + /** Builds a new instance of {@link OutlierDetectionLoadBalancerConfig}. */ + public OutlierDetectionLoadBalancerConfig build() { + checkState(childPolicy != null); + return new OutlierDetectionLoadBalancerConfig(intervalNanos, baseEjectionTimeNanos, + maxEjectionTimeNanos, maxEjectionPercent, successRateEjection, + failurePercentageEjection, childPolicy); + } + } + + /** The configuration for success rate ejection. */ + public static class SuccessRateEjection { + + public final Integer stdevFactor; + public final Integer enforcementPercentage; + public final Integer minimumHosts; + public final Integer requestVolume; + + SuccessRateEjection(Integer stdevFactor, Integer enforcementPercentage, Integer minimumHosts, + Integer requestVolume) { + this.stdevFactor = stdevFactor; + this.enforcementPercentage = enforcementPercentage; + this.minimumHosts = minimumHosts; + this.requestVolume = requestVolume; + } + + /** Builds new instances of {@link SuccessRateEjection}. */ + public static final class Builder { + + Integer stdevFactor = 1900; + Integer enforcementPercentage = 100; + Integer minimumHosts = 5; + Integer requestVolume = 100; + + /** The product of this and the standard deviation of success rates determine the ejection + * threshold. + */ + public Builder setStdevFactor(Integer stdevFactor) { + checkArgument(stdevFactor != null); + this.stdevFactor = stdevFactor; + return this; + } + + /** Only eject this percentage of outliers. */ + public Builder setEnforcementPercentage(Integer enforcementPercentage) { + checkArgument(enforcementPercentage != null); + checkArgument(enforcementPercentage >= 0 && enforcementPercentage <= 100); + this.enforcementPercentage = enforcementPercentage; + return this; + } + + /** The minimum amount of hosts needed for success rate ejection. */ + public Builder setMinimumHosts(Integer minimumHosts) { + checkArgument(minimumHosts != null); + checkArgument(minimumHosts >= 0); + this.minimumHosts = minimumHosts; + return this; + } + + /** The minimum address request volume to be considered for success rate ejection. */ + public Builder setRequestVolume(Integer requestVolume) { + checkArgument(requestVolume != null); + checkArgument(requestVolume >= 0); + this.requestVolume = requestVolume; + return this; + } + + /** Builds a new instance of {@link SuccessRateEjection}. */ + public SuccessRateEjection build() { + return new SuccessRateEjection(stdevFactor, enforcementPercentage, minimumHosts, + requestVolume); + } + } + } + + /** The configuration for failure percentage ejection. */ + public static class FailurePercentageEjection { + public final Integer threshold; + public final Integer enforcementPercentage; + public final Integer minimumHosts; + public final Integer requestVolume; + + FailurePercentageEjection(Integer threshold, Integer enforcementPercentage, + Integer minimumHosts, Integer requestVolume) { + this.threshold = threshold; + this.enforcementPercentage = enforcementPercentage; + this.minimumHosts = minimumHosts; + this.requestVolume = requestVolume; + } + + /** For building new {@link FailurePercentageEjection} instances. */ + public static class Builder { + Integer threshold = 85; + Integer enforcementPercentage = 100; + Integer minimumHosts = 5; + Integer requestVolume = 50; + + /** The failure percentage that will result in an address being considered an outlier. */ + public Builder setThreshold(Integer threshold) { + checkArgument(threshold != null); + checkArgument(threshold >= 0 && threshold <= 100); + this.threshold = threshold; + return this; + } + + /** Only eject this percentage of outliers. */ + public Builder setEnforcementPercentage(Integer enforcementPercentage) { + checkArgument(enforcementPercentage != null); + checkArgument(enforcementPercentage >= 0 && enforcementPercentage <= 100); + this.enforcementPercentage = enforcementPercentage; + return this; + } + + /** The minimum amount of host for failure percentage ejection to be enabled. */ + public Builder setMinimumHosts(Integer minimumHosts) { + checkArgument(minimumHosts != null); + checkArgument(minimumHosts >= 0); + this.minimumHosts = minimumHosts; + return this; + } + + /** + * The request volume required for an address to be considered for failure percentage + * ejection. + */ + public Builder setRequestVolume(Integer requestVolume) { + checkArgument(requestVolume != null); + checkArgument(requestVolume >= 0); + this.requestVolume = requestVolume; + return this; + } + + /** Builds a new instance of {@link FailurePercentageEjection}. */ + public FailurePercentageEjection build() { + return new FailurePercentageEjection(threshold, enforcementPercentage, minimumHosts, + requestVolume); + } + } + } + + /** Determine if any outlier detection algorithms are enabled in the config. */ + boolean outlierDetectionEnabled() { + return successRateEjection != null || failurePercentageEjection != null; + } + } +} diff --git a/core/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java b/core/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java new file mode 100644 index 00000000000..e52c7414653 --- /dev/null +++ b/core/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java @@ -0,0 +1,160 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.TimeProvider; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.FailurePercentageEjection; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.SuccessRateEjection; +import java.util.List; +import java.util.Map; + +@Internal +public final class OutlierDetectionLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new OutlierDetectionLoadBalancer(helper, TimeProvider.SYSTEM_TIME_PROVIDER); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "outlier_detection_experimental"; + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + // Common configuration. + Long intervalNanos = JsonUtil.getStringAsDuration(rawConfig, "interval"); + Long baseEjectionTimeNanos = JsonUtil.getStringAsDuration(rawConfig, "baseEjectionTime"); + Long maxEjectionTimeNanos = JsonUtil.getStringAsDuration(rawConfig, "maxEjectionTime"); + Integer maxEjectionPercentage = JsonUtil.getNumberAsInteger(rawConfig, + "maxEjectionPercentage"); + + OutlierDetectionLoadBalancerConfig.Builder configBuilder + = new OutlierDetectionLoadBalancerConfig.Builder(); + if (intervalNanos != null) { + configBuilder.setIntervalNanos(intervalNanos); + } + if (baseEjectionTimeNanos != null) { + configBuilder.setBaseEjectionTimeNanos(baseEjectionTimeNanos); + } + if (maxEjectionTimeNanos != null) { + configBuilder.setMaxEjectionTimeNanos(maxEjectionTimeNanos); + } + if (maxEjectionPercentage != null) { + configBuilder.setMaxEjectionPercent(maxEjectionPercentage); + } + + // Success rate ejection specific configuration. + Map rawSuccessRateEjection = JsonUtil.getObject(rawConfig, "successRateEjection"); + if (rawSuccessRateEjection != null) { + SuccessRateEjection.Builder successRateEjectionBuilder = new SuccessRateEjection.Builder(); + + Integer stdevFactor = JsonUtil.getNumberAsInteger(rawSuccessRateEjection, "stdevFactor"); + Integer enforcementPercentage = JsonUtil.getNumberAsInteger(rawSuccessRateEjection, + "enforcementPercentage"); + Integer minimumHosts = JsonUtil.getNumberAsInteger(rawSuccessRateEjection, "minimumHosts"); + Integer requestVolume = JsonUtil.getNumberAsInteger(rawSuccessRateEjection, "requestVolume"); + + if (stdevFactor != null) { + successRateEjectionBuilder.setStdevFactor(stdevFactor); + } + if (enforcementPercentage != null) { + successRateEjectionBuilder.setEnforcementPercentage(enforcementPercentage); + } + if (minimumHosts != null) { + successRateEjectionBuilder.setMinimumHosts(minimumHosts); + } + if (requestVolume != null) { + successRateEjectionBuilder.setRequestVolume(requestVolume); + } + + configBuilder.setSuccessRateEjection(successRateEjectionBuilder.build()); + } + + // Failure percentage ejection specific configuration. + Map rawFailurePercentageEjection = JsonUtil.getObject(rawConfig, + "failurePercentageEjection"); + if (rawFailurePercentageEjection != null) { + FailurePercentageEjection.Builder failurePercentageEjectionBuilder + = new FailurePercentageEjection.Builder(); + + Integer threshold = JsonUtil.getNumberAsInteger(rawFailurePercentageEjection, "threshold"); + Integer enforcementPercentage = JsonUtil.getNumberAsInteger(rawFailurePercentageEjection, + "enforcementPercentage"); + Integer minimumHosts = JsonUtil.getNumberAsInteger(rawFailurePercentageEjection, + "minimumHosts"); + Integer requestVolume = JsonUtil.getNumberAsInteger(rawFailurePercentageEjection, + "requestVolume"); + + if (threshold != null) { + failurePercentageEjectionBuilder.setThreshold(threshold); + } + if (enforcementPercentage != null) { + failurePercentageEjectionBuilder.setEnforcementPercentage(enforcementPercentage); + } + if (minimumHosts != null) { + failurePercentageEjectionBuilder.setMinimumHosts(minimumHosts); + } + if (requestVolume != null) { + failurePercentageEjectionBuilder.setRequestVolume(requestVolume); + } + + configBuilder.setFailurePercentageEjection(failurePercentageEjectionBuilder.build()); + } + + // Child load balancer configuration. + List childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(rawConfig, "childPolicy")); + if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { + return ConfigOrError.fromError(Status.INTERNAL.withDescription( + "No child policy in outlier_detection_experimental LB policy: " + + rawConfig)); + } + ConfigOrError selectedConfig = + ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, + LoadBalancerRegistry.getDefaultRegistry()); + if (selectedConfig.getError() != null) { + return selectedConfig; + } + configBuilder.setChildPolicy((PolicySelection) selectedConfig.getConfig()); + + return ConfigOrError.fromConfig(configBuilder.build()); + } +} diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 4a6a1ff5611..b715f756144 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -69,7 +69,14 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getAddresses().isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } + List servers = resolvedAddresses.getAddresses(); Set currentAddrs = subchannels.keySet(); Map latestAddrs = stripAttrs(servers); @@ -126,6 +133,8 @@ public void onSubchannelState(ConnectivityStateInfo state) { for (Subchannel removedSubchannel : removedSubchannels) { shutdownSubchannel(removedSubchannel); } + + return true; } @Override diff --git a/core/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/core/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index cb200d5f044..d68a57c4eb3 100644 --- a/core/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/core/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -1,2 +1,3 @@ io.grpc.internal.PickFirstLoadBalancerProvider io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider +io.grpc.util.OutlierDetectionLoadBalancerProvider diff --git a/core/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java b/core/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java index 152fdf2252a..b1d80d53b8b 100644 --- a/core/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java +++ b/core/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java @@ -79,7 +79,8 @@ protected ManagedClientTransport newClientTransport(InternalServer server) { eagAttrs(), schedulerPool, testServer.streamTracerFactories, - testServer.serverListener); + testServer.serverListener, + false); } @Override diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index 65fc89be231..9f6c4922aa5 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -28,11 +28,13 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.common.util.concurrent.SettableFuture; import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.grpc.internal.AbstractServerStream.TransportState; import io.grpc.internal.MessageFramerTest.ByteWritableBuffer; import java.io.ByteArrayInputStream; @@ -108,6 +110,43 @@ public void messagesAvailable(MessageProducer producer) { assertNull("no message expected", streamListenerMessageQueue.poll()); } + @Test + public void noHalfCloseListenerOnCancellation() throws Exception { + final Queue streamListenerMessageQueue = new LinkedList<>(); + final SettableFuture closedFuture = SettableFuture.create(); + + stream.transportState().setListener(new ServerStreamListenerBase() { + @Override + public void messagesAvailable(StreamListener.MessageProducer producer) { + InputStream message; + while ((message = producer.next()) != null) { + streamListenerMessageQueue.add(message); + } + } + + @Override + public void halfClosed() { + if (streamListenerMessageQueue.isEmpty()) { + throw new StatusRuntimeException(Status.INTERNAL.withDescription( + "Half close without request")); + } + } + + @Override + public void closed(Status status) { + closedFuture.set(status); + } + }); + + ReadableBuffer buffer = mock(ReadableBuffer.class); + when(buffer.readableBytes()).thenReturn(1); + stream.transportState().inboundDataReceived(buffer, true); + Status cancel = Status.CANCELLED.withDescription("DEADLINE EXCEEDED"); + stream.transportState().transportReportStatus(cancel); + assertEquals(cancel, closedFuture.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + verify(buffer).close(); + } + @Test public void queuedBytesInDeframerShouldNotBlockComplete() throws Exception { final SettableFuture closedFuture = SettableFuture.create(); diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index cd522181311..a1c00d7dca2 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -408,12 +408,13 @@ public void serverStartInterrupted() throws Exception { } assumeTrue("transport is not using InetSocketAddress", port != -1); server.shutdown(); + assertTrue(serverListener.waitForShutdown(TIMEOUT_MS, TimeUnit.MILLISECONDS)); server = newServer(port, Arrays.asList(serverStreamTracerFactory)); boolean success; Thread.currentThread().interrupt(); try { - server.start(serverListener); + server.start(serverListener = new MockServerListener()); success = true; } catch (Exception ex) { success = false; @@ -1145,7 +1146,7 @@ public void earlyServerClose_serverFailure() throws Exception { public void earlyServerClose_serverFailure_withClientCancelOnListenerClosed() throws Exception { server.start(serverListener); client = newClientTransport(server); - runIfNotNull(client.start(mockClientTransportListener)); + startTransport(client, mockClientTransportListener); MockServerTransportListener serverTransportListener = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; @@ -1503,6 +1504,36 @@ private int verifyMessageCountAndClose(BlockingQueue messageQueue, return count; } + @Test + public void messageProducerOnlyProducesRequestedMessages() throws Exception { + server.start(serverListener); + client = newClientTransport(server); + startTransport(client, mockClientTransportListener); + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + serverTransport = serverTransportListener.transport; + + // Start an RPC. + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); + StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method); + + // Have the client send two messages. + clientStream.writeMessage(methodDescriptor.streamRequest("MESSAGE")); + clientStream.writeMessage(methodDescriptor.streamRequest("MESSAGE")); + clientStream.flush(); + + doPingPong(serverListener); + + // Verify server only receives one message if that's all it requests. + serverStreamCreation.stream.request(1); + verifyMessageCountAndClose(serverStreamCreation.listener.messageQueue, 1); + } + @Test public void interactionsAfterServerStreamCloseAreNoops() throws Exception { server.start(serverListener); diff --git a/core/src/test/java/io/grpc/internal/AtomicBackoffTest.java b/core/src/test/java/io/grpc/internal/AtomicBackoffTest.java index 780db5258ee..3c0277a7c79 100644 --- a/core/src/test/java/io/grpc/internal/AtomicBackoffTest.java +++ b/core/src/test/java/io/grpc/internal/AtomicBackoffTest.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import org.junit.Test; import org.junit.runner.RunWith; @@ -25,9 +26,13 @@ /** Unit tests for {@link AtomicBackoff}. */ @RunWith(JUnit4.class) public class AtomicBackoffTest { - @Test(expected = IllegalArgumentException.class) + @Test public void mustBePositive() { - new AtomicBackoff("test", 0); + try { + new AtomicBackoff("test", 0); + fail(); + } catch (IllegalArgumentException expected) { + } } @Test diff --git a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java index bf7c63818cf..ad886c31142 100644 --- a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java +++ b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java @@ -21,8 +21,8 @@ import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -97,9 +97,8 @@ public class AutoConfiguredLoadBalancerFactoryTest { @Before public void setUp() { - when(testLbBalancer.canHandleEmptyAddressListFromNameResolution()).thenCallRealMethod(); - assertThat(testLbBalancer.canHandleEmptyAddressListFromNameResolution()).isFalse(); - when(testLbBalancer2.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); + when(testLbBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(true); + when(testLbBalancer2.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(true); defaultRegistry.register(testLbBalancerProvider); defaultRegistry.register(testLbBalancerProvider2); } @@ -171,7 +170,7 @@ public void shutdown() { } @Test - public void handleResolvedAddressGroups_keepOldBalancer() { + public void acceptResolvedAddresses_keepOldBalancer() { final List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); Helper helper = new TestHelper() { @@ -184,19 +183,19 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); LoadBalancer oldDelegate = lb.getDelegate(); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setAttributes(Attributes.EMPTY) .setLoadBalancingPolicyConfig(null) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate()).isSameInstanceAs(oldDelegate); } @Test - public void handleResolvedAddressGroups_shutsDownOldBalancer() throws Exception { + public void acceptResolvedAddresses_shutsDownOldBalancer() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"round_robin\": { } } ] }"); ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); @@ -226,13 +225,13 @@ public void shutdown() { }; lb.setDelegate(testlb); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegateProvider().getClass().getName()).isEqualTo( "io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider"); assertTrue(shutdown.get()); @@ -240,7 +239,7 @@ public void shutdown() { @Test @SuppressWarnings("unchecked") - public void handleResolvedAddressGroups_propagateLbConfigToDelegate() throws Exception { + public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); @@ -251,20 +250,19 @@ public void handleResolvedAddressGroups_propagateLbConfigToDelegate() throws Exc Helper helper = new TestHelper(); AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); verify(testLbBalancerProvider).newLoadBalancer(same(helper)); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate()).isSameInstanceAs(testLbBalancer); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(testLbBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(testLbBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactlyElementsIn(servers).inOrder(); - verify(testLbBalancer, atLeast(0)).canHandleEmptyAddressListFromNameResolution(); ArgumentCaptor> lbConfigCaptor = ArgumentCaptor.forClass(Map.class); verify(testLbBalancerProvider).parseLoadBalancingPolicyConfig(lbConfigCaptor.capture()); assertThat(lbConfigCaptor.getValue()).containsExactly("setting1", "high"); @@ -274,7 +272,7 @@ public void handleResolvedAddressGroups_propagateLbConfigToDelegate() throws Exc parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"low\" } } ] }"); lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); - handleResult = lb.tryHandleResolvedAddresses( + addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -282,8 +280,8 @@ public void handleResolvedAddressGroups_propagateLbConfigToDelegate() throws Exc resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(testLbBalancer, times(2)).handleResolvedAddresses(resultCaptor.capture()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + verify(testLbBalancer, times(2)).acceptResolvedAddresses(resultCaptor.capture()); + assertThat(addressesAccepted).isTrue(); assertThat(resultCaptor.getValue().getAddresses()).containsExactlyElementsIn(servers).inOrder(); verify(testLbBalancerProvider, times(2)) .parseLoadBalancingPolicyConfig(lbConfigCaptor.capture()); @@ -294,7 +292,7 @@ public void handleResolvedAddressGroups_propagateLbConfigToDelegate() throws Exc } @Test - public void handleResolvedAddressGroups_propagateAddrsToDelegate() throws Exception { + public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); @@ -305,56 +303,58 @@ public void handleResolvedAddressGroups_propagateAddrsToDelegate() throws Except List servers = Collections.singletonList(new EquivalentAddressGroup(new InetSocketAddress(8080){})); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); verify(testLbBalancerProvider).newLoadBalancer(same(helper)); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate()).isSameInstanceAs(testLbBalancer); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(testLbBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(testLbBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactlyElementsIn(servers).inOrder(); servers = Collections.singletonList(new EquivalentAddressGroup(new InetSocketAddress(9090){})); - handleResult = lb.tryHandleResolvedAddresses( + addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); - verify(testLbBalancer, times(2)).handleResolvedAddresses(resultCaptor.capture()); + assertThat(addressesAccepted).isTrue(); + verify(testLbBalancer, times(2)).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactlyElementsIn(servers).inOrder(); } @Test - public void handleResolvedAddressGroups_delegateDoNotAcceptEmptyAddressList_nothing() + public void acceptResolvedAddresses_delegateDoNotAcceptEmptyAddressList_nothing() throws Exception { + + // The test LB will NOT accept the addresses we give them. + when(testLbBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(false); + Helper helper = new TestHelper(); AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); ConfigOrError lbConfig = lbf.parseLoadBalancerPolicy(serviceConfig); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(lbConfig.getConfig()) .build()); - assertThat(testLbBalancer.canHandleEmptyAddressListFromNameResolution()).isFalse(); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(handleResult.getDescription()).startsWith("NameResolver returned no usable address"); + assertThat(addressesAccepted).isFalse(); assertThat(lb.getDelegate()).isSameInstanceAs(testLbBalancer); } @Test - public void handleResolvedAddressGroups_delegateAcceptsEmptyAddressList() + public void acceptResolvedAddresses_delegateAcceptsEmptyAddressList() throws Exception { Helper helper = new TestHelper(); AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); @@ -363,25 +363,24 @@ public void handleResolvedAddressGroups_delegateAcceptsEmptyAddressList() parseConfig("{\"loadBalancingConfig\": [ {\"test_lb2\": { \"setting1\": \"high\" } } ] }"); ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate()).isSameInstanceAs(testLbBalancer2); - assertThat(testLbBalancer2.canHandleEmptyAddressListFromNameResolution()).isTrue(); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(testLbBalancer2).handleResolvedAddresses(resultCaptor.capture()); + verify(testLbBalancer2).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).isEmpty(); assertThat(resultCaptor.getValue().getLoadBalancingPolicyConfig()) .isEqualTo(nextParsedConfigOrError2.get().getConfig()); } @Test - public void handleResolvedAddressGroups_useSelectedLbPolicy() throws Exception { + public void acceptResolvedAddresses_useSelectedLbPolicy() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [{\"round_robin\": {}}]}"); ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); @@ -399,18 +398,18 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } }; AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate().getClass().getName()) .isEqualTo("io.grpc.util.RoundRobinLoadBalancer"); } @Test - public void handleResolvedAddressGroups_noLbPolicySelected_defaultToPickFirst() { + public void acceptResolvedAddresses_noLbPolicySelected_defaultToPickFirst() { final List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); Helper helper = new TestHelper() { @@ -421,27 +420,27 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } }; AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(null) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate()).isInstanceOf(PickFirstLoadBalancer.class); } @Test - public void handleResolvedAddressGroups_noLbPolicySelected_defaultToCustomDefault() { + public void acceptResolvedAddresses_noLbPolicySelected_defaultToCustomDefault() { AutoConfiguredLoadBalancer lb = new AutoConfiguredLoadBalancerFactory("test_lb") .newLoadBalancer(new TestHelper()); List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(null) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); assertThat(lb.getDelegate()).isSameInstanceAs(testLbBalancer); } @@ -458,13 +457,13 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { AutoConfiguredLoadBalancer lb = new AutoConfiguredLoadBalancerFactory(GrpcUtil.DEFAULT_LB_POLICY).newLoadBalancer(helper); - Status handleResult = lb.tryHandleResolvedAddresses( + boolean addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setAttributes(Attributes.EMPTY) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); verifyNoMoreInteractions(channelLogger); ConfigOrError testLbParsedConfig = ConfigOrError.fromConfig("foo"); @@ -472,13 +471,13 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { } } ] }"); ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); - handleResult = lb.tryHandleResolvedAddresses( + addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); verify(channelLogger).log( eq(ChannelLogLevel.INFO), eq("Load balancer changed from {0} to {1}"), @@ -495,12 +494,12 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { nextParsedConfigOrError.set(testLbParsedConfig); serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { } } ] }"); lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); - handleResult = lb.tryHandleResolvedAddresses( + addressesAccepted = lb.tryAcceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) .build()); - assertThat(handleResult.getCode()).isEqualTo(Status.Code.OK); + assertThat(addressesAccepted).isTrue(); verify(channelLogger).log( eq(ChannelLogLevel.DEBUG), eq("Load-balancing config: {0}"), @@ -643,14 +642,13 @@ protected LoadBalancer delegate() { @Override @Deprecated - public void handleResolvedAddressGroups( - List servers, Attributes attributes) { - delegate().handleResolvedAddressGroups(servers, attributes); + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + delegate().acceptResolvedAddresses(resolvedAddresses); } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - delegate().handleResolvedAddresses(resolvedAddresses); + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + return delegate().acceptResolvedAddresses(resolvedAddresses); } @Override diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 2f0ce1070b1..bb64bbae188 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; @@ -41,6 +42,7 @@ import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StringMarshaller; import java.net.SocketAddress; import java.util.concurrent.Executor; @@ -142,12 +144,13 @@ public void parameterPropagation_base() { transport.newStream(method, origHeaders, callOptions, tracers); - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(RequestInfo.class); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); RequestInfo info = infoCaptor.getValue(); assertSame(transportAttrs, info.getTransportAttrs()); assertSame(method, info.getMethodDescriptor()); + assertSame(callOptions, info.getCallOptions()); assertSame(AUTHORITY, info.getAuthority()); assertSame(SecurityLevel.NONE, info.getSecurityLevel()); } @@ -166,9 +169,9 @@ public void parameterPropagation_overrideByCallOptions() { callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), tracers); - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(RequestInfo.class); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), - same(anotherExecutor), any(CallCredentials.MetadataApplier.class)); + same(mockExecutor), any(CallCredentials.MetadataApplier.class)); RequestInfo info = infoCaptor.getValue(); assertSame(transportAttrs, info.getTransportAttrs()); assertSame(method, info.getMethodDescriptor()); @@ -186,7 +189,7 @@ public void parameterPropagation_transportSetSecurityLevel() { transport.newStream(method, origHeaders, callOptions, tracers); - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(RequestInfo.class); verify(mockCreds).applyRequestMetadata( infoCaptor.capture(), same(mockExecutor), any(io.grpc.CallCredentials.MetadataApplier.class)); @@ -210,9 +213,9 @@ public void parameterPropagation_callOptionsSetAuthority() { callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), tracers); - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(RequestInfo.class); verify(mockCreds).applyRequestMetadata( - infoCaptor.capture(), same(anotherExecutor), + infoCaptor.capture(), same(mockExecutor), any(io.grpc.CallCredentials.MetadataApplier.class)); RequestInfo info = infoCaptor.getValue(); assertSame(method, info.getMethodDescriptor()); @@ -264,7 +267,7 @@ public void applyMetadata_inline() { @Test public void fail_inline() { - final Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); + final Status error = Status.UNAVAILABLE.withDescription("channel not secure for creds"); when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); doAnswer(new Answer() { @Override @@ -290,6 +293,38 @@ public Void answer(InvocationOnMock invocation) throws Throwable { verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } + // If the creds return an error that is inappropriate to directly propagate from the control plane + // to the call, it should be converted to an INTERNAL error. + @Test + public void fail_inline_inappropriate_error() { + final Status error = Status.NOT_FOUND.withDescription("channel not secure for creds"); + when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + CallCredentials.MetadataApplier applier = + (CallCredentials.MetadataApplier) invocation.getArguments()[2]; + applier.fail(error); + return null; + } + }).when(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), any(CallCredentials.MetadataApplier.class)); + + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); + + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + assertThat(stream.getError().getCode()).isEqualTo(Code.INTERNAL); + assertThat(stream.getError().getDescription()).contains("Inappropriate"); + assertThat(stream.getError().getCause()).isNull(); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); + } + @Test public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); @@ -298,7 +333,8 @@ public void applyMetadata_delayed() { DelayedStream stream = (DelayedStream) transport.newStream( method, origHeaders, callOptions, tracers); - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor applierCaptor = + ArgumentCaptor.forClass(CallCredentials.MetadataApplier.class); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream( @@ -324,7 +360,8 @@ public void applyMetadata_delayed() { @Test public void delayedShutdown_shutdownShutdownNowThenApply() { transport.newStream(method, origHeaders, callOptions, tracers); - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor applierCaptor = + ArgumentCaptor.forClass(CallCredentials.MetadataApplier.class); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); @@ -345,7 +382,8 @@ public void delayedShutdown_shutdownShutdownNowThenApply() { @Test public void delayedShutdown_shutdownThenApplyThenShutdownNow() { transport.newStream(method, origHeaders, callOptions, tracers); - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor applierCaptor = + ArgumentCaptor.forClass(CallCredentials.MetadataApplier.class); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); @@ -373,7 +411,8 @@ public void delayedShutdown_shutdownMulti() { transport.newStream(method, origHeaders, callOptions, tracers); transport.newStream(method, origHeaders, callOptions, tracers); transport.newStream(method, origHeaders, callOptions, tracers); - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor applierCaptor = + ArgumentCaptor.forClass(CallCredentials.MetadataApplier.class); verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); applierCaptor.getAllValues().get(1).apply(headers); @@ -401,11 +440,12 @@ public void fail_delayed() { DelayedStream stream = (DelayedStream) transport.newStream( method, origHeaders, callOptions, tracers); - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor applierCaptor = + ArgumentCaptor.forClass(CallCredentials.MetadataApplier.class); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); + Status error = Status.UNAVAILABLE.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); verify(mockTransport, never()).newStream( diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index e409f2f9df4..d19c60abe90 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -395,7 +395,7 @@ public void callOptionsPropagatedToTransport() { @Test public void methodInfoDeadlinePropagatedToStream() { - ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); CallOptions callOptions = baseCallOptions.withDeadline(Deadline.after(2000, SECONDS)); // Case: config Deadline expires later than CallOptions Deadline @@ -786,7 +786,7 @@ public void deadlineExceededBeforeCallStarted() { verify(callListener, timeout(1000)).onClose(statusCaptor.capture(), any(Metadata.class)); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); assertThat(statusCaptor.getValue().getDescription()) - .startsWith("ClientCall started after deadline exceeded"); + .startsWith("ClientCall started after CallOptions deadline was exceeded"); verifyNoInteractions(clientStreamProvider); } diff --git a/core/src/test/java/io/grpc/internal/ConfigSelectingClientCallTest.java b/core/src/test/java/io/grpc/internal/ConfigSelectingClientCallTest.java index 9b3f8ad3b23..85fe7d30b85 100644 --- a/core/src/test/java/io/grpc/internal/ConfigSelectingClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/ConfigSelectingClientCallTest.java @@ -121,6 +121,31 @@ public void selectionErrorPropagatedToListener() { InternalConfigSelector configSelector = new InternalConfigSelector() { @Override public Result selectConfig(PickSubchannelArgs args) { + return Result.forError(Status.DEADLINE_EXCEEDED); + } + }; + + ClientCall configSelectingClientCall = new ConfigSelectingClientCall<>( + configSelector, + channel, + MoreExecutors.directExecutor(), + method, + CallOptions.DEFAULT); + configSelectingClientCall.start(callListener, new Metadata()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(callListener).onClose(statusCaptor.capture(), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.DEADLINE_EXCEEDED); + + // The call should not delegate to null and fail methods with NPE. + configSelectingClientCall.request(1); + } + + @Test + public void selectionErrorPropagatedToListener_inappropriateStatus() { + InternalConfigSelector configSelector = new InternalConfigSelector() { + @Override + public Result selectConfig(PickSubchannelArgs args) { + // This status code is considered inappropriate to propagate from the control plane... return Result.forError(Status.FAILED_PRECONDITION); } }; @@ -132,9 +157,10 @@ public Result selectConfig(PickSubchannelArgs args) { method, CallOptions.DEFAULT); configSelectingClientCall.start(callListener, new Metadata()); - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); verify(callListener).onClose(statusCaptor.capture(), any(Metadata.class)); - assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.FAILED_PRECONDITION); + // ... so it should be represented as an internal error to highlight the control plane bug. + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.INTERNAL); // The call should not delegate to null and fail methods with NPE. configSelectingClientCall.request(1); diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index 290e2b9de65..45682b3a385 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -21,19 +21,24 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import com.google.common.util.concurrent.MoreExecutors; import io.grpc.ClientCall; import io.grpc.ClientCall.Listener; +import io.grpc.Context; import io.grpc.Deadline; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingTestUtil; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.StatusException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -63,12 +68,13 @@ public class DelayedClientCallTest { public void allMethodsForwarded() throws Exception { DelayedClientCall delayedClientCall = new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); - delayedClientCall.setCall(mockRealCall); + callMeMaybe(delayedClientCall.setCall(mockRealCall)); ForwardingTestUtil.testMethodsForwarded( ClientCall.class, mockRealCall, delayedClientCall, - Arrays.asList(ClientCall.class.getMethod("toString")), + Arrays.asList(ClientCall.class.getMethod("toString"), + ClientCall.class.getMethod("start", Listener.class, Metadata.class)), new ForwardingTestUtil.ArgumentProvider() { @Override public Object get(Method method, int argPos, Class clazz) { @@ -101,8 +107,9 @@ public void listenerEventsPropagated() { DelayedClientCall delayedClientCall = new DelayedClientCall<>( callExecutor, fakeClock.getScheduledExecutorService(), Deadline.after(10, SECONDS)); delayedClientCall.start(listener, new Metadata()); - delayedClientCall.setCall(mockRealCall); - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + callMeMaybe(delayedClientCall.setCall(mockRealCall)); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); Listener realCallListener = listenerCaptor.getValue(); Metadata metadata = new Metadata(); @@ -119,4 +126,110 @@ public void listenerEventsPropagated() { verify(listener).onClose(statusCaptor.capture(), eq(trailer)); assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.DATA_LOSS); } + + @Test + public void setCallThenStart() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + callMeMaybe(delayedClientCall.setCall(mockRealCall)); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + verify(mockRealCall).request(1); + realCallListener.onMessage(1); + verify(listener).onMessage(1); + } + + @Test + public void startThenSetCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + verify(mockRealCall).request(1); + realCallListener.onMessage(1); + verify(listener).onMessage(1); + } + + @Test + @SuppressWarnings("unchecked") + public void cancelThenSetCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + delayedClientCall.cancel("cancel", new StatusException(Status.CANCELLED)); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNull(); + verify(mockRealCall, never()).start(any(Listener.class), any(Metadata.class)); + verify(mockRealCall, never()).request(1); + verify(mockRealCall, never()).cancel(any(), any()); + verify(listener).onClose(any(), any()); + } + + @Test + @SuppressWarnings("unchecked") + public void setCallThenCancel() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + delayedClientCall.cancel("cancel", new StatusException(Status.CANCELLED)); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + verify(mockRealCall).request(1); + verify(mockRealCall).cancel(any(), any()); + realCallListener.onClose(Status.CANCELLED, null); + verify(listener).onClose(Status.CANCELLED, null); + } + + @Test + public void delayedCallsRunUnderContext() throws Exception { + Context.Key contextKey = Context.key("foo"); + Object goldenValue = new Object(); + DelayedClientCall delayedClientCall = + Context.current().withValue(contextKey, goldenValue).call(() -> + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null)); + AtomicReference readyContext = new AtomicReference<>(); + delayedClientCall.start(new ClientCall.Listener() { + @Override public void onReady() { + readyContext.set(Context.current()); + } + }, new Metadata()); + AtomicReference startContext = new AtomicReference<>(); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override public void start(Listener listener, Metadata metadata) { + startContext.set(Context.current()); + listener.onReady(); // Delayed until call finishes draining + assertThat(readyContext.get()).isNull(); + super.start(listener, metadata); + } + }); + assertThat(r).isNotNull(); + r.run(); + assertThat(contextKey.get(startContext.get())).isEqualTo(goldenValue); + assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue); + } + + private void callMeMaybe(Runnable r) { + if (r != null) { + r.run(); + } + } } diff --git a/core/src/test/java/io/grpc/internal/FakeClock.java b/core/src/test/java/io/grpc/internal/FakeClock.java index d708af5f25d..9cc9178f1ff 100644 --- a/core/src/test/java/io/grpc/internal/FakeClock.java +++ b/core/src/test/java/io/grpc/internal/FakeClock.java @@ -88,13 +88,21 @@ public long currentTimeNanos() { public class ScheduledTask extends AbstractFuture implements ScheduledFuture { public final Runnable command; - public final long dueTimeNanos; + public long dueTimeNanos; - ScheduledTask(long dueTimeNanos, Runnable command) { - this.dueTimeNanos = dueTimeNanos; + ScheduledTask(Runnable command) { this.command = command; } + void run() { + command.run(); + set(null); + } + + void setDueTimeNanos(long dueTimeNanos) { + this.dueTimeNanos = dueTimeNanos; + } + @Override public boolean cancel(boolean mayInterruptIfRunning) { scheduledTasks.remove(this); dueTasks.remove(this); @@ -116,10 +124,6 @@ public class ScheduledTask extends AbstractFuture implements ScheduledFutu } } - void complete() { - set(null); - } - @Override public String toString() { return "[due=" + dueTimeNanos + ", task=" + command + "]"; @@ -132,24 +136,33 @@ private class ScheduledExecutorImpl implements ScheduledExecutorService { throw new UnsupportedOperationException(); } - @Override public ScheduledFuture schedule(Runnable cmd, long delay, TimeUnit unit) { - ScheduledTask task = new ScheduledTask(currentTimeNanos + unit.toNanos(delay), cmd); + private void schedule(ScheduledTask task, long delay, TimeUnit unit) { + task.setDueTimeNanos(currentTimeNanos + unit.toNanos(delay)); if (delay > 0) { scheduledTasks.add(task); } else { dueTasks.add(task); } + } + + @Override public ScheduledFuture schedule(Runnable cmd, long delay, TimeUnit unit) { + ScheduledTask task = new ScheduledTask(cmd); + schedule(task, delay, unit); return task; } @Override public ScheduledFuture scheduleAtFixedRate( - Runnable command, long initialDelay, long period, TimeUnit unit) { - throw new UnsupportedOperationException(); + Runnable cmd, long initialDelay, long period, TimeUnit unit) { + ScheduledTask task = new ScheduleAtFixedRateTask(cmd, period, unit); + schedule(task, initialDelay, unit); + return task; } @Override public ScheduledFuture scheduleWithFixedDelay( - Runnable command, long initialDelay, long delay, TimeUnit unit) { - throw new UnsupportedOperationException(); + Runnable cmd, long initialDelay, long delay, TimeUnit unit) { + ScheduledTask task = new ScheduleWithFixedDelayTask(cmd, delay, unit); + schedule(task, initialDelay, unit); + return task; } @Override public boolean awaitTermination(long timeout, TimeUnit unit) { @@ -206,6 +219,41 @@ private class ScheduledExecutorImpl implements ScheduledExecutorService { // Since it is being enqueued immediately, no point in tracing the future for cancellation. Future unused = schedule(command, 0, TimeUnit.NANOSECONDS); } + + class ScheduleAtFixedRateTask extends ScheduledTask { + final long periodNanos; + + public ScheduleAtFixedRateTask(Runnable command, long period, TimeUnit unit) { + super(command); + this.periodNanos = unit.toNanos(period); + } + + @Override void run() { + long startTimeNanos = currentTimeNanos; + command.run(); + if (!isCancelled()) { + schedule(this, startTimeNanos + periodNanos - currentTimeNanos, TimeUnit.NANOSECONDS); + } + } + } + + class ScheduleWithFixedDelayTask extends ScheduledTask { + + final long delayNanos; + + ScheduleWithFixedDelayTask(Runnable command, long delay, TimeUnit unit) { + super(command); + this.delayNanos = unit.toNanos(delay); + } + + @Override + void run() { + command.run(); + if (!isCancelled()) { + schedule(this, delayNanos, TimeUnit.NANOSECONDS); + } + } + } } /** @@ -258,8 +306,7 @@ public int runDueTasks() { } ScheduledTask task; while ((task = dueTasks.poll()) != null) { - task.command.run(); - task.complete(); + task.run(); count++; } } @@ -357,7 +404,7 @@ public int numPendingTasks(TaskFilter filter) { public long currentTimeMillis() { // Normally millis and nanos are of different epochs. Add an offset to simulate that. - return TimeUnit.NANOSECONDS.toMillis(currentTimeNanos + 123456789L); + return TimeUnit.NANOSECONDS.toMillis(currentTimeNanos + 1234567890123456789L); } /** diff --git a/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java b/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java index b778f25e5de..8ce45bc77cf 100644 --- a/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java @@ -36,13 +36,10 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -/** - * Tests for {@link ForwardingReadableBuffer}. - */ +/** Tests for {@link ForwardingReadableBuffer}. */ @RunWith(JUnit4.class) public class ForwardingReadableBufferTest { - @Rule - public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private ReadableBuffer delegate; private ForwardingReadableBuffer buffer; @@ -55,10 +52,7 @@ public void setUp() { @Test public void allMethodsForwarded() throws Exception { ForwardingTestUtil.testMethodsForwarded( - ReadableBuffer.class, - delegate, - buffer, - Collections.emptyList()); + ReadableBuffer.class, delegate, buffer, Collections.emptyList()); } @Test @@ -99,7 +93,7 @@ public void readBytes() { @Test public void readBytes_overload1() { - ByteBuffer dest = mock(ByteBuffer.class); + ByteBuffer dest = ByteBuffer.allocate(0); buffer.readBytes(dest); verify(delegate).readBytes(dest); diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 7e3f6e7db4e..bd2864ecc9d 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -27,19 +28,26 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import com.google.common.collect.Lists; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.LoadBalancer.PickResult; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil.Http2Error; import io.grpc.testing.TestMethodDescriptors; +import java.util.ArrayList; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; /** Unit tests for {@link GrpcUtil}. */ @RunWith(JUnit4.class) @@ -51,6 +59,11 @@ public class GrpcUtilTest { @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Captor + private ArgumentCaptor statusCaptor; + @Test public void http2ErrorForCode() { @@ -258,6 +271,69 @@ public void getTransportFromPickResult_errorPickResult_failFast() { verify(listener).closed(eq(status), eq(RpcProgress.PROCESSED), any(Metadata.class)); } + /* Status codes that a control plane should not be returned get replaced by INTERNAL. */ + @Test + public void getTransportFromPickResult_errorPickResult_noInappropriateControlPlaneStatus() { + + // These are NOT appropriate for a control plane to return. + ArrayList inappropriateStatus = Lists.newArrayList( + Status.INVALID_ARGUMENT.withDescription("bad one").withCause(new RuntimeException()), + Status.NOT_FOUND.withDescription("not here").withCause(new RuntimeException()), + Status.ALREADY_EXISTS.withDescription("not again").withCause(new RuntimeException()), + Status.FAILED_PRECONDITION.withDescription("naah").withCause(new RuntimeException()), + Status.ABORTED.withDescription("nope").withCause(new RuntimeException()), + Status.OUT_OF_RANGE.withDescription("outta range").withCause(new RuntimeException()), + Status.DATA_LOSS.withDescription("lost").withCause(new RuntimeException())); + + for (Status status : inappropriateStatus) { + PickResult pickResult = PickResult.withError(status); + ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, false); + + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); + ClientStreamListener listener = mock(ClientStreamListener.class); + stream.start(listener); + + verify(listener).closed(statusCaptor.capture(), eq(RpcProgress.PROCESSED), + any(Metadata.class)); + Status usedStatus = statusCaptor.getValue(); + assertThat(usedStatus.getCode()).isEqualTo(Code.INTERNAL); + assertThat(usedStatus.getDescription()).contains("Inappropriate status"); + assertThat(usedStatus.getCause()).isInstanceOf(RuntimeException.class); + } + } + + /* Status codes a control plane can return are not replaced. */ + @Test + public void getTransportFromPickResult_errorPickResult_appropriateControlPlaneStatus() { + + // These ARE appropriate for a control plane to return. + ArrayList inappropriateStatus = Lists.newArrayList( + Status.CANCELLED, + Status.UNKNOWN, + Status.DEADLINE_EXCEEDED, + Status.PERMISSION_DENIED, + Status.RESOURCE_EXHAUSTED, + Status.UNIMPLEMENTED, + Status.INTERNAL, + Status.UNAVAILABLE, + Status.UNAUTHENTICATED); + + for (Status status : inappropriateStatus) { + PickResult pickResult = PickResult.withError(status); + ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, false); + + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); + ClientStreamListener listener = mock(ClientStreamListener.class); + stream.start(listener); + + verify(listener).closed(eq(status), eq(RpcProgress.PROCESSED), any(Metadata.class)); + } + } + @Test public void getTransportFromPickResult_dropPickResult_waitForReady() { Status status = Status.UNAVAILABLE; diff --git a/netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java b/core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java similarity index 99% rename from netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java rename to core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java index 8dfeb990e2b..c58ed6ea160 100644 --- a/netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java +++ b/core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.netty; +package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index bc5b7cde651..286c48ebd62 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -35,9 +35,11 @@ import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; +import io.grpc.InternalGlobalInterceptors; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import io.grpc.NameResolver; +import io.grpc.StaticTestingClassLoader; import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; @@ -47,12 +49,14 @@ import java.net.SocketAddress; import java.net.URI; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -78,6 +82,14 @@ public ClientCall interceptCall( return next.newCall(method, callOptions); } }; + private static final ClientInterceptor DUMMY_USER_INTERCEPTOR1 = + new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @@ -90,7 +102,12 @@ public ClientCall interceptCall( private ManagedChannelImplBuilder builder; private ManagedChannelImplBuilder directAddressBuilder; private final FakeClock clock = new FakeClock(); - + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader( + getClass().getClassLoader(), + Pattern.compile( + "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|" + + "io\\.grpc\\.internal\\.[^.]+")); @Before public void setUp() throws Exception { @@ -447,6 +464,86 @@ public void getEffectiveInterceptors_disableBoth() { assertThat(effectiveInterceptors).containsExactly(DUMMY_USER_INTERCEPTOR); } + @Test + public void getEffectiveInterceptors_callsGetGlobalInterceptors() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassLoaderCallsGet implements Runnable { + + @Override + public void run() { + ManagedChannelImplBuilder builder = + new ManagedChannelImplBuilder( + DUMMY_TARGET, + new UnsupportedClientTransportFactoryBuilder(), + new FixedPortProvider(DUMMY_PORT)); + List effectiveInterceptors = builder.getEffectiveInterceptors(); + assertThat(effectiveInterceptors).hasSize(2); + try { + InternalGlobalInterceptors.setInterceptorsTracers( + Arrays.asList(DUMMY_USER_INTERCEPTOR), + Collections.emptyList(), + Collections.emptyList()); + fail("exception expected"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().contains("Set cannot be called after any get call"); + } + } + } + + @Test + public void getEffectiveInterceptors_callsSetGlobalInterceptors() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsSet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassLoaderCallsSet implements Runnable { + + @Override + public void run() { + InternalGlobalInterceptors.setInterceptorsTracers( + Arrays.asList(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1), + Collections.emptyList(), + Collections.emptyList()); + ManagedChannelImplBuilder builder = + new ManagedChannelImplBuilder( + DUMMY_TARGET, + new UnsupportedClientTransportFactoryBuilder(), + new FixedPortProvider(DUMMY_PORT)); + List effectiveInterceptors = builder.getEffectiveInterceptors(); + assertThat(effectiveInterceptors) + .containsExactly(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1); + } + } + + @Test + public void getEffectiveInterceptors_setEmptyGlobalInterceptors() throws Exception { + Class runnable = + classLoader.loadClass(StaticTestingClassLoaderCallsSetEmpty.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassLoaderCallsSetEmpty implements Runnable { + + @Override + public void run() { + InternalGlobalInterceptors.setInterceptorsTracers( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + ManagedChannelImplBuilder builder = + new ManagedChannelImplBuilder( + DUMMY_TARGET, + new UnsupportedClientTransportFactoryBuilder(), + new FixedPortProvider(DUMMY_PORT)); + List effectiveInterceptors = builder.getEffectiveInterceptors(); + assertThat(effectiveInterceptors).isEmpty(); + } + } + @Test public void idleTimeout() { assertEquals(ManagedChannelImplBuilder.IDLE_MODE_DEFAULT_TIMEOUT_MILLIS, diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index 30e137cba22..00eb154cf82 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -154,6 +154,7 @@ public String getPolicyName() { @Before @SuppressWarnings("deprecation") // For NameResolver.Listener public void setUp() { + when(mockLoadBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(true); LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); when(mockNameResolver.getServiceAuthority()).thenReturn(AUTHORITY); when(mockNameResolverFactory @@ -220,7 +221,7 @@ public void newCallExitsIdleness() throws Exception { ArgumentCaptor resolvedAddressCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resolvedAddressCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); assertThat(resolvedAddressCaptor.getValue().getAddresses()) .containsExactlyElementsIn(servers); } @@ -324,7 +325,7 @@ public void realTransportsHoldsOffIdleness() throws Exception { call.start(mockCallListener, new Metadata()); // Verify that we have exited the idle mode - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); deliverResolutionResult(); Helper helper = helperCaptor.getValue(); @@ -372,7 +373,7 @@ public void enterIdleWhileRealTransportInProgress() { call.start(mockCallListener, new Metadata()); // Verify that we have exited the idle mode - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); deliverResolutionResult(); Helper helper = helperCaptor.getValue(); @@ -411,7 +412,7 @@ public void enterIdleWhileRealTransportInProgress() { public void updateSubchannelAddresses_newAddressConnects() { ClientCall call = channel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, new Metadata()); // Create LB - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); deliverResolutionResult(); Helper helper = helperCaptor.getValue(); @@ -435,7 +436,7 @@ public void updateSubchannelAddresses_newAddressConnects() { public void updateSubchannelAddresses_existingAddressDoesNotConnect() { ClientCall call = channel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, new Metadata()); // Create LB - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); deliverResolutionResult(); Helper helper = helperCaptor.getValue(); @@ -460,7 +461,7 @@ public void oobTransportDoesNotAffectIdleness() { call.start(mockCallListener, new Metadata()); // Verify that we have exited the idle mode - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); Helper helper = helperCaptor.getValue(); deliverResolutionResult(); @@ -509,7 +510,7 @@ public void oobTransportDoesNotAffectIdleness() { public void updateOobChannelAddresses_newAddressConnects() { ClientCall call = channel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, new Metadata()); // Create LB - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); deliverResolutionResult(); Helper helper = helperCaptor.getValue(); @@ -533,7 +534,7 @@ public void updateOobChannelAddresses_newAddressConnects() { public void updateOobChannelAddresses_existingAddressDoesNotConnect() { ClientCall call = channel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, new Metadata()); // Create LB - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); Helper helper = helperCaptor.getValue(); deliverResolutionResult(); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index f47954e2215..6c2a398fe5f 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -140,9 +140,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.logging.Handler; -import java.util.logging.Level; -import java.util.logging.LogRecord; import javax.annotation.Nullable; import org.junit.After; import org.junit.Assert; @@ -284,7 +281,6 @@ public String getPolicyName() { private boolean requestConnection = true; private BlockingQueue transports; private boolean panicExpected; - private final List logs = new ArrayList<>(); @Captor private ArgumentCaptor resolvedAddressCaptor; @@ -319,7 +315,7 @@ public void run() { assertEquals(numExpectedTasks, timer.numPendingTasks()); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); } @@ -327,7 +323,7 @@ public void run() { @Before public void setUp() throws Exception { - when(mockLoadBalancer.canHandleEmptyAddressListFromNameResolution()).thenCallRealMethod(); + when(mockLoadBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(true); LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); expectedUri = new URI(TARGET); transports = TestUtils.captureTransports(mockTransportFactory); @@ -336,22 +332,6 @@ public void setUp() throws Exception { when(executorPool.getObject()).thenReturn(executor.getScheduledExecutorService()); when(balancerRpcExecutorPool.getObject()) .thenReturn(balancerRpcExecutor.getScheduledExecutorService()); - Handler handler = new Handler() { - @Override - public void publish(LogRecord record) { - logs.add(record); - } - - @Override - public void flush() { - } - - @Override - public void close() throws SecurityException { - } - }; - ManagedChannelImpl.logger.addHandler(handler); - ManagedChannelImpl.logger.setLevel(Level.ALL); channelBuilder = new ManagedChannelImplBuilder(TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); @@ -422,7 +402,8 @@ public void createSubchannel_resolverOverrideAuthority() { Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); requestConnectionSafely(helper, subchannel); - ArgumentCaptor transportOptionCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor transportOptionCaptor = + ArgumentCaptor.forClass(ClientTransportOptions.class); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), transportOptionCaptor.capture(), any(ChannelLogger.class)); @@ -447,7 +428,8 @@ public void createSubchannel_channelBuilderOverrideAuthority() { final Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); requestConnectionSafely(helper, subchannel); - ArgumentCaptor transportOptionCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor transportOptionCaptor = + ArgumentCaptor.forClass(ClientTransportOptions.class); verify(mockTransportFactory) .newClientTransport( any(SocketAddress.class), transportOptionCaptor.capture(), any(ChannelLogger.class)); @@ -514,7 +496,7 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), ClientCall call = channel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, headers); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); // Make the transport available @@ -540,7 +522,7 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); verify(mockTransport).newStream( same(method), same(headers), callOptionsCaptor.capture(), ArgumentMatchers.any()); @@ -596,7 +578,7 @@ public ClientCall interceptCall( ClientCall call = channel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, headers); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); // Make the transport available @@ -619,7 +601,7 @@ public ClientCall interceptCall( updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); verify(mockTransport).newStream( same(method), same(headers), callOptionsCaptor.capture(), ArgumentMatchers.any()); @@ -948,7 +930,7 @@ public void noMoreCallbackAfterLoadBalancerShutdown() { FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); - verify(mockLoadBalancer).handleResolvedAddresses(resolvedAddressCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); assertThat(resolvedAddressCaptor.getValue().getAddresses()).containsExactly(addressGroup); SubchannelStateListener stateListener1 = mock(SubchannelStateListener.class); @@ -1163,8 +1145,9 @@ public void nameResolutionFailed_delayedTransportShutdownCancelsBackoff() { } @Test - public void nameResolverReturnsEmptySubLists_becomeErrorByDefault() throws Exception { - String errorDescription = "NameResolver returned no usable address"; + public void nameResolverReturnsEmptySubLists_resolutionRetry() throws Exception { + // The mock LB is set to reject the addresses. + when(mockLoadBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(false); // Pass a FakeNameResolverFactory with an empty list and LB config FakeNameResolverFactory nameResolverFactory = @@ -1177,21 +1160,12 @@ public void nameResolverReturnsEmptySubLists_becomeErrorByDefault() throws Excep channelBuilder.nameResolverFactory(nameResolverFactory); createChannel(); - // LoadBalancer received the error - verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); - verify(mockLoadBalancer).handleNameResolutionError(statusCaptor.capture()); - Status status = statusCaptor.getValue(); - assertSame(Status.Code.UNAVAILABLE, status.getCode()); - assertThat(status.getDescription()).startsWith(errorDescription); - // A resolution retry has been scheduled assertEquals(1, timer.numPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER)); } @Test public void nameResolverReturnsEmptySubLists_optionallyAllowed() throws Exception { - when(mockLoadBalancer.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); - // Pass a FakeNameResolverFactory with an empty list and LB config FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); @@ -1213,7 +1187,7 @@ public void nameResolverReturnsEmptySubLists_optionallyAllowed() throws Exceptio verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).isEmpty(); assertThat(resultCaptor.getValue().getLoadBalancingPolicyConfig()).isEqualTo(parsedLbConfig); @@ -1234,7 +1208,7 @@ public void loadBalancerThrowsInHandleResolvedAddresses() { createChannel(); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); - doThrow(ex).when(mockLoadBalancer).handleResolvedAddresses(any(ResolvedAddresses.class)); + doThrow(ex).when(mockLoadBalancer).acceptResolvedAddresses(any(ResolvedAddresses.class)); // NameResolver returns addresses. nameResolverFactory.allResolved(); @@ -1296,7 +1270,7 @@ public void firstResolvedServerFailedToConnect() throws Exception { // Simulate name resolution results EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(resolvedAddrs); - inOrder.verify(mockLoadBalancer).handleResolvedAddresses(resolvedAddressCaptor.capture()); + inOrder.verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); assertThat(resolvedAddressCaptor.getValue().getAddresses()).containsExactly(addressGroup); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); @@ -1446,7 +1420,7 @@ public void allServersFailedToConnect() throws Exception { // Simulate name resolution results EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(resolvedAddrs); - inOrder.verify(mockLoadBalancer).handleResolvedAddresses(resolvedAddressCaptor.capture()); + inOrder.verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); assertThat(resolvedAddressCaptor.getValue().getAddresses()).containsExactly(addressGroup); Subchannel subchannel = @@ -1591,103 +1565,6 @@ public void run() { timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, TimeUnit.SECONDS); } - @Test - public void subchannelConnectionBroken_noLbRefreshingResolver_logWarningAndTriggeRefresh() { - FakeNameResolverFactory nameResolverFactory = - new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) - .build(); - channelBuilder.nameResolverFactory(nameResolverFactory); - createChannel(); - FakeNameResolverFactory.FakeNameResolver resolver = - Iterables.getOnlyElement(nameResolverFactory.resolvers); - assertThat(resolver.refreshCalled).isEqualTo(0); - - Subchannel subchannel = - createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); - InternalSubchannel internalSubchannel = - (InternalSubchannel) subchannel.getInternalSubchannel(); - internalSubchannel.obtainActiveTransport(); - MockClientTransportInfo transportInfo = transports.poll(); - - // Break subchannel connection - transportInfo.listener.transportShutdown(Status.UNAVAILABLE.withDescription("unreachable")); - LogRecord log = Iterables.getOnlyElement(logs); - assertThat(log.getLevel()).isEqualTo(Level.WARNING); - assertThat(log.getMessage()).isEqualTo( - "LoadBalancer should call Helper.refreshNameResolution() to refresh name resolution if " - + "subchannel state becomes TRANSIENT_FAILURE or IDLE. This will no longer happen " - + "automatically in the future releases"); - assertThat(resolver.refreshCalled).isEqualTo(1); - } - - @Test - public void subchannelConnectionBroken_ResolverRefreshedByLb() { - FakeNameResolverFactory nameResolverFactory = - new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) - .build(); - channelBuilder.nameResolverFactory(nameResolverFactory); - createChannel(); - FakeNameResolverFactory.FakeNameResolver resolver = - Iterables.getOnlyElement(nameResolverFactory.resolvers); - assertThat(resolver.refreshCalled).isEqualTo(0); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); - verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); - helper = helperCaptor.getValue(); - - SubchannelStateListener listener = new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - // Normal LoadBalancer should refresh name resolution when some subchannel enters - // TRANSIENT_FAILURE or IDLE - if (newState.getState() == TRANSIENT_FAILURE || newState.getState() == IDLE) { - helper.refreshNameResolution(); - } - } - }; - Subchannel subchannel = - createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, listener); - InternalSubchannel internalSubchannel = - (InternalSubchannel) subchannel.getInternalSubchannel(); - internalSubchannel.obtainActiveTransport(); - MockClientTransportInfo transportInfo = transports.poll(); - - // Break subchannel connection and simulate load balancer refreshing name resolution - transportInfo.listener.transportShutdown(Status.UNAVAILABLE.withDescription("unreachable")); - assertThat(logs).isEmpty(); - assertThat(resolver.refreshCalled).isEqualTo(1); - } - - @Test - public void subchannelConnectionBroken_ignoreRefreshNameResolutionCheck_noRefresh() { - FakeNameResolverFactory nameResolverFactory = - new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) - .build(); - channelBuilder.nameResolverFactory(nameResolverFactory); - createChannel(); - FakeNameResolverFactory.FakeNameResolver resolver = - Iterables.getOnlyElement(nameResolverFactory.resolvers); - assertThat(resolver.refreshCalled).isEqualTo(0); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); - verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); - helper = helperCaptor.getValue(); - helper.ignoreRefreshNameResolutionCheck(); - - Subchannel subchannel = - createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); - InternalSubchannel internalSubchannel = - (InternalSubchannel) subchannel.getInternalSubchannel(); - internalSubchannel.obtainActiveTransport(); - MockClientTransportInfo transportInfo = transports.poll(); - - // Break subchannel connection - transportInfo.listener.transportShutdown(Status.UNAVAILABLE.withDescription("unreachable")); - assertThat(logs).isEmpty(); - assertThat(resolver.refreshCalled).isEqualTo(0); - } - @Test public void subchannelStringableBeforeStart() { createChannel(); @@ -2395,10 +2272,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(RequestInfo.class); + ArgumentCaptor executorArgumentCaptor = ArgumentCaptor.forClass(Executor.class); + ArgumentCaptor applierCaptor = + ArgumentCaptor.forClass(CallCredentials.MetadataApplier.class); verify(creds).applyRequestMetadata(infoCaptor.capture(), - same(executor.getScheduledExecutorService()), applierCaptor.capture()); + executorArgumentCaptor.capture(), applierCaptor.capture()); + assertSame(offloadExecutor, + ((ManagedChannelImpl.ExecutorHolder) executorArgumentCaptor.getValue()).getExecutor()); assertEquals("testValue", testKey.get(credsApplyContexts.poll())); assertEquals(AUTHORITY, infoCaptor.getValue().getAuthority()); assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); @@ -2423,7 +2304,9 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { call.start(mockCallListener, new Metadata()); verify(creds, times(2)).applyRequestMetadata(infoCaptor.capture(), - same(executor.getScheduledExecutorService()), applierCaptor.capture()); + executorArgumentCaptor.capture(), applierCaptor.capture()); + assertSame(offloadExecutor, + ((ManagedChannelImpl.ExecutorHolder) executorArgumentCaptor.getValue()).getExecutor()); assertEquals("testValue", testKey.get(credsApplyContexts.poll())); assertEquals(AUTHORITY, infoCaptor.getValue().getAuthority()); assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); @@ -2555,7 +2438,7 @@ public void getState_withRequestConnect() { // call getState() with requestConnection = true assertEquals(IDLE, channel.getState(true)); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); @@ -2799,6 +2682,40 @@ public void run() { panicExpected = true; } + @Test + public void panic_atStart() { + final RuntimeException panicReason = new RuntimeException("Simulated NR exception"); + final NameResolver failingResolver = new NameResolver() { + @Override public String getServiceAuthority() { + return "fake-authority"; + } + + @Override public void start(Listener2 listener) { + throw panicReason; + } + + @Override public void shutdown() {} + }; + channelBuilder.nameResolverFactory(new NameResolver.Factory() { + @Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return failingResolver; + } + + @Override public String getDefaultScheme() { + return "fakescheme"; + } + }); + createChannel(); + + // RPCs fail immediately + ClientCall call = + channel.newCall(method, CallOptions.DEFAULT.withoutWaitForReady()); + call.start(mockCallListener, new Metadata()); + executor.runDueTasks(); + verifyCallListenerClosed(mockCallListener, Status.Code.INTERNAL, panicReason); + panicExpected = true; + } + private void verifyPanicMode(Throwable cause) { panicExpected = true; @SuppressWarnings("unchecked") @@ -2818,7 +2735,7 @@ private void verifyPanicMode(Throwable cause) { private void verifyCallListenerClosed( ClientCall.Listener listener, Status.Code code, Throwable cause) { - ArgumentCaptor captor = ArgumentCaptor.forClass(null); + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); verify(listener).onClose(captor.capture(), any(Metadata.class)); Status rpcStatus = captor.getValue(); assertEquals(code, rpcStatus.getCode()); @@ -3429,7 +3346,7 @@ public void channelTracing_oobChannelCreationEvents() throws Exception { public void channelsAndSubchannels_instrumented_state() throws Exception { createChannel(); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); @@ -3733,7 +3650,7 @@ public double nextDouble() { ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); - verify(mockLoadBalancer).handleResolvedAddresses( + verify(mockLoadBalancer).acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(nameResolverFactory.servers) .build()); @@ -3839,7 +3756,7 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); - verify(mockLoadBalancer).handleResolvedAddresses( + verify(mockLoadBalancer).acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(nameResolverFactory.servers) .build()); @@ -4166,7 +4083,7 @@ public void disableServiceConfigLookUp_noDefaultConfig() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); assertThat(resultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY)).isNull(); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); @@ -4204,7 +4121,7 @@ public void disableServiceConfigLookUp_withDefaultConfig() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); assertThat(resultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY)).isNull(); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); @@ -4234,7 +4151,7 @@ public void enableServiceConfigLookUp_noDefaultConfig() throws Exception { createChannel(); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); @@ -4250,7 +4167,7 @@ public void enableServiceConfigLookUp_noDefaultConfig() throws Exception { nameResolverFactory.allResolved(); resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer, times(2)).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer, times(2)).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); } finally { @@ -4284,7 +4201,7 @@ public void enableServiceConfigLookUp_withDefaultConfig() throws Exception { createChannel(); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); } finally { @@ -4312,7 +4229,7 @@ public void enableServiceConfigLookUp_resolverReturnsNoConfig_withDefaultConfig( createChannel(); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); } finally { @@ -4338,7 +4255,7 @@ public void enableServiceConfigLookUp_resolverReturnsNoConfig_noDefaultConfig() createChannel(); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); } finally { @@ -4418,7 +4335,7 @@ public void healthCheckingConfigPropagated() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); assertThat(resultCaptor.getValue().getAttributes() .get(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG)) .containsExactly("serviceName", "service1"); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java index 709f6274de4..c25a0808584 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java @@ -163,7 +163,7 @@ public void managedChannelServiceConfig_parseMethodConfig() { .put("backoffMultiplier", 1.5D) .put("perAttemptRecvTimeout", "2.5s") .put("retryableStatusCodes", ImmutableList.of("UNAVAILABLE")) - .build()); + .buildOrThrow()); Map defaultMethodConfig = ImmutableMap.of( "name", ImmutableList.of(ImmutableMap.of()), "timeout", "4.321s"); @@ -225,7 +225,7 @@ public void retryConfig_emptyRetriableStatusCodesAllowedWithPerAttemptRecvTimeou .put("backoffMultiplier", 1.5D) .put("perAttemptRecvTimeout", "2.5s") .put("retryableStatusCodes", ImmutableList.of()) - .build(); + .buildOrThrow(); Map methodConfig = ImmutableMap.of( "name", ImmutableList.of(ImmutableMap.of()), "retryPolicy", retryPolicy); Map rawServiceConfig = @@ -242,7 +242,7 @@ public void retryConfig_PerAttemptRecvTimeoutUnsetAllowedIfRetryableStatusCodesN .put("maxBackoff", "10s") .put("backoffMultiplier", 1.5D) .put("retryableStatusCodes", ImmutableList.of("UNAVAILABLE")) - .build(); + .buildOrThrow(); Map methodConfig = ImmutableMap.of( "name", ImmutableList.of(ImmutableMap.of()), "retryPolicy", retryPolicy); Map rawServiceConfig = @@ -259,7 +259,7 @@ public void retryConfig_emptyRetriableStatusCodesNotAllowedWithPerAttemptRecvTim .put("maxBackoff", "10s") .put("backoffMultiplier", 1.5D) .put("retryableStatusCodes", ImmutableList.of()) - .build(); + .buildOrThrow(); Map methodConfig = ImmutableMap.of( "name", ImmutableList.of(ImmutableMap.of()), "retryPolicy", retryPolicy); Map rawServiceConfig = @@ -285,7 +285,7 @@ public void retryConfig_AllowPerAttemptRecvTimeoutZero() { .put("backoffMultiplier", 1.5D) .put("perAttemptRecvTimeout", "0s") .put("retryableStatusCodes", ImmutableList.of()) - .build(); + .buildOrThrow(); Map methodConfig = ImmutableMap.of( "name", ImmutableList.of(ImmutableMap.of()), "retryPolicy", retryPolicy); Map rawServiceConfig = diff --git a/netty/src/test/java/io/grpc/netty/MaxConnectionIdleManagerTest.java b/core/src/test/java/io/grpc/internal/MaxConnectionIdleManagerTest.java similarity index 62% rename from netty/src/test/java/io/grpc/netty/MaxConnectionIdleManagerTest.java rename to core/src/test/java/io/grpc/internal/MaxConnectionIdleManagerTest.java index d2ae98980d0..53566054a64 100644 --- a/netty/src/test/java/io/grpc/netty/MaxConnectionIdleManagerTest.java +++ b/core/src/test/java/io/grpc/internal/MaxConnectionIdleManagerTest.java @@ -14,17 +14,11 @@ * limitations under the License. */ -package io.grpc.netty; +package io.grpc.internal; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import io.grpc.internal.FakeClock; -import io.grpc.netty.MaxConnectionIdleManager.Ticker; -import io.netty.channel.ChannelHandlerContext; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,7 +30,7 @@ @RunWith(JUnit4.class) public class MaxConnectionIdleManagerTest { private final FakeClock fakeClock = new FakeClock(); - private final Ticker ticker = new Ticker() { + private final MaxConnectionIdleManager.Ticker ticker = new MaxConnectionIdleManager.Ticker() { @Override public long nanoTime() { return fakeClock.getTicker().read(); @@ -44,7 +38,7 @@ public long nanoTime() { }; @Mock - private ChannelHandlerContext ctx; + private Runnable closure; @Before public void setUp() { @@ -54,21 +48,21 @@ public void setUp() { @Test public void maxIdleReached() { MaxConnectionIdleManager maxConnectionIdleManager = - spy(new TestMaxConnectionIdleManager(123L, ticker)); + new MaxConnectionIdleManager(123L, ticker); - maxConnectionIdleManager.start(ctx, fakeClock.getScheduledExecutorService()); + maxConnectionIdleManager.start(closure, fakeClock.getScheduledExecutorService()); maxConnectionIdleManager.onTransportIdle(); fakeClock.forwardNanos(123L); - verify(maxConnectionIdleManager).close(eq(ctx)); + verify(closure).run(); } @Test public void maxIdleNotReachedAndReached() { MaxConnectionIdleManager maxConnectionIdleManager = - spy(new TestMaxConnectionIdleManager(123L, ticker)); + new MaxConnectionIdleManager(123L, ticker); - maxConnectionIdleManager.start(ctx, fakeClock.getScheduledExecutorService()); + maxConnectionIdleManager.start(closure, fakeClock.getScheduledExecutorService()); maxConnectionIdleManager.onTransportIdle(); fakeClock.forwardNanos(100L); // max idle not reached @@ -79,35 +73,25 @@ public void maxIdleNotReachedAndReached() { maxConnectionIdleManager.onTransportActive(); fakeClock.forwardNanos(100L); - verify(maxConnectionIdleManager, never()).close(any(ChannelHandlerContext.class)); + verify(closure, never()).run(); // max idle reached maxConnectionIdleManager.onTransportIdle(); fakeClock.forwardNanos(123L); - verify(maxConnectionIdleManager).close(eq(ctx)); + verify(closure).run(); } @Test public void shutdownThenMaxIdleReached() { MaxConnectionIdleManager maxConnectionIdleManager = - spy(new TestMaxConnectionIdleManager(123L, ticker)); + new MaxConnectionIdleManager(123L, ticker); - maxConnectionIdleManager.start(ctx, fakeClock.getScheduledExecutorService()); + maxConnectionIdleManager.start(closure, fakeClock.getScheduledExecutorService()); maxConnectionIdleManager.onTransportIdle(); maxConnectionIdleManager.onTransportTermination(); fakeClock.forwardNanos(123L); - verify(maxConnectionIdleManager, never()).close(any(ChannelHandlerContext.class)); - } - - private static class TestMaxConnectionIdleManager extends MaxConnectionIdleManager { - TestMaxConnectionIdleManager(long maxConnectionIdleInNanos, Ticker ticker) { - super(maxConnectionIdleInNanos, ticker); - } - - @Override - void close(ChannelHandlerContext ctx) { - } + verify(closure, never()).run(); } } diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index 5edf64ef85d..98ed0691458 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -49,6 +49,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.zip.GZIPOutputStream; import org.junit.Before; @@ -494,6 +495,8 @@ public void sizeEnforcingInputStream_markReset() throws IOException { } /** + * Verify stats were published through the tracer. + * * @param transportStats the transport level stats counters * @param clock the fakeClock to verify timestamp * @param sizes in the format {wire0, uncompressed0, wire1, uncompressed1, ...} @@ -507,7 +510,7 @@ private static void checkStats( for (int i = 0; i < count; i++) { assertEquals("inboundMessage(" + i + ")", tracer.nextInboundEvent()); assertEquals( - String.format("inboundMessageRead(%d, %d, -1)", i, sizes[i * 2]), + String.format(Locale.US, "inboundMessageRead(%d, %d, -1)", i, sizes[i * 2]), tracer.nextInboundEvent()); expectedWireSize += sizes[i * 2]; expectedUncompressedSize += sizes[i * 2 + 1]; diff --git a/core/src/test/java/io/grpc/internal/MessageFramerTest.java b/core/src/test/java/io/grpc/internal/MessageFramerTest.java index 91698acced7..07f717cb81d 100644 --- a/core/src/test/java/io/grpc/internal/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageFramerTest.java @@ -34,6 +34,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; +import java.util.Locale; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -372,6 +373,8 @@ private static void writeKnownLength(MessageFramer framer, byte[] bytes) { } /** + * Verify stats were published through the tracer. + * * @param sizes in the format {wire0, uncompressed0, wire1, uncompressed1, ...} */ private void checkStats(long... sizes) { @@ -382,7 +385,8 @@ private void checkStats(long... sizes) { for (int i = 0; i < count; i++) { assertEquals("outboundMessage(" + i + ")", tracer.nextOutboundEvent()); assertEquals( - String.format("outboundMessageSent(%d, %d, %d)", i, sizes[i * 2], sizes[i * 2 + 1]), + String.format( + Locale.US, "outboundMessageSent(%d, %d, %d)", i, sizes[i * 2], sizes[i * 2 + 1]), tracer.nextOutboundEvent()); expectedWireSize += sizes[i * 2]; expectedUncompressedSize += sizes[i * 2 + 1]; diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java index 720341da0cb..8bace289584 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java @@ -50,6 +50,7 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import java.net.SocketAddress; import java.util.List; @@ -89,6 +90,8 @@ public void uncaughtException(Thread t, Throwable e) { @Captor private ArgumentCaptor pickerCaptor; @Captor + private ArgumentCaptor connectivityStateCaptor; + @Captor private ArgumentCaptor createArgsCaptor; @Captor private ArgumentCaptor stateListenerCaptor; @@ -121,7 +124,7 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); verify(mockHelper).createSubchannel(createArgsCaptor.capture()); @@ -139,7 +142,7 @@ public void pickAfterResolved() throws Exception { @Test public void requestConnectionPicker() throws Exception { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); InOrder inOrder = inOrder(mockHelper, mockSubchannel); @@ -164,7 +167,7 @@ public void requestConnectionPicker() throws Exception { @Test public void refreshNameResolutionAfterSubchannelConnectionBroken() { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); verify(mockHelper).createSubchannel(any(CreateSubchannelArgs.class)); @@ -196,11 +199,11 @@ public void refreshNameResolutionAfterSubchannelConnectionBroken() { @Test public void pickAfterResolvedAndUnchanged() throws Exception { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); verify(mockSubchannel).start(any(SubchannelStateListener.class)); verify(mockSubchannel).requestConnection(); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); verify(mockSubchannel).updateAddresses(eq(servers)); verifyNoMoreInteractions(mockSubchannel); @@ -223,7 +226,7 @@ public void pickAfterResolvedAndChanged() throws Exception { InOrder inOrder = inOrder(mockHelper, mockSubchannel); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); verify(mockSubchannel).start(any(SubchannelStateListener.class)); @@ -233,7 +236,7 @@ public void pickAfterResolvedAndChanged() throws Exception { verify(mockSubchannel).requestConnection(); assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); inOrder.verify(mockSubchannel).updateAddresses(eq(newServers)); @@ -245,7 +248,7 @@ public void pickAfterResolvedAndChanged() throws Exception { public void pickAfterStateChangeAfterResolution() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); CreateSubchannelArgs args = createArgsCaptor.getValue(); @@ -288,6 +291,21 @@ public void nameResolutionError() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void nameResolutionError_emptyAddressList() throws Exception { + servers.clear(); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + verify(mockHelper).updateBalancingState(connectivityStateCaptor.capture(), + pickerCaptor.capture()); + PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); + assertThat(pickResult.getSubchannel()).isNull(); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(pickResult.getStatus().getDescription()).contains("returned no usable address"); + verify(mockSubchannel, never()).requestConnection(); + verifyNoMoreInteractions(mockHelper); + } + @Test public void nameResolutionSuccessAfterError() throws Exception { InOrder inOrder = inOrder(mockHelper); @@ -297,7 +315,7 @@ public void nameResolutionSuccessAfterError() throws Exception { .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); verify(mockSubchannel, never()).requestConnection(); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); CreateSubchannelArgs args = createArgsCaptor.getValue(); @@ -318,7 +336,7 @@ public void nameResolutionSuccessAfterError() throws Exception { @Test public void nameResolutionErrorWithStateChanges() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); verify(mockSubchannel).start(stateListenerCaptor.capture()); @@ -358,7 +376,7 @@ public void requestConnection() { loadBalancer.requestConnection(); verify(mockSubchannel, never()).requestConnection(); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); verify(mockSubchannel).requestConnection(); diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java index 0947f65da12..2bc5a8a3760 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java +++ b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java @@ -19,6 +19,7 @@ import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; @@ -32,9 +33,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.InvalidMarkException; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,9 +45,6 @@ public class ReadableBuffersTest { private static final byte[] MSG_BYTES = "hello".getBytes(UTF_8); - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void empty_returnsEmptyBuffer() { ReadableBuffer buffer = ReadableBuffers.empty(); @@ -216,8 +212,7 @@ public void bufferInputStream_markDiscardedAfterDetached() throws IOException { InputStream inputStream = ReadableBuffers.openStream(buffer, true); inputStream.mark(5); ((Detachable) inputStream).detach(); - thrown.expect(InvalidMarkException.class); - inputStream.reset(); + assertThrows(InvalidMarkException.class, () -> inputStream.reset()); } @Test diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 025c4f23e80..7597c10d354 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -43,6 +43,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.MoreExecutors; import io.grpc.ClientStreamTracer; import io.grpc.Codec; @@ -60,6 +61,8 @@ import io.grpc.internal.StreamListener.MessageProducer; import java.io.InputStream; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.concurrent.Executor; @@ -268,10 +271,14 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg3"); retriableStream.request(456); - inOrder.verify(mockStream1, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround + inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(345); inOrder.verify(mockStream1, times(2)).flush(); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(456); inOrder.verifyNoMoreInteractions(); @@ -279,6 +286,7 @@ public Void answer(InvocationOnMock in) { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); + inOrder.verify(retriableStreamRecorder).newSubstream(1); assertEquals(1, fakeClock.numPendingTasks()); // send more messages during backoff @@ -290,7 +298,6 @@ public Void answer(InvocationOnMock in) { assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); - inOrder.verify(retriableStreamRecorder).newSubstream(1); inOrder.verify(mockStream2).setAuthority(AUTHORITY); inOrder.verify(mockStream2).setCompressor(COMPRESSOR); inOrder.verify(mockStream2).setDecompressorRegistry(DECOMPRESSOR_REGISTRY); @@ -304,12 +311,19 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); - inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(345); inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(456); - inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); @@ -319,12 +333,16 @@ public Void answer(InvocationOnMock in) { // mockStream1 is closed so it is not in the drainedSubstreams verifyNoMoreInteractions(mockStream1); - inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround // retry2 doReturn(mockStream3).when(retriableStreamRecorder).newSubstream(2); sublistenerCaptor2.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + inOrder.verify(retriableStreamRecorder).newSubstream(2); assertEquals(1, fakeClock.numPendingTasks()); // send more messages during backoff @@ -339,7 +357,6 @@ public Void answer(InvocationOnMock in) { assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); - inOrder.verify(retriableStreamRecorder).newSubstream(2); inOrder.verify(mockStream3).setAuthority(AUTHORITY); inOrder.verify(mockStream3).setCompressor(COMPRESSOR); inOrder.verify(mockStream3).setDecompressorRegistry(DECOMPRESSOR_REGISTRY); @@ -353,12 +370,19 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); - inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(345); inOrder.verify(mockStream3, times(2)).flush(); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(456); - inOrder.verify(mockStream3, times(7)).writeMessage(any(InputStream.class)); + for (int i = 0; i < 7; i++) { + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround + } inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); @@ -771,6 +795,8 @@ public boolean isReady() { public void cancelWhileDraining() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); ClientStream mockStream1 = mock(ClientStream.class); ClientStream mockStream2 = mock( @@ -797,7 +823,7 @@ public void request(int numMessages) { Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); - inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); inOrder.verify(mockStream2).request(3); inOrder.verify(retriableStreamRecorder).postCommit(); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); @@ -805,6 +831,7 @@ public void request(int numMessages) { assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); assertThat(statusCaptor.getValue().getDescription()) .isEqualTo("Stream thrown away because RetriableStream committed"); + sublistenerCaptor2.getValue().closed(Status.CANCELLED, PROCESSED, new Metadata()); verify(masterListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); @@ -827,6 +854,8 @@ public void start(ClientStreamListener listener) { Status.CANCELLED.withDescription("cancelled while retry start")); } })); + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); @@ -839,13 +868,14 @@ public void start(ClientStreamListener listener) { Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); - inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); inOrder.verify(retriableStreamRecorder).postCommit(); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); inOrder.verify(mockStream2).cancel(statusCaptor.capture()); assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); assertThat(statusCaptor.getValue().getDescription()) .isEqualTo("Stream thrown away because RetriableStream committed"); + sublistenerCaptor2.getValue().closed(Status.CANCELLED, PROCESSED, new Metadata()); verify(masterListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); @@ -970,6 +1000,78 @@ public void messageAvailable() { verify(masterListener).messagesAvailable(messageProducer); } + @Test + public void inboundMessagesClosedOnCancel() throws Exception { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + retriableStream.request(1); + retriableStream.cancel(Status.CANCELLED.withDescription("on purpose")); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + ClientStreamListener listener = sublistenerCaptor1.getValue(); + listener.headersRead(new Metadata()); + InputStream is = mock(InputStream.class); + listener.messagesAvailable(new FakeMessageProducer(is)); + verify(masterListener, never()).messagesAvailable(any(MessageProducer.class)); + verify(is).close(); + } + + @Test + public void notAdd0PrevRetryAttemptsToRespHeaders() { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor.capture()); + + sublistenerCaptor.getValue().headersRead(new Metadata()); + + ArgumentCaptor metadataCaptor = + ArgumentCaptor.forClass(Metadata.class); + verify(masterListener).headersRead(metadataCaptor.capture()); + assertEquals(null, metadataCaptor.getValue().get(GRPC_PREVIOUS_RPC_ATTEMPTS)); + } + + @Test + public void addPrevRetryAttemptsToRespHeaders() { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + // retry + ClientStream mockStream2 = mock(ClientStream.class); + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream2).start(sublistenerCaptor2.capture()); + Metadata headers = new Metadata(); + headers.put(GRPC_PREVIOUS_RPC_ATTEMPTS, "3"); + sublistenerCaptor2.getValue().headersRead(headers); + + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(masterListener).headersRead(metadataCaptor.capture()); + Iterable iterable = metadataCaptor.getValue().getAll(GRPC_PREVIOUS_RPC_ATTEMPTS); + assertEquals(1, Iterables.size(iterable)); + assertEquals("1", iterable.iterator().next()); + } + @Test public void closedWhileDraining() { ClientStream mockStream1 = mock(ClientStream.class); @@ -1100,7 +1202,6 @@ public void perRpcBufferLimitExceededDuringBackoff() { sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - // bufferSizeTracer.outboundWireSize() quits immediately while backoff b/c substream1 is closed assertEquals(1, fakeClock.numPendingTasks()); bufferSizeTracer.outboundWireSize(2); verify(retriableStreamRecorder, never()).postCommit(); @@ -1111,8 +1212,6 @@ public void perRpcBufferLimitExceededDuringBackoff() { // bufferLimitExceeded bufferSizeTracer.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); - verify(retriableStreamRecorder, never()).postCommit(); - bufferSizeTracer.outboundWireSize(2); verify(retriableStreamRecorder).postCommit(); verifyNoMoreInteractions(mockStream1); @@ -1711,12 +1810,12 @@ public void transparentRetry_unlimitedTimesOnMiscarried() { assertEquals(0, fakeClock.numPendingTasks()); ArgumentCaptor sublistenerCaptor = sublistenerCaptor3; - for (int i = 0; i < 9999; i++) { + for (int i = 0; i < 999; i++) { ClientStream mockStream = mock(ClientStream.class); doReturn(mockStream).when(retriableStreamRecorder).newSubstream(0); sublistenerCaptor.getValue() .closed(Status.fromCode(NON_RETRIABLE_STATUS_CODE), MISCARRIED, new Metadata()); - if (i == 9998) { + if (i == 998) { verify(retriableStreamRecorder).postCommit(); verify(masterListener) .closed(any(Status.class), any(RpcProgress.class), any(Metadata.class)); @@ -1958,10 +2057,14 @@ public Void answer(InvocationOnMock in) { hedgingStream.sendMessage("msg3"); hedgingStream.request(456); - inOrder.verify(mockStream1, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround + inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(345); inOrder.verify(mockStream1, times(2)).flush(); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(456); inOrder.verifyNoMoreInteractions(); @@ -1984,10 +2087,14 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); - inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(345); inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(456); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); @@ -1998,9 +2105,13 @@ public Void answer(InvocationOnMock in) { hedgingStream.sendMessage("msg2 after hedge2 starts"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verifyNoMoreInteractions(); @@ -2022,12 +2133,19 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); - inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(345); inOrder.verify(mockStream3, times(2)).flush(); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(456); - inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); inOrder.verify(mockStream3).isReady(); @@ -2036,8 +2154,11 @@ public Void answer(InvocationOnMock in) { // send one more message hedgingStream.sendMessage("msg1 after hedge3 starts"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround // hedge3 receives nonFatalStatus sublistenerCaptor3.getValue().closed( @@ -2047,7 +2168,9 @@ public Void answer(InvocationOnMock in) { // send one more message hedgingStream.sendMessage("msg1 after hedge3 fails"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).flush(); // Memory leak workaround // the hedge mockStream4 starts fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2067,12 +2190,19 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); - inOrder.verify(mockStream4, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream4).flush(); // Memory leak workaround + inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream4).flush(); // Memory leak workaround inOrder.verify(mockStream4).request(345); inOrder.verify(mockStream4, times(2)).flush(); inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream4).flush(); // Memory leak workaround inOrder.verify(mockStream4).request(456); - inOrder.verify(mockStream4, times(4)).writeMessage(any(InputStream.class)); + for (int i = 0; i < 4; i++) { + inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream4).flush(); // Memory leak workaround + } inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); inOrder.verify(mockStream4).isReady(); @@ -2096,6 +2226,10 @@ public Void answer(InvocationOnMock in) { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue().closed( + Status.CANCELLED, PROCESSED, new Metadata()); + sublistenerCaptor4.getValue().closed( + Status.CANCELLED, PROCESSED, new Metadata()); inOrder.verify(masterListener).closed( any(Status.class), any(RpcProgress.class), any(Metadata.class)); inOrder.verifyNoMoreInteractions(); @@ -2103,7 +2237,8 @@ public Void answer(InvocationOnMock in) { insight = new InsightBuilder(); hedgingStream.appendTimeoutInsight(insight); assertThat(insight.toString()).isEqualTo( - "[closed=[UNAVAILABLE, INTERNAL], committed=[remote_addr=2.2.2.2:81]]"); + "[closed=[UNAVAILABLE, INTERNAL, CANCELLED, CANCELLED], " + + "committed=[remote_addr=2.2.2.2:81]]"); } @Test @@ -2190,6 +2325,7 @@ public void hedging_maxAttempts() { hedgingStream.sendMessage("msg1 after commit"); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verifyNoMoreInteractions(); Metadata heders = new Metadata(); @@ -2369,6 +2505,7 @@ public void hedging_pushback_positive() { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor3.getValue().closed(Status.CANCELLED, PROCESSED, metadata); inOrder.verify(masterListener).closed(fatal, PROCESSED, metadata); inOrder.verifyNoMoreInteractions(); } @@ -2411,6 +2548,8 @@ public void hedging_cancelled() { assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue().closed(Status.CANCELLED, PROCESSED, new Metadata()); + sublistenerCaptor2.getValue().closed(Status.CANCELLED, PROCESSED, new Metadata()); inOrder.verify(masterListener).closed( any(Status.class), any(RpcProgress.class), any(Metadata.class)); inOrder.verifyNoMoreInteractions(); @@ -2547,6 +2686,8 @@ public void hedging_transparentRetry() { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue().closed(Status.CANCELLED, PROCESSED, metadata); + sublistenerCaptor4.getValue().closed(Status.CANCELLED, PROCESSED, metadata); verify(masterListener).closed(status, REFUSED, metadata); } @@ -2587,6 +2728,9 @@ public void hedging_transparentRetryNotAllowed() { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue() + .closed(Status.CANCELLED, REFUSED, new Metadata()); + //master listener close should wait until all substreams are closed verify(masterListener).closed(status, REFUSED, metadata); } @@ -2665,4 +2809,22 @@ private interface RetriableStreamRecorder { Status prestart(); } + + private static final class FakeMessageProducer implements MessageProducer { + private final Iterator iterator; + + public FakeMessageProducer(InputStream... iss) { + this.iterator = Arrays.asList(iss).iterator(); + } + + @Override + @Nullable + public InputStream next() { + if (iterator.hasNext()) { + return iterator.next(); + } else { + return null; + } + } + } } diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 9c25f474804..4818c6c2017 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.GrpcUtil.CONTENT_LENGTH_KEY; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -33,6 +34,7 @@ import static org.mockito.Mockito.when; import com.google.common.io.CharStreams; +import io.grpc.Attributes; import io.grpc.CompressorRegistry; import io.grpc.Context; import io.grpc.DecompressorRegistry; @@ -41,6 +43,7 @@ import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.SecurityLevel; import io.grpc.ServerCall; import io.grpc.Status; import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl; @@ -187,7 +190,6 @@ public void sendMessage() { call.sendMessage(1234L); verify(stream).writeMessage(isA(InputStream.class)); - verify(stream).flush(); } @Test @@ -352,6 +354,23 @@ public void getNullAuthority() { verify(stream).getAuthority(); } + @Test + public void getSecurityLevel() { + Attributes attributes = Attributes.newBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.INTEGRITY).build(); + when(stream.getAttributes()).thenReturn(attributes); + assertEquals(SecurityLevel.INTEGRITY, call.getSecurityLevel()); + verify(stream).getAttributes(); + } + + @Test + public void getNullSecurityLevel() { + when(stream.getAttributes()).thenReturn(null); + assertEquals(SecurityLevel.NONE, call.getSecurityLevel()); + verify(stream).getAttributes(); + } + + @Test public void setMessageCompression() { call.setMessageCompression(true); @@ -406,7 +425,7 @@ public void streamListener_closedCancelled() { verify(callListener).onCancel(); assertTrue(context.isCancelled()); - assertNull(context.cancellationCause()); + assertNotNull(context.cancellationCause()); } @Test diff --git a/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java index ad8cf41598a..ce601c5f837 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java @@ -18,11 +18,20 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import io.grpc.InternalGlobalInterceptors; import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerStreamTracer; +import io.grpc.StaticTestingClassLoader; import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.regex.Pattern; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -38,6 +47,21 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata throw new UnsupportedOperationException(); } }; + private static final ServerInterceptor DUMMY_TEST_INTERCEPTOR = + new ServerInterceptor() { + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + throw new UnsupportedOperationException(); + } + }; + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader( + getClass().getClassLoader(), + Pattern.compile( + "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|" + + "io\\.grpc\\.internal\\.[^.]+")); private ServerImplBuilder builder; @@ -101,4 +125,77 @@ public void getTracerFactories_disableBoth() { List factories = builder.getTracerFactories(); assertThat(factories).containsExactly(DUMMY_USER_TRACER); } + + @Test + public void getTracerFactories_callsGet() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + public static final class StaticTestingClassLoaderCallsGet implements Runnable { + @Override + public void run() { + ServerImplBuilder builder = + new ServerImplBuilder( + streamTracerFactories -> { + throw new UnsupportedOperationException(); + }); + assertThat(builder.getTracerFactories()).hasSize(2); + assertThat(builder.interceptors).hasSize(0); + try { + InternalGlobalInterceptors.setInterceptorsTracers( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + fail("exception expected"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().contains("Set cannot be called after any get call"); + } + } + } + + @Test + public void getTracerFactories_callsSet() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsSet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + public static final class StaticTestingClassLoaderCallsSet implements Runnable { + @Override + public void run() { + InternalGlobalInterceptors.setInterceptorsTracers( + Collections.emptyList(), + Arrays.asList(DUMMY_TEST_INTERCEPTOR), + Arrays.asList(DUMMY_USER_TRACER)); + ServerImplBuilder builder = + new ServerImplBuilder( + streamTracerFactories -> { + throw new UnsupportedOperationException(); + }); + assertThat(builder.getTracerFactories()).containsExactly(DUMMY_USER_TRACER); + assertThat(builder.interceptors).containsExactly(DUMMY_TEST_INTERCEPTOR); + } + } + + @Test + public void getEffectiveInterceptors_setEmpty() throws Exception { + Class runnable = + classLoader.loadClass(StaticTestingClassLoaderCallsSetEmpty.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassLoaderCallsSetEmpty implements Runnable { + + @Override + public void run() { + InternalGlobalInterceptors.setInterceptorsTracers( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + ServerImplBuilder builder = + new ServerImplBuilder( + streamTracerFactories -> { + throw new UnsupportedOperationException(); + }); + assertThat(builder.getTracerFactories()).isEmpty(); + assertThat(builder.interceptors).isEmpty(); + } + } } diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 0f5c510f97c..d3c07787b60 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -130,7 +130,7 @@ public class ServerImplTest { private static final Context.Key SERVER_TRACER_ADDED_KEY = Context.key("tracer-added"); private static final Context.CancellableContext SERVER_CONTEXT = Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation(); - private static final FakeClock.TaskFilter CONTEXT_CLOSER_TASK_FITLER = + private static final FakeClock.TaskFilter CONTEXT_CLOSER_TASK_FILTER = new FakeClock.TaskFilter() { @Override public boolean shouldAccept(Runnable runnable) { @@ -1085,7 +1085,7 @@ private void checkContext() { assertTrue(onHalfCloseCalled.get()); streamListener.closed(Status.CANCELLED); - assertEquals(1, executor.numPendingTasks(CONTEXT_CLOSER_TASK_FITLER)); + assertEquals(1, executor.numPendingTasks(CONTEXT_CLOSER_TASK_FILTER)); assertEquals(2, executor.runDueTasks()); assertTrue(onCancelCalled.get()); @@ -1179,10 +1179,11 @@ public void testStreamClose_clientCancelTriggersImmediateCancellation() throws E assertFalse(callReference.get().isCancelled()); assertFalse(context.get().isCancelled()); streamListener.closed(Status.CANCELLED); - assertEquals(1, executor.numPendingTasks(CONTEXT_CLOSER_TASK_FITLER)); + assertEquals(1, executor.numPendingTasks(CONTEXT_CLOSER_TASK_FILTER)); assertEquals(2, executor.runDueTasks()); assertTrue(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); + assertThat(context.get().cancellationCause()).isNotNull(); assertTrue(contextCancelled.get()); } @@ -1208,6 +1209,7 @@ public void testStreamClose_clientOkTriggersDelayedCancellation() throws Excepti assertEquals(1, executor.runDueTasks()); assertFalse(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); + assertThat(context.get().cancellationCause()).isNull(); assertTrue(contextCancelled.get()); } @@ -1228,6 +1230,7 @@ public void testStreamClose_deadlineExceededTriggersImmediateCancellation() thro assertTrue(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); + assertThat(context.get().cancellationCause()).isNotNull(); assertTrue(contextCancelled.get()); } diff --git a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java index 16c6f3bf302..24317e80692 100644 --- a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java +++ b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java @@ -186,13 +186,13 @@ public void run() { } assertEquals(numExpectedTasks, timer.numPendingTasks()); - ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); } @Before public void setUp() throws Exception { - when(mockLoadBalancer.canHandleEmptyAddressListFromNameResolution()).thenCallRealMethod(); + mockLoadBalancer.setAcceptAddresses(true); LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); expectedUri = new URI(TARGET); when(mockTransportFactory.getScheduledExecutorService()) @@ -236,7 +236,7 @@ public void cleanUp() { public void emptyAddresses_validConfig_firstResolution_lbNeedsAddress() throws Exception { FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.emptyList()) + .setServers(Collections.emptyList()) .build(); channelBuilder.nameResolverFactory(nameResolverFactory); @@ -268,7 +268,7 @@ public void emptyAddresses_validConfig_2ndResolution_lbNeedsAddress() throws Exc ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("12"); @@ -280,29 +280,23 @@ public void emptyAddresses_validConfig_2ndResolution_lbNeedsAddress() throws Exc nameResolverFactory.servers.clear(); // 2nd resolution + mockLoadBalancer.setAcceptAddresses(false); nameResolverFactory.allResolved(); // 2nd service config without addresses - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(mockLoadBalancer, never()).handleResolvedAddresses(any(ResolvedAddresses.class)); - verify(mockLoadBalancer).handleNameResolutionError(statusCaptor.capture()); - assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(statusCaptor.getValue().getDescription()) - .contains("NameResolver returned no usable address."); - assertThat(channel.getState(true)).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); - assertWithMessage("Empty address should schedule NameResolver retry") - .that(getNameResolverRefresh()) - .isNotNull(); + verify(mockLoadBalancer).acceptResolvedAddresses(any(ResolvedAddresses.class)); + + // A resolution retry has been scheduled + assertEquals(1, timer.numPendingTasks(NAME_RESOLVER_REFRESH_TASK_FILTER)); } @Test public void emptyAddresses_validConfig_lbDoesNotNeedAddress() throws Exception { FakeNameResolverFactory nameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.emptyList()) + .setServers(Collections.emptyList()) .build(); channelBuilder.nameResolverFactory(nameResolverFactory); - when(mockLoadBalancer.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); Map rawServiceConfig = parseJson("{\"loadBalancingConfig\": [{\"mock_lb\": {\"check\": \"val\"}}]}"); @@ -312,11 +306,11 @@ public void emptyAddresses_validConfig_lbDoesNotNeedAddress() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).isEmpty(); - assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("val");; + assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("val"); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); assertThat(channel.getState(false)).isNotEqualTo(ConnectivityState.TRANSIENT_FAILURE); @@ -338,7 +332,7 @@ public void validConfig_lbDoesNotNeedAddress() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("foo"); @@ -360,7 +354,7 @@ public void noConfig_noDefaultConfig() { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isNull(); @@ -386,7 +380,7 @@ public void noConfig_usingDefaultConfig() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("foo"); @@ -431,7 +425,7 @@ public void invalidConfig_withDefaultConfig() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); @@ -462,7 +456,7 @@ public void invalidConfig_2ndResolution() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("1st raw config"); @@ -477,7 +471,7 @@ public void invalidConfig_2ndResolution() throws Exception { nextLbPolicyConfigError.set(Status.UNKNOWN); nameResolverFactory.allResolved(); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses newResolvedAddress = resultCaptor.getValue(); // should use previous service config because new service config is invalid. assertThat(newResolvedAddress.getLoadBalancingPolicyConfig()).isEqualTo("1st raw config"); @@ -510,7 +504,7 @@ public void validConfig_thenNoConfig_withDefaultConfig() throws Exception { ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); - verify(mockLoadBalancer).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses resolvedAddresses = resultCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).containsExactly(addressGroup); // should use previous service config because new resolution result is no config. @@ -522,7 +516,7 @@ public void validConfig_thenNoConfig_withDefaultConfig() throws Exception { nameResolverFactory.nextRawServiceConfig.set(null); nameResolverFactory.allResolved(); - verify(mockLoadBalancer, times(2)).handleResolvedAddresses(resultCaptor.capture()); + verify(mockLoadBalancer, times(2)).acceptResolvedAddresses(resultCaptor.capture()); ResolvedAddresses newResolvedAddress = resultCaptor.getValue(); assertThat(newResolvedAddress.getLoadBalancingPolicyConfig()).isEqualTo("mate"); assertThat(newResolvedAddress.getAttributes().get(InternalConfigSelector.KEY)) @@ -658,6 +652,8 @@ private FakeClock.ScheduledTask getNameResolverRefresh() { private static class FakeLoadBalancer extends LoadBalancer { + private boolean acceptAddresses = true; + @Nullable private Helper helper; @@ -665,6 +661,15 @@ public void setHelper(Helper helper) { this.helper = helper; } + @Override + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + return acceptAddresses; + } + + public void setAcceptAddresses(boolean acceptAddresses) { + this.acceptAddresses = acceptAddresses; + } + @Override public void handleNameResolutionError(final Status error) { helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, diff --git a/core/src/test/java/io/grpc/util/ForwardingLoadBalancerTest.java b/core/src/test/java/io/grpc/util/ForwardingLoadBalancerTest.java index 01226466845..f9b53400cea 100644 --- a/core/src/test/java/io/grpc/util/ForwardingLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/ForwardingLoadBalancerTest.java @@ -20,8 +20,8 @@ import io.grpc.ForwardingTestUtil; import io.grpc.LoadBalancer; -import java.lang.reflect.Method; -import java.util.Collections; +import io.grpc.LoadBalancer.ResolvedAddresses; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -44,6 +44,7 @@ public void allMethodsForwarded() throws Exception { LoadBalancer.class, mockDelegate, new TestBalancer(), - Collections.emptyList()); + Arrays.asList( + LoadBalancer.class.getMethod("acceptResolvedAddresses", ResolvedAddresses.class))); } } diff --git a/core/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java b/core/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java index 2131cd725b6..6e89176e9c9 100644 --- a/core/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java @@ -479,7 +479,7 @@ public int hashCode() { @Test public void transientFailureOnInitialResolutionError() { gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); - ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); assertThat(picker.pickSubchannel(mock(PickSubchannelArgs.class)).getStatus().getCode()) diff --git a/core/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java b/core/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java new file mode 100644 index 00000000000..5a27e6f176f --- /dev/null +++ b/core/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java @@ -0,0 +1,141 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.grpc.InternalServiceProviders; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.SynchronizationContext; +import io.grpc.internal.JsonParser; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; +import java.io.IOException; +import java.lang.Thread.UncaughtExceptionHandler; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link OutlierDetectionLoadBalancerProvider}. + */ +@RunWith(JUnit4.class) +public class OutlierDetectionLoadBalancerProviderTest { + + private final SynchronizationContext syncContext = new SynchronizationContext( + new UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final OutlierDetectionLoadBalancerProvider provider + = new OutlierDetectionLoadBalancerProvider(); + + @Test + public void provided() { + for (LoadBalancerProvider current : InternalServiceProviders.getCandidatesViaServiceLoader( + LoadBalancerProvider.class, getClass().getClassLoader())) { + if (current instanceof OutlierDetectionLoadBalancerProvider) { + return; + } + } + fail("OutlierDetectionLoadBalancerProvider not registered"); + } + + @Test + public void providesLoadBalancer() { + Helper helper = mock(Helper.class); + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn(mock(ScheduledExecutorService.class)); + assertThat(provider.newLoadBalancer(helper)) + .isInstanceOf(OutlierDetectionLoadBalancer.class); + } + + @Test + public void parseLoadBalancingConfig_defaults() throws IOException { + String lbConfig = + "{ \"successRateEjection\" : {}, " + + "\"failurePercentageEjection\" : {}, " + + "\"childPolicy\" : [{\"round_robin\" : {}}]}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + OutlierDetectionLoadBalancerConfig config + = (OutlierDetectionLoadBalancerConfig) configOrError.getConfig(); + assertThat(config.successRateEjection).isNotNull(); + assertThat(config.failurePercentageEjection).isNotNull(); + assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + } + + @Test + public void parseLoadBalancingConfig_valuesSet() throws IOException { + String lbConfig = + "{\"interval\" : \"100s\"," + + " \"baseEjectionTime\" : \"100s\"," + + " \"maxEjectionTime\" : \"100s\"," + + " \"maxEjectionPercentage\" : 100," + + " \"successRateEjection\" : {" + + " \"stdevFactor\" : 100," + + " \"enforcementPercentage\" : 100," + + " \"minimumHosts\" : 100," + + " \"requestVolume\" : 100" + + " }," + + " \"failurePercentageEjection\" : {" + + " \"threshold\" : 100," + + " \"enforcementPercentage\" : 100," + + " \"minimumHosts\" : 100," + + " \"requestVolume\" : 100" + + " }," + + "\"childPolicy\" : [{\"round_robin\" : {}}]}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + OutlierDetectionLoadBalancerConfig config + = (OutlierDetectionLoadBalancerConfig) configOrError.getConfig(); + + assertThat(config.intervalNanos).isEqualTo(100_000_000_000L); + assertThat(config.baseEjectionTimeNanos).isEqualTo(100_000_000_000L); + assertThat(config.maxEjectionTimeNanos).isEqualTo(100_000_000_000L); + assertThat(config.maxEjectionPercent).isEqualTo(100); + + assertThat(config.successRateEjection).isNotNull(); + assertThat(config.successRateEjection.stdevFactor).isEqualTo(100); + assertThat(config.successRateEjection.enforcementPercentage).isEqualTo(100); + assertThat(config.successRateEjection.minimumHosts).isEqualTo(100); + assertThat(config.successRateEjection.requestVolume).isEqualTo(100); + + assertThat(config.failurePercentageEjection).isNotNull(); + assertThat(config.failurePercentageEjection.threshold).isEqualTo(100); + assertThat(config.failurePercentageEjection.enforcementPercentage).isEqualTo(100); + assertThat(config.failurePercentageEjection.minimumHosts).isEqualTo(100); + assertThat(config.failurePercentageEjection.requestVolume).isEqualTo(100); + + assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + } + + @SuppressWarnings("unchecked") + private static Map parseJsonObject(String json) throws IOException { + return (Map) JsonParser.parse(json); + } +} diff --git a/core/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/core/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java new file mode 100644 index 00000000000..ccf3d40cdb6 --- /dev/null +++ b/core/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -0,0 +1,1163 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.ConnectivityState.READY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import io.grpc.ClientStreamTracer; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.FakeClock.ScheduledTask; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.TestUtils.StandardLoadBalancerProvider; +import io.grpc.util.OutlierDetectionLoadBalancer.AddressTracker; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.FailurePercentageEjection; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.SuccessRateEjection; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionSubchannel; +import io.grpc.util.OutlierDetectionLoadBalancer.SuccessRateOutlierEjectionAlgorithm; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; + +/** + * Unit tests for {@link OutlierDetectionLoadBalancer}. + */ +@RunWith(JUnit4.class) +public class OutlierDetectionLoadBalancerTest { + + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock + private LoadBalancer mockChildLb; + @Mock + private Helper mockHelper; + @Mock + private SocketAddress mockSocketAddress; + + @Captor + private ArgumentCaptor connectivityStateCaptor; + @Captor + private ArgumentCaptor errorPickerCaptor; + @Captor + private ArgumentCaptor pickerCaptor; + @Captor + private ArgumentCaptor stateCaptor; + + private final LoadBalancerProvider mockChildLbProvider = new StandardLoadBalancerProvider( + "foo_policy") { + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return mockChildLb; + } + }; + private final LoadBalancerProvider fakeLbProvider = new StandardLoadBalancerProvider( + "fake_policy") { + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new FakeLoadBalancer(helper); + } + }; + private final LoadBalancerProvider roundRobinLbProvider = new StandardLoadBalancerProvider( + "round_robin") { + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new RoundRobinLoadBalancer(helper); + } + }; + + private final FakeClock fakeClock = new FakeClock(); + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private OutlierDetectionLoadBalancer loadBalancer; + + private final List servers = Lists.newArrayList(); + private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); + private final Map subchannelStateListeners + = Maps.newLinkedHashMap(); + + private Subchannel subchannel1; + private Subchannel subchannel2; + private Subchannel subchannel3; + private Subchannel subchannel4; + private Subchannel subchannel5; + + @Before + public void setUp() { + for (int i = 0; i < 5; i++) { + SocketAddress addr = new FakeSocketAddress("server" + i); + EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); + servers.add(eag); + Subchannel sc = mock(Subchannel.class); + subchannels.put(Arrays.asList(eag), sc); + } + + Iterator subchannelIterator = subchannels.values().iterator(); + subchannel1 = subchannelIterator.next(); + subchannel2 = subchannelIterator.next(); + subchannel3 = subchannelIterator.next(); + subchannel4 = subchannelIterator.next(); + subchannel5 = subchannelIterator.next(); + + when(mockHelper.getSynchronizationContext()).thenReturn(syncContext); + when(mockHelper.getScheduledExecutorService()).thenReturn( + fakeClock.getScheduledExecutorService()); + when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))).then( + new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put(subchannel, + (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); + + loadBalancer = new OutlierDetectionLoadBalancer(mockHelper, fakeClock.getTimeProvider()); + } + + @Test + public void handleNameResolutionError_noChildLb() { + loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); + + verify(mockHelper).updateBalancingState(connectivityStateCaptor.capture(), + errorPickerCaptor.capture()); + assertThat(connectivityStateCaptor.getValue()).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); + } + + @Test + public void handleNameResolutionError_withChildLb() { + loadBalancer.acceptResolvedAddresses(buildResolvedAddress( + new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(), + new EquivalentAddressGroup(mockSocketAddress))); + loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); + + verify(mockChildLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); + } + + /** + * {@code shutdown()} is simply delegated. + */ + @Test + public void shutdown() { + loadBalancer.acceptResolvedAddresses(buildResolvedAddress( + new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(), + new EquivalentAddressGroup(mockSocketAddress))); + loadBalancer.shutdown(); + verify(mockChildLb).shutdown(); + } + + /** + * Base case for accepting new resolved addresses. + */ + @Test + public void acceptResolvedAddresses() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + ResolvedAddresses resolvedAddresses = buildResolvedAddress(config, + new EquivalentAddressGroup(mockSocketAddress)); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + // Handling of resolved addresses is delegated + verify(mockChildLb).handleResolvedAddresses( + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config.childPolicy.getConfig()) + .build()); + + // There is a single pending task to run the outlier detection algorithm + assertThat(fakeClock.getPendingTasks()).hasSize(1); + + // The task is scheduled to run after a delay set in the config. + ScheduledTask task = fakeClock.getPendingTasks().iterator().next(); + assertThat(task.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(config.intervalNanos); + } + + /** + * Outlier detection first enabled, then removed. + */ + @Test + public void acceptResolvedAddresses_outlierDetectionDisabled() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + ResolvedAddresses resolvedAddresses = buildResolvedAddress(config, + new EquivalentAddressGroup(mockSocketAddress)); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + fakeClock.forwardTime(15, TimeUnit.SECONDS); + + // There is a single pending task to run the outlier detection algorithm + assertThat(fakeClock.getPendingTasks()).hasSize(1); + + config = new OutlierDetectionLoadBalancerConfig.Builder().setChildPolicy( + new PolicySelection(mockChildLbProvider, null)).build(); + loadBalancer.acceptResolvedAddresses( + buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress))); + + // Pending task should be gone since OD is disabled. + assertThat(fakeClock.getPendingTasks()).isEmpty(); + + } + + /** + * Tests different scenarios when the timer interval in the config changes. + */ + @Test + public void acceptResolvedAddresses_intervalUpdate() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + ResolvedAddresses resolvedAddresses = buildResolvedAddress(config, + new EquivalentAddressGroup(mockSocketAddress)); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + // Config update has doubled the interval + config = new OutlierDetectionLoadBalancerConfig.Builder() + .setIntervalNanos(config.intervalNanos * 2) + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses( + buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress))); + + // If the timer has not run yet the task is just rescheduled to run after the new delay. + assertThat(fakeClock.getPendingTasks()).hasSize(1); + ScheduledTask task = fakeClock.getPendingTasks().iterator().next(); + assertThat(task.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(config.intervalNanos); + assertThat(task.dueTimeNanos).isEqualTo(config.intervalNanos); + + // The new interval time has passed. The next task due time should have been pushed back another + // interval. + forwardTime(config); + assertThat(fakeClock.getPendingTasks()).hasSize(1); + task = fakeClock.getPendingTasks().iterator().next(); + assertThat(task.dueTimeNanos).isEqualTo(config.intervalNanos + config.intervalNanos + 1); + + // Some time passes and a second update comes down, but now the timer has had a chance to run, + // the new delay to timer start should consider when the timer last ran and if the interval is + // not changing in the config, the next task due time should remain unchanged. + fakeClock.forwardTime(4, TimeUnit.SECONDS); + task = fakeClock.getPendingTasks().iterator().next(); + loadBalancer.acceptResolvedAddresses( + buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress))); + assertThat(task.dueTimeNanos).isEqualTo(config.intervalNanos + config.intervalNanos + 1); + } + + /** + * Confirm basic picking works by delegating to round_robin. + */ + @Test + public void delegatePick() throws Exception { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); + + // Make one of the subchannels READY. + final Subchannel readySubchannel = subchannels.values().iterator().next(); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + + verify(mockHelper, times(3)).updateBalancingState(stateCaptor.capture(), + pickerCaptor.capture()); + + // Make sure that we can pick the single READY subchannel. + SubchannelPicker picker = pickerCaptor.getAllValues().get(2); + PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate()).isEqualTo( + readySubchannel); + } + + /** + * The success rate algorithm leaves a healthy set of addresses alone. + */ + @Test + public void successRateNoOutliers() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder().setMinimumHosts(3).setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // No outliers, no ejections. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The success rate algorithm ejects the outlier. + */ + @Test + public void successRateOneOutlier() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + } + + /** + * The success rate algorithm ejects the outlier, but then the config changes so that similar + * behavior no longer gets ejected. + */ + @Test + public void successRateOneOutlier_configChange() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + + // New config sets enforcement percentage to 0. + config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setEnforcementPercentage(0).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // Since we brought enforcement percentage to 0, no additional ejection should have happened. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + } + + /** + * The success rate algorithm ejects the outlier but after some time it should get unejected + * if it stops being an outlier.. + */ + @Test + public void successRateOneOutlier_unejected() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + fakeClock.forwardTime(config.intervalNanos + 1, TimeUnit.NANOSECONDS); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + + // Now we produce more load, but the subchannel start working and is no longer an outlier. + generateLoad(ImmutableMap.of(), 8); + + // Move forward in time to a point where the detection timer has fired. + fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); + + // No subchannels should remain ejected. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The success rate algorithm ignores addresses without enough volume. + */ + @Test + public void successRateOneOutlier_notEnoughVolume() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(20).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + // We produce an outlier, but don't give it enough calls to reach the minimum volume. + generateLoad( + ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), + ImmutableMap.of(subchannel1, 19), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The address should not have been ejected. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The success rate algorithm does not apply if we don't have enough addresses that have the + * required volume. + */ + @Test + public void successRateOneOutlier_notEnoughAddressesWithVolume() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(5) + .setRequestVolume(20).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad( + ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), + // subchannel2 has only 19 calls which results in success rate not triggering. + ImmutableMap.of(subchannel2, 19), + 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // No subchannels should have been ejected. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The enforcementPercentage configuration should be honored. + */ + @Test + public void successRateOneOutlier_enforcementPercentage() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setEnforcementPercentage(0) + .build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // There is one outlier, but because enforcementPercentage is 0, nothing should be ejected. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * Two outliers get ejected. + */ + @Test + public void successRateTwoOutliers() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setStdevFactor(1).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of( + subchannel1, Status.DEADLINE_EXCEEDED, + subchannel2, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0), + servers.get(1).getAddresses().get(0))); + } + + /** + * Three outliers, second one ejected even if ejecting it goes above the max ejection percentage, + * as this matches Envoy behavior. The third one should not get ejected. + */ + @Test + public void successRateThreeOutliers_maxEjectionPercentage() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(30) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setStdevFactor(1).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of( + subchannel1, Status.DEADLINE_EXCEEDED, + subchannel2, Status.DEADLINE_EXCEEDED, + subchannel3, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + int totalEjected = 0; + for (EquivalentAddressGroup addressGroup: servers) { + totalEjected += + loadBalancer.trackerMap.get(addressGroup.getAddresses().get(0)).subchannelsEjected() ? 1 + : 0; + } + + assertThat(totalEjected).isEqualTo(2); + } + + + /** + * The success rate algorithm leaves a healthy set of addresses alone. + */ + @Test + public void failurePercentageNoOutliers() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + // By default all calls will return OK. + generateLoad(ImmutableMap.of(), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // No outliers, no ejections. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The success rate algorithm ejects the outlier. + */ + @Test + public void failurePercentageOneOutlier() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + } + + /** + * The failure percentage algorithm ignores addresses without enough volume.. + */ + @Test + public void failurePercentageOneOutlier_notEnoughVolume() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(100).build()) // We won't produce this much volume... + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // We should see no ejections. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The failure percentage algorithm does not apply if we don't have enough addresses that have the + * required volume. + */ + @Test + public void failurePercentageOneOutlier_notEnoughAddressesWithVolume() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(5) + .setRequestVolume(20).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad( + ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), + // subchannel2 has only 19 calls which results in failure percentage not triggering. + ImmutableMap.of(subchannel2, 19), + 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // No subchannels should have been ejected. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** + * The enforcementPercentage configuration should be honored. + */ + @Test + public void failurePercentageOneOutlier_enforcementPercentage() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setEnforcementPercentage(0) + .build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // There is one outlier, but because enforcementPercentage is 0, nothing should be ejected. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** Success rate detects two outliers and error percentage three. */ + @Test + public void successRateAndFailurePercentageThreeOutliers() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(100) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setStdevFactor(1).build()) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setThreshold(0) + .setMinimumHosts(3) + .setRequestVolume(1) + .build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + // Three subchannels with problems, but one only has a single call that failed. + // This is not enough for success rate to catch, but failure percentage is + // configured with a 0 tolerance threshold. + generateLoad( + ImmutableMap.of( + subchannel1, Status.DEADLINE_EXCEEDED, + subchannel2, Status.DEADLINE_EXCEEDED, + subchannel3, Status.DEADLINE_EXCEEDED), + ImmutableMap.of(subchannel3, 1), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // Should see thee ejected, success rate cathes the first two, error percentage the + // same two plus the subchannel with the single failure. + assertEjectedSubchannels(ImmutableSet.of( + servers.get(0).getAddresses().get(0), + servers.get(1).getAddresses().get(0), + servers.get(2).getAddresses().get(0))); + } + + /** + * When the address a subchannel is associated with changes it should get tracked under the new + * address and its ejection state should match what the address has. + */ + @Test + public void subchannelUpdateAddress_singleReplaced() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + EquivalentAddressGroup oldAddressGroup = servers.get(0); + AddressTracker oldAddressTracker = loadBalancer.trackerMap.get( + oldAddressGroup.getAddresses().get(0)); + EquivalentAddressGroup newAddressGroup = servers.get(1); + AddressTracker newAddressTracker = loadBalancer.trackerMap.get( + newAddressGroup.getAddresses().get(0)); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(oldAddressGroup.getAddresses().get(0))); + + // The ejected subchannel gets updated with another address in the map that is not ejected + OutlierDetectionSubchannel subchannel = oldAddressTracker.getSubchannels() + .iterator().next(); + subchannel.updateAddresses(ImmutableList.of(newAddressGroup)); + + // The replaced address should no longer have the subchannel associated with it. + assertThat(oldAddressTracker.getSubchannels()).doesNotContain(subchannel); + + // The new address should instead have the subchannel. + assertThat(newAddressTracker.getSubchannels()).contains(subchannel); + + // Since the new address is not ejected, the ejected subchannel moving over to it should also + // become unejected. + assertThat(subchannel.isEjected()).isFalse(); + } + + /** + * If a single address gets replaced by multiple, the subchannel becomes uneligible for outlier + * detection. + */ + @Test + public void subchannelUpdateAddress_singleReplacedWithMultiple() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + EquivalentAddressGroup oldAddressGroup = servers.get(0); + AddressTracker oldAddressTracker = loadBalancer.trackerMap.get( + oldAddressGroup.getAddresses().get(0)); + EquivalentAddressGroup newAddress1 = servers.get(1); + EquivalentAddressGroup newAddress2 = servers.get(2); + + OutlierDetectionSubchannel subchannel = oldAddressTracker.getSubchannels() + .iterator().next(); + + // The subchannel gets updated with two new addresses + ImmutableList addressUpdate + = ImmutableList.of(newAddress1, newAddress2); + subchannel.updateAddresses(addressUpdate); + when(subchannel1.getAllAddresses()).thenReturn(addressUpdate); + + // The replaced address should no longer be tracked. + assertThat(oldAddressTracker.getSubchannels()).doesNotContain(subchannel); + + // The old tracker should also have its call counters cleared. + assertThat(oldAddressTracker.activeVolume()).isEqualTo(0); + assertThat(oldAddressTracker.inactiveVolume()).isEqualTo(0); + } + + /** + * A subchannel with multiple addresses will again become eligible for outlier detection if it + * receives an update with a single address. + */ + @Test + public void subchannelUpdateAddress_multipleReplacedWithSingle() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 6); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + EquivalentAddressGroup oldAddressGroup = servers.get(0); + AddressTracker oldAddressTracker = loadBalancer.trackerMap.get( + oldAddressGroup.getAddresses().get(0)); + EquivalentAddressGroup newAddressGroup1 = servers.get(1); + AddressTracker newAddressTracker1 = loadBalancer.trackerMap.get( + newAddressGroup1.getAddresses().get(0)); + EquivalentAddressGroup newAddressGroup2 = servers.get(2); + + // The old subchannel was returning errors and should be ejected. + assertEjectedSubchannels(ImmutableSet.of(oldAddressGroup.getAddresses().get(0))); + + OutlierDetectionSubchannel subchannel = oldAddressTracker.getSubchannels() + .iterator().next(); + + // The subchannel gets updated with two new addresses + ImmutableList addressUpdate + = ImmutableList.of(newAddressGroup1, newAddressGroup2); + subchannel.updateAddresses(addressUpdate); + when(subchannel1.getAllAddresses()).thenReturn(addressUpdate); + + // The replaced address should no longer be tracked. + assertThat(oldAddressTracker.getSubchannels()).doesNotContain(subchannel); + + // The old tracker should also have its call counters cleared. + assertThat(oldAddressTracker.activeVolume()).isEqualTo(0); + assertThat(oldAddressTracker.inactiveVolume()).isEqualTo(0); + + // Another update takes the subchannel back to a single address. + addressUpdate = ImmutableList.of(newAddressGroup1); + subchannel.updateAddresses(addressUpdate); + when(subchannel1.getAllAddresses()).thenReturn(addressUpdate); + + // The subchannel is now associated with the single new address. + assertThat(newAddressTracker1.getSubchannels()).contains(subchannel); + + // The previously ejected subchannel should become unejected as it is now associated with an + // unejected address. + assertThat(subchannel.isEjected()).isFalse(); + } + + /** Both algorithms configured, but no outliers. */ + @Test + public void successRateAndFailurePercentage_noOutliers() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // No outliers, no ejections. + assertEjectedSubchannels(ImmutableSet.of()); + } + + /** Both algorithms configured, success rate detects an outlier. */ + @Test + public void successRateAndFailurePercentage_successRateOutlier() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setEnforcementPercentage(0).build()) // Configured, but not enforcing. + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + } + + /** Both algorithms configured, error percentage detects an outlier. */ + @Test + public void successRateAndFailurePercentage_errorPercentageOutlier() { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setMaxEjectionPercent(50) + .setSuccessRateEjection( + new SuccessRateEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10) + .setEnforcementPercentage(0).build()) + .setFailurePercentageEjection( + new FailurePercentageEjection.Builder() + .setMinimumHosts(3) + .setRequestVolume(10).build()) // Configured, but not enforcing. + .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); + + generateLoad(ImmutableMap.of(subchannel1, Status.DEADLINE_EXCEEDED), 7); + + // Move forward in time to a point where the detection timer has fired. + forwardTime(config); + + // The one subchannel that was returning errors should be ejected. + assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); + } + + @Test + public void mathChecksOut() { + ImmutableList values = ImmutableList.of(600d, 470d, 170d, 430d, 300d); + double mean = SuccessRateOutlierEjectionAlgorithm.mean(values); + double stdev = SuccessRateOutlierEjectionAlgorithm.standardDeviation(values, mean); + + assertThat(mean).isEqualTo(394); + assertThat(stdev).isEqualTo(147.32277488562318); + } + + private static class FakeSocketAddress extends SocketAddress { + + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } + } + + private ResolvedAddresses buildResolvedAddress(OutlierDetectionLoadBalancerConfig config, + EquivalentAddressGroup... servers) { + return ResolvedAddresses.newBuilder().setAddresses(ImmutableList.copyOf(servers)) + .setLoadBalancingPolicyConfig(config).build(); + } + + private ResolvedAddresses buildResolvedAddress(OutlierDetectionLoadBalancerConfig config, + List servers) { + return ResolvedAddresses.newBuilder().setAddresses(ImmutableList.copyOf(servers)) + .setLoadBalancingPolicyConfig(config).build(); + } + + private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + subchannelStateListeners.get(subchannel).onSubchannelState(newState); + } + + private void generateLoad(Map statusMap, int expectedStateChanges) { + generateLoad(statusMap, null, expectedStateChanges); + } + + // Generates 100 calls, 20 each across the subchannels. Default status is OK. + private void generateLoad(Map statusMap, + Map maxCallsMap, int expectedStateChanges) { + deliverSubchannelState(subchannel1, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel2, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel3, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel4, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel5, ConnectivityStateInfo.forNonError(READY)); + + verify(mockHelper, times(expectedStateChanges)).updateBalancingState(stateCaptor.capture(), + pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getAllValues() + .get(pickerCaptor.getAllValues().size() - 1); + + HashMap callCountMap = new HashMap<>(); + for (int i = 0; i < 100; i++) { + PickResult pickResult = picker + .pickSubchannel(mock(PickSubchannelArgs.class)); + ClientStreamTracer clientStreamTracer = pickResult.getStreamTracerFactory() + .newClientStreamTracer(null, null); + + Subchannel subchannel = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate(); + + int maxCalls = + maxCallsMap != null && maxCallsMap.containsKey(subchannel) + ? maxCallsMap.get(subchannel) : Integer.MAX_VALUE; + int calls = callCountMap.containsKey(subchannel) ? callCountMap.get(subchannel) : 0; + if (calls < maxCalls) { + callCountMap.put(subchannel, ++calls); + clientStreamTracer.streamClosed( + statusMap.containsKey(subchannel) ? statusMap.get(subchannel) : Status.OK); + } + } + } + + // Forwards time past the moment when the timer will fire. + private void forwardTime(OutlierDetectionLoadBalancerConfig config) { + fakeClock.forwardTime(config.intervalNanos + 1, TimeUnit.NANOSECONDS); + } + + // Asserts that the given addresses are ejected and the rest are not. + void assertEjectedSubchannels(Set addresses) { + for (Entry entry : loadBalancer.trackerMap.entrySet()) { + assertWithMessage("not ejected: " + entry.getKey()) + .that(entry.getValue().subchannelsEjected()) + .isEqualTo(addresses.contains(entry.getKey())); + } + } + + /** Round robin like fake load balancer. */ + private static final class FakeLoadBalancer extends LoadBalancer { + private final Helper helper; + + List subchannelList; + int lastPickIndex = -1; + + FakeLoadBalancer(Helper helper) { + this.helper = helper; + } + + @Override + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + subchannelList = new ArrayList<>(); + for (EquivalentAddressGroup eag: resolvedAddresses.getAddresses()) { + Subchannel subchannel = helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(eag).build()); + subchannelList.add(subchannel); + subchannel.start(mock(SubchannelStateListener.class)); + deliverSubchannelState(READY); + } + return true; + } + + @Override + public void handleNameResolutionError(Status error) { + } + + @Override + public void shutdown() { + } + + void deliverSubchannelState(ConnectivityState state) { + SubchannelPicker picker = new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + if (lastPickIndex < 0 || lastPickIndex > subchannelList.size() - 1) { + lastPickIndex = 0; + } + return PickResult.withSubchannel(subchannelList.get(lastPickIndex++)); + } + }; + helper.updateBalancingState(state, picker); + } + } +} diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index dcb85ad916d..d4c07e3d50e 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -27,6 +27,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; @@ -48,7 +49,6 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; @@ -147,8 +147,9 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -198,9 +199,10 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -220,8 +222,9 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); verify(newSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); @@ -239,25 +242,16 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); - // test going from non-empty to empty - loadBalancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes(affinity) - .build()); - - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); - verifyNoMoreInteractions(mockHelper); } @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); Ref subchannelStateInfo = subchannel.getAttributes().get( STATE_INFO); @@ -295,9 +289,10 @@ public void pickAfterStateChange() throws Exception { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); @@ -314,9 +309,10 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -352,9 +348,10 @@ public void stayTransientFailureUntilReady() { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -419,8 +416,9 @@ public void nameResolutionErrorWithNoChannels() throws Exception { @Test public void nameResolutionErrorWithActiveChannels() throws Exception { final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); @@ -448,9 +446,10 @@ public void subchannelStateIsolation() throws Exception { Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -486,10 +485,14 @@ public void subchannelStateIsolation() throws Exception { assertThat(pickers.hasNext()).isFalse(); } - @Test(expected = IllegalArgumentException.class) + @Test public void readyPicker_emptyList() { // ready picker list must be non-empty - new ReadyPicker(Collections.emptyList(), 0); + try { + new ReadyPicker(Collections.emptyList(), 0); + fail(); + } catch (IllegalArgumentException expected) { + } } @Test @@ -517,6 +520,15 @@ public void internalPickerComparisons() { assertFalse(ready1.isEquivalentTo(emptyOk1)); } + @Test + public void emptyAddresses() { + assertThat(loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(affinity) + .build())).isFalse(); + } + private static List getList(SubchannelPicker picker) { return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : Collections.emptyList(); diff --git a/cronet/README.md b/cronet/README.md index fb65ca23e5a..7852cacd4df 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.44.1' +implementation 'io.grpc:grpc-cronet:1.53.0' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/cronet/build.gradle b/cronet/build.gradle index d66eaf3a182..61551cb4d10 100644 --- a/cronet/build.gradle +++ b/cronet/build.gradle @@ -38,14 +38,14 @@ android { dependencies { api project(':grpc-core'), - libraries.cronet_api + libraries.cronet.api implementation libraries.guava testImplementation project(':grpc-testing') - testImplementation libraries.cronet_embedded + testImplementation libraries.cronet.embedded testImplementation libraries.junit - testImplementation libraries.mockito + testImplementation libraries.mockito.core testImplementation (libraries.robolectric) { // Unreleased change: https://github.com/robolectric/robolectric/pull/5432 exclude group: 'com.google.auto.service', module: 'auto-service' diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index fcc42603cb2..4b3ba3b97da 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.44.1' -implementation 'io.grpc:grpc-okhttp:1.44.1' +implementation 'io.grpc:grpc-android:1.53.0' +implementation 'io.grpc:grpc-okhttp:1.53.0' ``` You also need permission to access the device's network state in your diff --git a/examples/BUILD.bazel b/examples/BUILD.bazel index e7f00381ad1..2a5ed52b35c 100644 --- a/examples/BUILD.bazel +++ b/examples/BUILD.bazel @@ -135,3 +135,40 @@ java_binary( ":examples", ], ) + +java_binary( + name = "load-balance-client", + testonly = 1, + main_class = "io.grpc.examples.loadbalance.LoadBalanceClient", + runtime_deps = [ + ":examples", + ], +) + +java_binary( + name = "load-balance-server", + testonly = 1, + main_class = "io.grpc.examples.loadbalance.LoadBalanceServer", + runtime_deps = [ + ":examples", + ], +) + +java_binary( + name = "name-resolve-client", + testonly = 1, + main_class = "io.grpc.examples.nameresolve.NameResolveClient", + runtime_deps = [ + ":examples", + ], +) + +java_binary( + name = "name-resolve-server", + testonly = 1, + main_class = "io.grpc.examples.nameresolve.NameResolveServer", + runtime_deps = [ + ":examples", + ], +) + diff --git a/examples/WORKSPACE b/examples/WORKSPACE index bd139ddd406..671176a8469 100644 --- a/examples/WORKSPACE +++ b/examples/WORKSPACE @@ -16,9 +16,9 @@ local_repository( http_archive( name = "rules_jvm_external", - sha256 = "cd1a77b7b02e8e008439ca76fd34f5b07aecb8c752961f9640dea15e9e5ba1ca", - strip_prefix = "rules_jvm_external-4.2", - url = "https://github.com/bazelbuild/rules_jvm_external/archive/4.2.zip", + sha256 = "c21ce8b8c4ccac87c809c317def87644cdc3a9dd650c74f41698d761c95175f3", + strip_prefix = "rules_jvm_external-1498ac6ccd3ea9cdb84afed65aa257c57abf3e0a", + url = "https://github.com/bazelbuild/rules_jvm_external/archive/1498ac6ccd3ea9cdb84afed65aa257c57abf3e0a.zip", ) load("@rules_jvm_external//:defs.bzl", "maven_install") diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 65803334c58..d548ef4ac17 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -32,9 +32,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.19.2' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.53.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.53.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' - testImplementation 'junit:junit:4.12' + testImplementation 'junit:junit:4.13.2' testImplementation 'com.google.truth:truth:1.0.1' - testImplementation 'io.grpc:grpc-testing:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.53.0' // CURRENT_GRPC_VERSION } diff --git a/examples/android/clientcache/build.gradle b/examples/android/clientcache/build.gradle index 8a94a30191e..4e6acca775c 100644 --- a/examples/android/clientcache/build.gradle +++ b/examples/android/clientcache/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.2.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.18" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.1" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 12b6899d881..3fab269de3c 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -30,9 +30,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.19.2' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.53.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.53.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/helloworld/app/src/main/AndroidManifest.xml b/examples/android/helloworld/app/src/main/AndroidManifest.xml index eee4057cd0c..0beff2dc840 100644 --- a/examples/android/helloworld/app/src/main/AndroidManifest.xml +++ b/examples/android/helloworld/app/src/main/AndroidManifest.xml @@ -11,7 +11,8 @@ android:theme="@style/Base.V7.Theme.AppCompat.Light" > + android:label="@string/app_name" + android:exported="true"> diff --git a/examples/android/helloworld/build.gradle b/examples/android/helloworld/build.gradle index 8a94a30191e..4e6acca775c 100644 --- a/examples/android/helloworld/build.gradle +++ b/examples/android/helloworld/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.2.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.18" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.1" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index f7e9c1a4bb6..cbf99e2c602 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -30,9 +30,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.19.2' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.53.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.53.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/build.gradle b/examples/android/routeguide/build.gradle index b1083bb867a..d0cd50e82c3 100644 --- a/examples/android/routeguide/build.gradle +++ b/examples/android/routeguide/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.2.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.18" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.1" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index 9638eb38a08..2cb260a2916 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -31,9 +31,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.19.2' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.53.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.53.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.53.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/build.gradle b/examples/android/strictmode/build.gradle index 8a94a30191e..4e6acca775c 100644 --- a/examples/android/strictmode/build.gradle +++ b/examples/android/strictmode/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.2.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.18" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.1" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/build.gradle b/examples/build.gradle index 8a5599f0a65..23534abdb5c 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -1,8 +1,7 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.18' + id 'com.google.protobuf' version '0.9.1' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } @@ -22,8 +21,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.19.2' +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' def protocVersion = protobufVersion dependencies { @@ -37,7 +36,7 @@ dependencies { runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation "io.grpc:grpc-testing:${grpcVersion}" - testImplementation "junit:junit:4.12" + testImplementation "junit:junit:4.13.2" testImplementation "org.mockito:mockito-core:3.4.0" } @@ -140,6 +139,34 @@ task manualFlowControlServer(type: CreateStartScripts) { classpath = startScripts.classpath } +task loadBalanceServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.loadbalance.LoadBalanceServer' + applicationName = 'load-balance-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task loadBalanceClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.loadbalance.LoadBalanceClient' + applicationName = 'load-balance-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task nameResolveServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.nameresolve.NameResolveServer' + applicationName = 'name-resolve-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task nameResolveClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.nameresolve.NameResolveClient' + applicationName = 'name-resolve-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + applicationDistribution.into('bin') { from(routeGuideServer) from(routeGuideClient) @@ -152,5 +179,9 @@ applicationDistribution.into('bin') { from(compressingHelloWorldClient) from(manualFlowControlClient) from(manualFlowControlServer) + from(loadBalanceServer) + from(loadBalanceClient) + from(nameResolveServer) + from(nameResolveClient) fileMode = 0755 } diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index b29c159e574..6eecd867b9f 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -23,8 +23,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.19.2' +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' dependencies { // grpc-alts transitively depends on grpc-netty-shaded, grpc-protobuf, and grpc-stub diff --git a/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsClient.java b/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsClient.java index 96351808b25..991ad777309 100644 --- a/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsClient.java +++ b/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsClient.java @@ -16,7 +16,8 @@ package io.grpc.examples.alts; -import io.grpc.alts.AltsChannelBuilder; +import io.grpc.alts.AltsChannelCredentials; +import io.grpc.Grpc; import io.grpc.ManagedChannel; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; @@ -81,7 +82,8 @@ private void parseArgs(String[] args) { private void run(String[] args) throws InterruptedException { parseArgs(args); ExecutorService executor = Executors.newFixedThreadPool(1); - ManagedChannel channel = AltsChannelBuilder.forTarget(serverAddress).executor(executor).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, AltsChannelCredentials.create()) + .executor(executor).build(); try { GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel); HelloReply resp = stub.sayHello(HelloRequest.newBuilder().setName("Waldo").build()); diff --git a/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsServer.java b/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsServer.java index fd662069cfd..6bc5226bf59 100644 --- a/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsServer.java +++ b/examples/example-alts/src/main/java/io/grpc/examples/alts/HelloWorldAltsServer.java @@ -16,7 +16,8 @@ package io.grpc.examples.alts; -import io.grpc.alts.AltsServerBuilder; +import io.grpc.alts.AltsServerCredentials; +import io.grpc.Grpc; import io.grpc.Server; import io.grpc.examples.helloworld.GreeterGrpc.GreeterImplBase; import io.grpc.examples.helloworld.HelloReply; @@ -82,7 +83,7 @@ private void parseArgs(String[] args) { private void start(String[] args) throws IOException, InterruptedException { parseArgs(args); server = - AltsServerBuilder.forPort(port) + Grpc.newServerBuilderForPort(port, AltsServerCredentials.create()) .addService(this) .executor(Executors.newFixedThreadPool(1)) .build(); diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 40bf8968b79..c323b34e211 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -23,8 +23,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.19.2' +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 5d56b44ab1e..ad4a3fd3767 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,14 +6,14 @@ jar - 1.45.0-SNAPSHOT + 1.53.0 example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.45.0-SNAPSHOT - 3.19.2 + 1.53.0 + 3.21.7 1.7 1.7 @@ -73,7 +73,7 @@ junit junit - 4.12 + 4.13.2 test diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index abf4a4e5af6..5964c22ac1a 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,8 +21,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.19.2' +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" @@ -31,7 +31,7 @@ dependencies { compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" - testImplementation 'junit:junit:4.12' + testImplementation 'junit:junit:4.13.2' testImplementation "io.grpc:grpc-testing:${grpcVersion}" } diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 53d399e2460..02ab0243b6c 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,14 +6,14 @@ jar - 1.45.0-SNAPSHOT + 1.53.0 example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.45.0-SNAPSHOT - 3.19.2 + 1.53.0 + 3.21.7 1.7 1.7 @@ -58,7 +58,7 @@ junit junit - 4.12 + 4.13.2 test diff --git a/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java b/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java index a6f2175914e..3c63296d7fa 100644 --- a/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java +++ b/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java @@ -16,6 +16,8 @@ package io.grpc.examples.hostname; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; @@ -49,7 +51,7 @@ public static void main(String[] args) throws IOException, InterruptedException hostname = args[1]; } HealthStatusManager health = new HealthStatusManager(); - final Server server = ServerBuilder.forPort(port) + final Server server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new HostnameGreeter(hostname)) .addService(ProtoReflectionService.newInstance()) .addService(health.getHealthService()) diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index 1e4acd3deb9..c5d1ce617b4 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -22,8 +22,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.19.2' +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' def protocVersion = protobufVersion dependencies { @@ -37,7 +37,7 @@ dependencies { runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation "io.grpc:grpc-testing:${grpcVersion}" - testImplementation "junit:junit:4.12" + testImplementation "junit:junit:4.13.2" testImplementation "org.mockito:mockito-core:3.4.0" } diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index d78e4e6a630..337077d6b1e 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,15 +7,15 @@ jar - 1.45.0-SNAPSHOT + 1.53.0 example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.45.0-SNAPSHOT - 3.19.2 - 3.19.2 + 1.53.0 + 3.21.7 + 3.21.7 1.7 1.7 @@ -71,7 +71,7 @@ junit junit - 4.12 + 4.13.2 test diff --git a/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthClient.java b/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthClient.java index f6ea4c57e45..c5769625807 100644 --- a/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthClient.java +++ b/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthClient.java @@ -17,8 +17,9 @@ package io.grpc.examples.jwtauth; import io.grpc.CallCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; @@ -42,12 +43,9 @@ public class AuthClient { AuthClient(CallCredentials callCredentials, String host, int port) { this( callCredentials, - ManagedChannelBuilder - .forAddress(host, port) - // Channels are secure by default (via SSL/TLS). For this example we disable TLS - // to avoid needing certificates, but it is recommended to use a secure channel - // while passing credentials. - .usePlaintext() + // For this example we use plaintext to avoid needing certificates, but it is + // recommended to use TlsChannelCredentials. + Grpc.newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()) .build()); } diff --git a/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthServer.java b/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthServer.java index 90e7dff1458..208645e4fad 100644 --- a/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthServer.java +++ b/examples/example-jwt-auth/src/main/java/io/grpc/examples/jwtauth/AuthServer.java @@ -16,8 +16,9 @@ package io.grpc.examples.jwtauth; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; @@ -41,7 +42,7 @@ public AuthServer(int port) { } private void start() throws IOException { - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) .intercept(new JwtServerInterceptor()) // add the JwtServerInterceptor .build() diff --git a/examples/example-orca/README.md b/examples/example-orca/README.md new file mode 100644 index 00000000000..7461b3ba397 --- /dev/null +++ b/examples/example-orca/README.md @@ -0,0 +1,44 @@ +gRPC ORCA Example +================ + +The ORCA example consists of a Hello World client and a Hello World server. Out-of-the-box the +client behaves the same the hello-world version and the server behaves similar to the +example-hostname. In addition, they have been integrated with backend metrics reporting features. + +### Build the example + +Build the ORCA hello-world example client & server. From the `grpc-java/examples/examples-orca` +directory: +``` +$ ../gradlew installDist +``` + +This creates the scripts `build/install/example-orca/bin/custom-backend-metrics-client` and +`build/install/example-orca/bin/custom-backend-metrics-server`. + +### Run the example + +To use ORCA, you have to instrument both the client and the server. +At the client, in your own load balancer policy, you use gRPC APIs to install listeners to receive +per-query and out-of-band metric reports. +At the server, you add a server interceptor provided by gRPC in order to send per-query backend metrics. +And you register a bindable service, also provided by gRPC, in order to send out-of-band backend metrics. +Meanwhile, you update the metrics data from your own measurements. + +That's it! In this example, we simply put all the necessary pieces together to demonstrate the +metrics reporting mechanism. + +1. To start the ORCA enabled example server on its default port of 50051, run: +``` +$ ./build/install/example-orca/bin/custom-backend-metrics-server +``` + +2. In a different terminal window, run the ORCA enabled example client: +``` +$ ./build/install/example-orca/bin/custom-backend-metrics-client "orca tester" 1500 +``` +The first command line argument (`orca tester`) is the name you wish to include in +the greeting request to the server and the second argument +(`1500`) is the time period (in milliseconds) you want to run the client before it shut downed so that it will show +more periodic backend metrics reports. You are expected to see the metrics data printed out. Try it! + diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle new file mode 100644 index 00000000000..233f303c481 --- /dev/null +++ b/examples/example-orca/build.gradle @@ -0,0 +1,62 @@ +plugins { + id 'application' // Provide convenience executables for trying out the examples. + // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions + id 'com.google.protobuf' version '0.8.17' + // Generate IntelliJ IDEA's .idea & .iml project files + id 'idea' + id 'java' +} + +repositories { + maven { // The google mirror is less flaky than mavenCentral() + url "https://maven-central.storage-download.googleapis.com/maven2/" } + mavenCentral() + mavenLocal() +} + +sourceCompatibility = 1.8 +targetCompatibility = 1.8 + +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' + +dependencies { + implementation "io.grpc:grpc-protobuf:${grpcVersion}" + implementation "io.grpc:grpc-services:${grpcVersion}" + implementation "io.grpc:grpc-stub:${grpcVersion}" + implementation "io.grpc:grpc-xds:${grpcVersion}" + compileOnly "org.apache.tomcat:annotations-api:6.0.53" + +} + +protobuf { + protoc { artifact = "com.google.protobuf:protoc:${protocVersion}" } + plugins { + grpc { artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" } + } + generateProtoTasks { + all()*.plugins { grpc {} } + } +} + +startScripts.enabled = false + +task CustomBackendMetricsClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.orca.CustomBackendMetricsClient' + applicationName = 'custom-backend-metrics-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task CustomBackendMetricsServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.orca.CustomBackendMetricsServer' + applicationName = 'custom-backend-metrics-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +applicationDistribution.into('bin') { + from(CustomBackendMetricsClient) + from(CustomBackendMetricsServer) + fileMode = 0755 +} diff --git a/examples/example-orca/settings.gradle b/examples/example-orca/settings.gradle new file mode 100644 index 00000000000..3c62dc663ce --- /dev/null +++ b/examples/example-orca/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'example-orca' diff --git a/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsClient.java b/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsClient.java new file mode 100644 index 00000000000..66143b44364 --- /dev/null +++ b/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsClient.java @@ -0,0 +1,106 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.orca; + +import static io.grpc.examples.orca.CustomBackendMetricsLoadBalancerProvider.EXAMPLE_LOAD_BALANCER; + +import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.LoadBalancerRegistry; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A simple xDS client that requests a greeting from {@link CustomBackendMetricsServer}. + * The client channel is configured to use an example load balancer policy + * {@link CustomBackendMetricsLoadBalancerProvider} which integrates with ORCA metrics reporting. + */ +public class CustomBackendMetricsClient { + private static final Logger logger = Logger.getLogger(CustomBackendMetricsClient.class.getName()); + + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + /** Construct client for accessing HelloWorld server using the existing channel. */ + public CustomBackendMetricsClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + /** Say hello to server. */ + public void greet(String name) { + logger.info("Will try to greet " + name + " ..."); + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } + + /** + * Greet server. If provided, the first element of {@code args} is the name to use in the + * greeting. The second argument is the target server. + */ + public static void main(String[] args) throws Exception { + String user = "orca tester"; + // The example defaults to the same behavior as the hello world example. + // To receive more periodic OOB metrics reports, use duration argument to a longer value. + String target = "localhost:50051"; + long timeBeforeShutdown = 1500; + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [name [duration [target]]]"); + System.err.println(""); + System.err.println(" name The name you wish to be greeted by. Defaults to " + user); + System.err.println(" duration The time period in milliseconds that the client application " + + "wait until shutdown. Defaults to " + timeBeforeShutdown); + System.err.println(" target The server to connect to. Defaults to " + target); + System.exit(1); + } + user = args[0]; + } + if (args.length > 1) { + timeBeforeShutdown = Long.parseLong(args[1]); + } + + if (args.length > 2) { + target = args[2]; + } + + LoadBalancerRegistry.getDefaultRegistry().register( + new CustomBackendMetricsLoadBalancerProvider()); + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .defaultLoadBalancingPolicy(EXAMPLE_LOAD_BALANCER) + .build(); + try { + CustomBackendMetricsClient client = new CustomBackendMetricsClient(channel); + client.greet(user); + Thread.sleep(timeBeforeShutdown); + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } +} diff --git a/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsLoadBalancerProvider.java b/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsLoadBalancerProvider.java new file mode 100644 index 00000000000..c42fb7cdc10 --- /dev/null +++ b/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsLoadBalancerProvider.java @@ -0,0 +1,148 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.orca; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.services.MetricReport; +import io.grpc.util.ForwardingLoadBalancer; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.xds.orca.OrcaOobUtil; +import io.grpc.xds.orca.OrcaPerRequestUtil; +import java.util.concurrent.TimeUnit; + +/** + * Implements a test LB policy that receives ORCA load reports. + * The load balancer mostly delegates to {@link io.grpc.internal.PickFirstLoadBalancerProvider}, + * in addition, it installs {@link OrcaOobUtil.OrcaOobReportListener} and + * {@link OrcaPerRequestUtil.OrcaPerRequestReportListener} to be notified with backend metrics. + */ +final class CustomBackendMetricsLoadBalancerProvider extends LoadBalancerProvider { + + static final String EXAMPLE_LOAD_BALANCER = "example_backend_metrics_load_balancer"; + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new CustomBackendMetricsLoadBalancer(helper); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return EXAMPLE_LOAD_BALANCER; + } + + private final class CustomBackendMetricsLoadBalancer extends ForwardingLoadBalancer { + private LoadBalancer delegate; + + public CustomBackendMetricsLoadBalancer(LoadBalancer.Helper helper) { + this.delegate = LoadBalancerRegistry.getDefaultRegistry() + .getProvider("pick_first") + .newLoadBalancer(new CustomBackendMetricsLoadBalancerHelper(helper)); + } + + @Override + public LoadBalancer delegate() { + return delegate; + } + + private final class CustomBackendMetricsLoadBalancerHelper + extends ForwardingLoadBalancerHelper { + private final LoadBalancer.Helper orcaHelper; + + public CustomBackendMetricsLoadBalancerHelper(LoadBalancer.Helper helper) { + this.orcaHelper = OrcaOobUtil.newOrcaReportingHelper(helper); + } + + @Override + public LoadBalancer.Subchannel createSubchannel(LoadBalancer.CreateSubchannelArgs args) { + LoadBalancer.Subchannel subchannel = super.createSubchannel(args); + // Installs ORCA OOB metrics reporting listener and configures to receive report every 1s. + // The interval can not be smaller than server minimum report interval configuration, + // otherwise it is treated as server minimum report interval. + OrcaOobUtil.setListener(subchannel, new OrcaOobUtil.OrcaOobReportListener() { + @Override + public void onLoadReport(MetricReport orcaLoadReport) { + System.out.println("Example load balancer received OOB metrics report:\n" + + orcaLoadReport); + } + }, + OrcaOobUtil.OrcaReportingConfig.newBuilder() + .setReportInterval(1, TimeUnit.SECONDS) + .build() + ); + return subchannel; + } + + @Override + public void updateBalancingState(ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate().updateBalancingState(newState, new MayReportLoadPicker(newPicker)); + } + + @Override + public LoadBalancer.Helper delegate() { + return orcaHelper; + } + } + + private final class MayReportLoadPicker extends LoadBalancer.SubchannelPicker { + private LoadBalancer.SubchannelPicker delegate; + + public MayReportLoadPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + if (result.getSubchannel() == null) { + return result; + } + // Installs ORCA per-query metrics reporting listener. + final OrcaPerRequestUtil.OrcaPerRequestReportListener orcaListener = + new OrcaPerRequestUtil.OrcaPerRequestReportListener() { + @Override + public void onLoadReport(MetricReport orcaLoadReport) { + System.out.println("Example load balancer received per-rpc metrics report:\n" + + orcaLoadReport); + } + }; + if (result.getStreamTracerFactory() == null) { + return LoadBalancer.PickResult.withSubchannel( + result.getSubchannel(), + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(orcaListener)); + } else { + return LoadBalancer.PickResult.withSubchannel( + result.getSubchannel(), + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + result.getStreamTracerFactory(), orcaListener)); + } + } + } + } +} diff --git a/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsServer.java b/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsServer.java new file mode 100644 index 00000000000..b04664da363 --- /dev/null +++ b/examples/example-orca/src/main/java/io/grpc/examples/orca/CustomBackendMetricsServer.java @@ -0,0 +1,142 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.orca; + +import com.google.common.collect.ImmutableMap; +import io.grpc.BindableService; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; +import io.grpc.Server; +import io.grpc.services.CallMetricRecorder; +import io.grpc.services.InternalCallMetricRecorder; +import io.grpc.services.MetricRecorder; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.orca.OrcaMetricReportingServerInterceptor; +import io.grpc.xds.orca.OrcaServiceImpl; +import java.io.IOException; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * Server that manages startup/shutdown of a {@code Greeter} server. + */ +public class CustomBackendMetricsServer { + private static final Logger logger = Logger.getLogger(CustomBackendMetricsServer.class.getName()); + + private Server server; + private static Random random = new Random(); + private MetricRecorder metricRecorder; + + private void start() throws IOException { + /* The port on which the server should run */ + int port = 50051; + + ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); + metricRecorder = MetricRecorder.newInstance(); + // Configure OOB metrics reporting minimum report interval to be 1s. This allows client + // configuration to be as short as 1s, suitable for test demonstration. + BindableService orcaOobService = + OrcaServiceImpl.createService(executor, metricRecorder, 1, TimeUnit.SECONDS); + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .addService(new GreeterImpl()) + // Enable OOB custom backend metrics reporting. + .addService(orcaOobService) + // Enable per-query custom backend metrics reporting. + .intercept(OrcaMetricReportingServerInterceptor.getInstance()) + .build() + .start(); + logger.info("Server started, listening on " + port); + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + // Use stderr here since the logger may have been reset by its JVM shutdown hook. + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + CustomBackendMetricsServer.this.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** server shut down"); + } + }); + } + + private void stop() throws InterruptedException { + if (server != null) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + /** + * Await termination on the main thread since the grpc library uses daemon threads. + */ + private void blockUntilShutdown() throws InterruptedException { + if (server != null) { + server.awaitTermination(); + } + } + + /** + * Main launches the server from the command line. + */ + public static void main(String[] args) throws IOException, InterruptedException { + CustomBackendMetricsServer server = new CustomBackendMetricsServer(); + server.start(); + server.blockUntilShutdown(); + } + + class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName()).build(); + double cpuUtilization = random.nextDouble(); + double memoryUtilization = random.nextDouble(); + Map utilization = ImmutableMap.of("util", random.nextDouble()); + Map requestCost = ImmutableMap.of("cost", random.nextDouble()); + // Sets per-query backend metrics to a random test report. + CallMetricRecorder.getCurrent() + .recordCpuUtilizationMetric(cpuUtilization) + .recordMemoryUtilizationMetric(memoryUtilization) + .recordCallMetric("cost", requestCost.get("cost")) + .recordUtilizationMetric("util", utilization.get("util")); + System.out.println(String.format("Hello World Server updates RPC metrics data:\n" + + "cpu: %s, memory: %s, request cost: %s, utilization: %s\n", + cpuUtilization, memoryUtilization, requestCost, utilization)); + + cpuUtilization = random.nextDouble(); + memoryUtilization = random.nextDouble(); + utilization = ImmutableMap.of("util", random.nextDouble()); + // Sets OOB backend metrics to a random test report. + metricRecorder.setCpuUtilizationMetric(cpuUtilization); + metricRecorder.setMemoryUtilizationMetric(memoryUtilization); + metricRecorder.setAllUtilizationMetrics(utilization); + System.out.println(String.format("Hello World Server updates OOB metrics data:\n" + + "cpu: %s, memory: %s, utilization: %s\n", + cpuUtilization, memoryUtilization, utilization)); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/example-orca/src/main/proto/helloworld/helloworld.proto b/examples/example-orca/src/main/proto/helloworld/helloworld.proto new file mode 100644 index 00000000000..77184aa8326 --- /dev/null +++ b/examples/example-orca/src/main/proto/helloworld/helloworld.proto @@ -0,0 +1,37 @@ +// Copyright 2022 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/example-servlet/README.md b/examples/example-servlet/README.md new file mode 100644 index 00000000000..f2e6188069a --- /dev/null +++ b/examples/example-servlet/README.md @@ -0,0 +1,37 @@ +# Hello World Example using Servlets + +This example uses Java Servlets instead of Netty for the gRPC server. This example requires `grpc-java` +and `protoc-gen-grpc-java` to already be built. You are strongly encouraged to check out a git release +tag, since these builds will already be available. + +```bash +git checkout v.. +``` +Otherwise, you must follow [COMPILING](../../COMPILING.md). + +To build the example, + +1. **[Install gRPC Java library SNAPSHOT locally, including code generation plugin](../../COMPILING.md) (Only need this step for non-released versions, e.g. master HEAD).** + +2. In this directory, build the war file +```bash +$ ../gradlew war +``` + +To run this, deploy the war, now found in `build/libs/example-servlet.war` to your choice of servlet +container. Note that this container must support the Servlet 4.0 spec, for this particular example must +use `javax.servlet` packages instead of the more modern `jakarta.servlet`, though there is a `grpc-servlet-jakarta` +artifact that can be used for Jakarta support. Be sure to enable http/2 support in the servlet container, +or clients will not be able to connect. + +To test that this is working properly, build the HelloWorldClient example and direct it to connect to your +http/2 server. From the parent directory: + +1. Build the executables: +```bash +$ ../gradlew installDist +``` +2. Run the client app, specifying the name to say hello to and the server's address: +```bash +$ ./build/install/examples/bin/hello-world-client World localhost:8080 +``` \ No newline at end of file diff --git a/examples/example-servlet/build.gradle b/examples/example-servlet/build.gradle new file mode 100644 index 00000000000..be58ca66794 --- /dev/null +++ b/examples/example-servlet/build.gradle @@ -0,0 +1,46 @@ +plugins { + // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions + id 'com.google.protobuf' version '0.8.17' + // Generate IntelliJ IDEA's .idea & .iml project files + id 'idea' + id 'war' +} + +repositories { + maven { // The google mirror is less flaky than mavenCentral() + url "https://maven-central.storage-download.googleapis.com/maven2/" } + mavenLocal() +} + +sourceCompatibility = 1.8 +targetCompatibility = 1.8 + +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' + +dependencies { + implementation "io.grpc:grpc-protobuf:${grpcVersion}", + "io.grpc:grpc-servlet:${grpcVersion}", + "io.grpc:grpc-stub:${grpcVersion}" + + providedImplementation "javax.servlet:javax.servlet-api:4.0.1", + "org.apache.tomcat:annotations-api:6.0.53" +} + +protobuf { + protoc { artifact = "com.google.protobuf:protoc:${protocVersion}" } + plugins { grpc { artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" } } + generateProtoTasks { + all()*.plugins { grpc {} } + } +} + +// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. +sourceSets { + main { + java { + srcDirs 'build/generated/source/proto/main/grpc' + srcDirs 'build/generated/source/proto/main/java' + } + } +} diff --git a/examples/example-servlet/settings.gradle b/examples/example-servlet/settings.gradle new file mode 100644 index 00000000000..273558dd9cf --- /dev/null +++ b/examples/example-servlet/settings.gradle @@ -0,0 +1,8 @@ +pluginManagement { + repositories { + maven { // The google mirror is less flaky than mavenCentral() + url "https://maven-central.storage-download.googleapis.com/maven2/" + } + gradlePluginPortal() + } +} diff --git a/examples/example-servlet/src/main/java/io/grpc/servlet/examples/helloworld/HelloWorldServlet.java b/examples/example-servlet/src/main/java/io/grpc/servlet/examples/helloworld/HelloWorldServlet.java new file mode 100644 index 00000000000..a970c26a119 --- /dev/null +++ b/examples/example-servlet/src/main/java/io/grpc/servlet/examples/helloworld/HelloWorldServlet.java @@ -0,0 +1,79 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet.examples.helloworld; + +import io.grpc.stub.StreamObserver; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.servlet.ServletAdapter; +import io.grpc.servlet.ServletServerBuilder; +import java.io.IOException; +import javax.servlet.annotation.WebServlet; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * A servlet that hosts a gRPC server over HTTP/2 and shares the resource URI for the normal servlet + * clients over HTTP/1.0+. + * + *

For creating a servlet that solely serves gRPC services, do not follow this example, simply + * extend or register a {@link io.grpc.servlet.GrpcServlet} instead. + */ +@WebServlet(urlPatterns = {"/helloworld.Greeter/SayHello"}, asyncSupported = true) +public class HelloWorldServlet extends HttpServlet { + private static final long serialVersionUID = 1L; + + private final ServletAdapter servletAdapter = + new ServletServerBuilder().addService(new GreeterImpl()).buildServletAdapter(); + + private static final class GreeterImpl extends GreeterGrpc.GreeterImplBase { + GreeterImpl() {} + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName()).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException { + response.setContentType("text/html"); + response.getWriter().println("

Hello World!

"); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws IOException { + if (ServletAdapter.isGrpc(request)) { + servletAdapter.doPost(request, response); + } else { + response.setContentType("text/html"); + response.getWriter().println("

Hello non-gRPC client!

"); + } + } + + @Override + public void destroy() { + servletAdapter.destroy(); + super.destroy(); + } +} diff --git a/examples/example-servlet/src/main/proto/helloworld/helloworld.proto b/examples/example-servlet/src/main/proto/helloworld/helloworld.proto new file mode 100644 index 00000000000..c60d9416f1f --- /dev/null +++ b/examples/example-servlet/src/main/proto/helloworld/helloworld.proto @@ -0,0 +1,37 @@ +// Copyright 2015 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/example-servlet/src/main/webapp/WEB-INF/glassfish-web.xml b/examples/example-servlet/src/main/webapp/WEB-INF/glassfish-web.xml new file mode 100644 index 00000000000..426162a9d13 --- /dev/null +++ b/examples/example-servlet/src/main/webapp/WEB-INF/glassfish-web.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/examples/example-servlet/src/main/webapp/WEB-INF/jboss-web.xml b/examples/example-servlet/src/main/webapp/WEB-INF/jboss-web.xml new file mode 100644 index 00000000000..9c83263e0c9 --- /dev/null +++ b/examples/example-servlet/src/main/webapp/WEB-INF/jboss-web.xml @@ -0,0 +1,9 @@ + + + + / + diff --git a/examples/example-tls/BUILD.bazel b/examples/example-tls/BUILD.bazel index 3daa305c167..81913836766 100644 --- a/examples/example-tls/BUILD.bazel +++ b/examples/example-tls/BUILD.bazel @@ -25,6 +25,7 @@ java_library( ), runtime_deps = [ "@maven//:io_netty_netty_tcnative_boringssl_static", + "@maven//:io_netty_netty_tcnative_classes", ], deps = [ ":helloworld_java_grpc", diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 2eddaf6afbe..14554dfb128 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -23,8 +23,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.19.2' +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index a135dc51d94..f3f6a6d198b 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,15 +6,15 @@ jar - 1.45.0-SNAPSHOT + 1.53.0 example-tls https://github.com/grpc/grpc-java UTF-8 - 1.45.0-SNAPSHOT - 3.19.2 - 2.0.34.Final + 1.53.0 + 3.21.7 + 2.0.54.Final 1.7 1.7 @@ -49,13 +49,7 @@
io.grpc - grpc-netty - - - io.netty - netty-tcnative-boringssl-static - ${netty.tcnative.version} - runtime + grpc-netty-shaded diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 56fef5ef12d..c7af08d8e2d 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -22,9 +22,9 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.45.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.53.0' // CURRENT_GRPC_VERSION def nettyTcNativeVersion = '2.0.31.Final' -def protocVersion = '3.19.2' +def protocVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-xds/settings.gradle b/examples/example-xds/settings.gradle index 9f46a95b6f5..878f1f23ae3 100644 --- a/examples/example-xds/settings.gradle +++ b/examples/example-xds/settings.gradle @@ -1,3 +1 @@ rootProject.name = 'example-xds' - -includeBuild '..' diff --git a/examples/gradle/wrapper/gradle-wrapper.properties b/examples/gradle/wrapper/gradle-wrapper.properties index 2e6e5897b52..070cb702f09 100644 --- a/examples/gradle/wrapper/gradle-wrapper.properties +++ b/examples/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/examples/pom.xml b/examples/pom.xml index 20c67251689..cedc58e0447 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,18 +6,18 @@ jar - 1.45.0-SNAPSHOT + 1.53.0 examples https://github.com/grpc/grpc-java UTF-8 - 1.45.0-SNAPSHOT - 3.19.2 - 3.19.2 - - 1.7 - 1.7 + 1.53.0 + 3.21.7 + 3.21.7 + + 1.8 + 1.8 @@ -54,7 +54,7 @@ com.google.code.gson gson - 2.8.9 + 2.9.0 org.apache.tomcat @@ -70,7 +70,7 @@ junit junit - 4.12 + 4.13.2 test diff --git a/examples/src/main/java/io/grpc/examples/advanced/HelloJsonClient.java b/examples/src/main/java/io/grpc/examples/advanced/HelloJsonClient.java index 264fe76b7d3..9291e45bafe 100644 --- a/examples/src/main/java/io/grpc/examples/advanced/HelloJsonClient.java +++ b/examples/src/main/java/io/grpc/examples/advanced/HelloJsonClient.java @@ -20,8 +20,9 @@ import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.MethodDescriptor; import io.grpc.StatusRuntimeException; import io.grpc.examples.helloworld.GreeterGrpc; @@ -49,8 +50,7 @@ public final class HelloJsonClient { /** Construct client connecting to HelloWorld server at {@code host:port}. */ public HelloJsonClient(String host, int port) { - channel = ManagedChannelBuilder.forAddress(host, port) - .usePlaintext() + channel = Grpc.newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()) .build(); blockingStub = new HelloJsonStub(channel); } diff --git a/examples/src/main/java/io/grpc/examples/advanced/HelloJsonServer.java b/examples/src/main/java/io/grpc/examples/advanced/HelloJsonServer.java index 0a656dd52b8..0f4e5d28f16 100644 --- a/examples/src/main/java/io/grpc/examples/advanced/HelloJsonServer.java +++ b/examples/src/main/java/io/grpc/examples/advanced/HelloJsonServer.java @@ -19,8 +19,9 @@ import static io.grpc.stub.ServerCalls.asyncUnaryCall; import io.grpc.BindableService; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.ServerServiceDefinition; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; @@ -50,7 +51,7 @@ public class HelloJsonServer { private void start() throws IOException { /* The port on which the server should run */ int port = 50051; - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) .build() .start(); diff --git a/examples/src/main/java/io/grpc/examples/errorhandling/DetailErrorSample.java b/examples/src/main/java/io/grpc/examples/errorhandling/DetailErrorSample.java index b743b46b471..b026b6f32dc 100644 --- a/examples/src/main/java/io/grpc/examples/errorhandling/DetailErrorSample.java +++ b/examples/src/main/java/io/grpc/examples/errorhandling/DetailErrorSample.java @@ -27,11 +27,12 @@ import com.google.rpc.DebugInfo; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.Status; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.GreeterGrpc.GreeterBlockingStub; @@ -73,7 +74,8 @@ public static void main(String[] args) throws Exception { private ManagedChannel channel; void run() throws Exception { - Server server = ServerBuilder.forPort(0).addService(new GreeterGrpc.GreeterImplBase() { + Server server = Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(new GreeterGrpc.GreeterImplBase() { @Override public void sayHello(HelloRequest request, StreamObserver responseObserver) { Metadata trailers = new Metadata(); @@ -82,8 +84,8 @@ public void sayHello(HelloRequest request, StreamObserver responseOb .asRuntimeException(trailers)); } }).build().start(); - channel = - ManagedChannelBuilder.forAddress("localhost", server.getPort()).usePlaintext().build(); + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), InsecureChannelCredentials.create()).build(); blockingCall(); futureCallDirect(); diff --git a/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java b/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java index 4afd39b08fa..7e310433a90 100644 --- a/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java +++ b/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java @@ -25,11 +25,12 @@ import com.google.common.util.concurrent.Uninterruptibles; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.Status; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.GreeterGrpc.GreeterBlockingStub; @@ -55,15 +56,16 @@ public static void main(String [] args) throws Exception { void run() throws Exception { // Port 0 means that the operating system will pick an available port to use. - Server server = ServerBuilder.forPort(0).addService(new GreeterGrpc.GreeterImplBase() { + Server server = Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(new GreeterGrpc.GreeterImplBase() { @Override public void sayHello(HelloRequest request, StreamObserver responseObserver) { responseObserver.onError(Status.INTERNAL .withDescription("Eggplant Xerxes Crybaby Overbite Narwhal").asRuntimeException()); } }).build().start(); - channel = - ManagedChannelBuilder.forAddress("localhost", server.getPort()).usePlaintext().build(); + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), InsecureChannelCredentials.create()).build(); blockingCall(); futureCallDirect(); diff --git a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java index 410b0c7c14c..49e9cb36d53 100644 --- a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java +++ b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldClient.java @@ -16,8 +16,9 @@ package io.grpc.examples.experimental; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; @@ -42,8 +43,7 @@ public class CompressingHelloWorldClient { /** Construct client connecting to HelloWorld server at {@code host:port}. */ public CompressingHelloWorldClient(String host, int port) { - channel = ManagedChannelBuilder.forAddress(host, port) - .usePlaintext() + channel = Grpc.newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()) .build(); blockingStub = GreeterGrpc.newBlockingStub(channel); } diff --git a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerAllMethods.java b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerAllMethods.java index 23c51a6d26c..794d7196f35 100644 --- a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerAllMethods.java +++ b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerAllMethods.java @@ -20,9 +20,10 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Metadata; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; @@ -44,7 +45,7 @@ public class CompressingHelloWorldServerAllMethods { private void start() throws IOException { /* The port on which the server should run */ int port = 50051; - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) /* This method call adds the Interceptor to enable compressed server responses for all RPCs */ .intercept(new ServerInterceptor() { @Override diff --git a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerPerMethod.java b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerPerMethod.java index 0ccf38184d3..b7faa96d7b4 100644 --- a/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerPerMethod.java +++ b/examples/src/main/java/io/grpc/examples/experimental/CompressingHelloWorldServerPerMethod.java @@ -20,8 +20,9 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; @@ -40,7 +41,7 @@ public class CompressingHelloWorldServerPerMethod { private void start() throws IOException { /* The port on which the server should run */ int port = 50051; - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) .build() .start(); diff --git a/examples/src/main/java/io/grpc/examples/header/CustomHeaderClient.java b/examples/src/main/java/io/grpc/examples/header/CustomHeaderClient.java index 93d106dba9e..52287040ba0 100644 --- a/examples/src/main/java/io/grpc/examples/header/CustomHeaderClient.java +++ b/examples/src/main/java/io/grpc/examples/header/CustomHeaderClient.java @@ -19,8 +19,9 @@ import io.grpc.Channel; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; @@ -43,8 +44,8 @@ public class CustomHeaderClient { * A custom client. */ private CustomHeaderClient(String host, int port) { - originChannel = ManagedChannelBuilder.forAddress(host, port) - .usePlaintext() + originChannel = Grpc + .newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()) .build(); ClientInterceptor interceptor = new HeaderClientInterceptor(); Channel channel = ClientInterceptors.intercept(originChannel, interceptor); diff --git a/examples/src/main/java/io/grpc/examples/header/CustomHeaderServer.java b/examples/src/main/java/io/grpc/examples/header/CustomHeaderServer.java index ae80045603c..75a3d24934f 100644 --- a/examples/src/main/java/io/grpc/examples/header/CustomHeaderServer.java +++ b/examples/src/main/java/io/grpc/examples/header/CustomHeaderServer.java @@ -16,8 +16,9 @@ package io.grpc.examples.header; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.ServerInterceptors; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; @@ -39,7 +40,7 @@ public class CustomHeaderServer { private Server server; private void start() throws IOException { - server = ServerBuilder.forPort(PORT) + server = Grpc.newServerBuilderForPort(PORT, InsecureServerCredentials.create()) .addService(ServerInterceptors.intercept(new GreeterImpl(), new HeaderServerInterceptor())) .build() .start(); diff --git a/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldClient.java b/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldClient.java index 30f7beb49a0..429cceb50c3 100644 --- a/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldClient.java +++ b/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldClient.java @@ -20,6 +20,8 @@ import com.google.gson.Gson; import com.google.gson.stream.JsonReader; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; @@ -51,11 +53,8 @@ public class HedgingHelloWorldClient { /** Construct client connecting to HelloWorld server at {@code host:port}. */ public HedgingHelloWorldClient(String host, int port, boolean hedging) { - - ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forAddress(host, port) - // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid - // needing certificates. - .usePlaintext(); + ManagedChannelBuilder channelBuilder + = Grpc.newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()); if (hedging) { Map hedgingServiceConfig = new Gson() diff --git a/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldServer.java b/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldServer.java index b934e8514ac..784269eea1c 100644 --- a/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/hedging/HedgingHelloWorldServer.java @@ -16,9 +16,10 @@ package io.grpc.examples.hedging; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Metadata; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; @@ -43,7 +44,7 @@ public class HedgingHelloWorldServer { private void start() throws IOException { /* The port on which the server should run */ int port = 50051; - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) .intercept(new LatencyInjectionInterceptor()) .build() diff --git a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldClient.java b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldClient.java index d00bca1e216..6b186facf46 100644 --- a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldClient.java +++ b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldClient.java @@ -17,8 +17,9 @@ package io.grpc.examples.helloworld; import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; import java.util.concurrent.TimeUnit; import java.util.logging.Level; @@ -81,10 +82,10 @@ public static void main(String[] args) throws Exception { // Create a communication channel to the server, known as a Channel. Channels are thread-safe // and reusable. It is common to create channels at the beginning of your application and reuse // them until the application shuts down. - ManagedChannel channel = ManagedChannelBuilder.forTarget(target) - // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid - // needing certificates. - .usePlaintext() + // + // For the example we use plaintext insecure credentials to avoid needing TLS certificates. To + // use TLS, use TlsChannelCredentials instead. + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) .build(); try { HelloWorldClient client = new HelloWorldClient(channel); diff --git a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java index 12836a4c828..81027587031 100644 --- a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java @@ -16,8 +16,9 @@ package io.grpc.examples.helloworld; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.concurrent.TimeUnit; @@ -34,7 +35,7 @@ public class HelloWorldServer { private void start() throws IOException { /* The port on which the server should run */ int port = 50051; - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) .build() .start(); diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java new file mode 100644 index 00000000000..f562f0ac107 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java @@ -0,0 +1,110 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.loadbalance; + +import com.google.common.collect.ImmutableMap; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.Status; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static io.grpc.examples.loadbalance.LoadBalanceClient.exampleServiceName; + +public class ExampleNameResolver extends NameResolver { + + private Listener2 listener; + + private final URI uri; + + private final Map> addrStore; + + public ExampleNameResolver(URI targetUri) { + this.uri = targetUri; + // This is a fake name resolver, so we just hard code the address here. + addrStore = ImmutableMap.>builder() + .put(exampleServiceName, + Stream.iterate(LoadBalanceServer.startPort,p->p+1) + .limit(LoadBalanceServer.serverCount) + .map(port->new InetSocketAddress("localhost",port)) + .collect(Collectors.toList()) + ) + .build(); + } + + @Override + public String getServiceAuthority() { + // Be consistent with behavior in grpc-go, authority is saved in Host field of URI. + if (uri.getHost() != null) { + return uri.getHost(); + } + return "no host"; + } + + @Override + public void shutdown() { + } + + @Override + public void start(Listener2 listener) { + this.listener = listener; + this.resolve(); + } + + @Override + public void refresh() { + this.resolve(); + } + + private void resolve() { + List addresses = addrStore.get(uri.getPath().substring(1)); + try { + List equivalentAddressGroup = addresses.stream() + // convert to socket address + .map(this::toSocketAddress) + // every socket address is a single EquivalentAddressGroup, so they can be accessed randomly + .map(Arrays::asList) + .map(this::addrToEquivalentAddressGroup) + .collect(Collectors.toList()); + + ResolutionResult resolutionResult = ResolutionResult.newBuilder() + .setAddresses(equivalentAddressGroup) + .build(); + + this.listener.onResult(resolutionResult); + + } catch (Exception e){ + // when error occurs, notify listener + this.listener.onError(Status.UNAVAILABLE.withDescription("Unable to resolve host ").withCause(e)); + } + } + + private SocketAddress toSocketAddress(InetSocketAddress address) { + return new InetSocketAddress(address.getHostName(), address.getPort()); + } + + private EquivalentAddressGroup addrToEquivalentAddressGroup(List addrList) { + return new EquivalentAddressGroup(addrList); + } +} diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolverProvider.java b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolverProvider.java new file mode 100644 index 00000000000..ee966fd044c --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolverProvider.java @@ -0,0 +1,47 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.loadbalance; + +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; + +import java.net.URI; + +import static io.grpc.examples.loadbalance.LoadBalanceClient.exampleScheme; + +public class ExampleNameResolverProvider extends NameResolverProvider { + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new ExampleNameResolver(targetUri); + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + // gRPC choose the first NameResolverProvider that supports the target URI scheme. + public String getDefaultScheme() { + return exampleScheme; + } +} diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceClient.java b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceClient.java new file mode 100644 index 00000000000..97444922871 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceClient.java @@ -0,0 +1,85 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.loadbalance; + +import io.grpc.*; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; + +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class LoadBalanceClient { + private static final Logger logger = Logger.getLogger(LoadBalanceClient.class.getName()); + + public static final String exampleScheme = "example"; + public static final String exampleServiceName = "lb.example.grpc.io"; + + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + public LoadBalanceClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + public void greet(String name) { + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } + + + public static void main(String[] args) throws Exception { + NameResolverRegistry.getDefaultRegistry().register(new ExampleNameResolverProvider()); + + String target = String.format("%s:///%s", exampleScheme, exampleServiceName); + + logger.info("Use default first_pick load balance policy"); + ManagedChannel channel = ManagedChannelBuilder.forTarget(target) + .usePlaintext() + .build(); + try { + LoadBalanceClient client = new LoadBalanceClient(channel); + for (int i = 0; i < 5; i++) { + client.greet("request" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + + logger.info("Change to round_robin policy"); + channel = ManagedChannelBuilder.forTarget(target) + .defaultLoadBalancingPolicy("round_robin") + .usePlaintext() + .build(); + try { + LoadBalanceClient client = new LoadBalanceClient(channel); + for (int i = 0; i < 5; i++) { + client.greet("request" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } +} diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java new file mode 100644 index 00000000000..c97d209497a --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java @@ -0,0 +1,94 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.loadbalance; + +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.stub.StreamObserver; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +public class LoadBalanceServer { + private static final Logger logger = Logger.getLogger(LoadBalanceServer.class.getName()); + static public final int serverCount = 3; + static public final int startPort = 50051; + private Server[] servers; + + private void start() throws IOException { + servers = new Server[serverCount]; + for (int i = 0; i < serverCount; i++) { + int port = startPort + i; + servers[i] = ServerBuilder.forPort(port) + .addService(new GreeterImpl(port)) + .build() + .start(); + logger.info("Server started, listening on " + port); + } + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + LoadBalanceServer.this.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** server shut down"); + })); + } + + private void stop() throws InterruptedException { + for (int i = 0; i < serverCount; i++) { + if (servers[i] != null) { + servers[i].shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + } + + private void blockUntilShutdown() throws InterruptedException { + for (int i = 0; i < serverCount; i++) { + if (servers[i] != null) { + servers[i].awaitTermination(); + } + } + } + + public static void main(String[] args) throws IOException, InterruptedException { + final LoadBalanceServer server = new LoadBalanceServer(); + server.start(); + server.blockUntilShutdown(); + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + int port; + + public GreeterImpl(int port) { + this.port = port; + } + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName() + " from server<" + this.port + ">").build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/ExampleNameResolver.java b/examples/src/main/java/io/grpc/examples/nameresolve/ExampleNameResolver.java new file mode 100644 index 00000000000..95bf20dd580 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/nameresolve/ExampleNameResolver.java @@ -0,0 +1,108 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.nameresolve; + +import com.google.common.collect.ImmutableMap; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.Status; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static io.grpc.examples.loadbalance.LoadBalanceClient.exampleServiceName; + +public class ExampleNameResolver extends NameResolver { + + private final URI uri; + private final Map> addrStore; + private Listener2 listener; + + public ExampleNameResolver(URI targetUri) { + this.uri = targetUri; + // This is a fake name resolver, so we just hard code the address here. + addrStore = ImmutableMap.>builder() + .put(exampleServiceName, + Stream.iterate(NameResolveServer.startPort, p -> p + 1) + .limit(NameResolveServer.serverCount) + .map(port -> new InetSocketAddress("localhost", port)) + .collect(Collectors.toList()) + ) + .build(); + } + + @Override + public String getServiceAuthority() { + // Be consistent with behavior in grpc-go, authority is saved in Host field of URI. + if (uri.getHost() != null) { + return uri.getHost(); + } + return "no host"; + } + + @Override + public void shutdown() { + } + + @Override + public void start(Listener2 listener) { + this.listener = listener; + this.resolve(); + } + + @Override + public void refresh() { + this.resolve(); + } + + private void resolve() { + List addresses = addrStore.get(uri.getPath().substring(1)); + try { + List equivalentAddressGroup = addresses.stream() + // convert to socket address + .map(this::toSocketAddress) + // every socket address is a single EquivalentAddressGroup, so they can be accessed randomly + .map(Arrays::asList) + .map(this::addrToEquivalentAddressGroup) + .collect(Collectors.toList()); + + ResolutionResult resolutionResult = ResolutionResult.newBuilder() + .setAddresses(equivalentAddressGroup) + .build(); + + this.listener.onResult(resolutionResult); + + } catch (Exception e) { + // when error occurs, notify listener + this.listener.onError(Status.UNAVAILABLE.withDescription("Unable to resolve host ").withCause(e)); + } + } + + private SocketAddress toSocketAddress(InetSocketAddress address) { + return new InetSocketAddress(address.getHostName(), address.getPort()); + } + + private EquivalentAddressGroup addrToEquivalentAddressGroup(List addrList) { + return new EquivalentAddressGroup(addrList); + } +} diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/ExampleNameResolverProvider.java b/examples/src/main/java/io/grpc/examples/nameresolve/ExampleNameResolverProvider.java new file mode 100644 index 00000000000..cd05f3214f6 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/nameresolve/ExampleNameResolverProvider.java @@ -0,0 +1,47 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.nameresolve; + +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; + +import java.net.URI; + +import static io.grpc.examples.loadbalance.LoadBalanceClient.exampleScheme; + +public class ExampleNameResolverProvider extends NameResolverProvider { + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new ExampleNameResolver(targetUri); + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + // gRPC choose the first NameResolverProvider that supports the target URI scheme. + public String getDefaultScheme() { + return exampleScheme; + } +} diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java new file mode 100644 index 00000000000..ac6fdd32549 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java @@ -0,0 +1,85 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.nameresolve; + +import io.grpc.*; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; + +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class NameResolveClient { + public static final String exampleScheme = "example"; + public static final String exampleServiceName = "lb.example.grpc.io"; + private static final Logger logger = Logger.getLogger(NameResolveClient.class.getName()); + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + public NameResolveClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + public static void main(String[] args) throws Exception { + NameResolverRegistry.getDefaultRegistry().register(new ExampleNameResolverProvider()); + + logger.info("Use default DNS resolver"); + ManagedChannel channel = ManagedChannelBuilder.forTarget("localhost:50051") + .usePlaintext() + .build(); + try { + NameResolveClient client = new NameResolveClient(channel); + for (int i = 0; i < 5; i++) { + client.greet("request" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + + logger.info("Change to use example name resolver"); + /* + Dial to "example:///resolver.example.grpc.io", use {@link ExampleNameResolver} to create connection + "resolver.example.grpc.io" is converted to {@link java.net.URI.path} + */ + channel = ManagedChannelBuilder.forTarget( + String.format("%s:///%s", exampleScheme, exampleServiceName)) + .defaultLoadBalancingPolicy("round_robin") + .usePlaintext() + .build(); + try { + NameResolveClient client = new NameResolveClient(channel); + for (int i = 0; i < 5; i++) { + client.greet("request" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } + + public void greet(String name) { + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } +} diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveServer.java b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveServer.java new file mode 100644 index 00000000000..0a402485906 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveServer.java @@ -0,0 +1,94 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.nameresolve; + +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.stub.StreamObserver; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +public class NameResolveServer { + static public final int serverCount = 3; + static public final int startPort = 50051; + private static final Logger logger = Logger.getLogger(NameResolveServer.class.getName()); + private Server[] servers; + + public static void main(String[] args) throws IOException, InterruptedException { + final NameResolveServer server = new NameResolveServer(); + server.start(); + server.blockUntilShutdown(); + } + + private void start() throws IOException { + servers = new Server[serverCount]; + for (int i = 0; i < serverCount; i++) { + int port = startPort + i; + servers[i] = ServerBuilder.forPort(port) + .addService(new GreeterImpl(port)) + .build() + .start(); + logger.info("Server started, listening on " + port); + } + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + NameResolveServer.this.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** server shut down"); + })); + } + + private void stop() throws InterruptedException { + for (int i = 0; i < serverCount; i++) { + if (servers[i] != null) { + servers[i].shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + } + + private void blockUntilShutdown() throws InterruptedException { + for (int i = 0; i < serverCount; i++) { + if (servers[i] != null) { + servers[i].awaitTermination(); + } + } + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + int port; + + public GreeterImpl(int port) { + this.port = port; + } + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName() + " from server<" + this.port + ">").build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java b/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java index d8f2e72441b..20c9e5893fb 100644 --- a/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java +++ b/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java @@ -20,6 +20,8 @@ import com.google.gson.Gson; import com.google.gson.stream.JsonReader; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; @@ -64,10 +66,8 @@ public class RetryingHelloWorldClient { */ public RetryingHelloWorldClient(String host, int port, boolean enableRetries) { - ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forAddress(host, port) - // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid - // needing certificates. - .usePlaintext(); + ManagedChannelBuilder channelBuilder + = Grpc.newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()); if (enableRetries) { Map serviceConfig = getRetryingServiceConfig(); logger.info("Client started with retrying configuration: " + serviceConfig); diff --git a/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java b/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java index 0bff00a6988..165cc72ffa3 100644 --- a/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java @@ -19,8 +19,9 @@ import java.text.DecimalFormat; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.Status; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; @@ -43,7 +44,7 @@ public class RetryingHelloWorldServer { private void start() throws IOException { /* The port on which the server should run */ int port = 50051; - server = ServerBuilder.forPort(port) + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) .build() .start(); diff --git a/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideClient.java b/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideClient.java index 6958c643da7..f65b1215359 100644 --- a/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideClient.java +++ b/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideClient.java @@ -19,8 +19,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Message; import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.examples.routeguide.RouteGuideGrpc.RouteGuideBlockingStub; @@ -259,7 +260,8 @@ public static void main(String[] args) throws InterruptedException { return; } - ManagedChannel channel = ManagedChannelBuilder.forTarget(target).usePlaintext().build(); + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .build(); try { RouteGuideClient client = new RouteGuideClient(channel); // Looking for a valid feature diff --git a/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java b/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java index c91544ae45d..b39b06a6f92 100644 --- a/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java +++ b/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java @@ -25,6 +25,8 @@ import static java.lang.Math.toRadians; import static java.util.concurrent.TimeUnit.NANOSECONDS; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.stub.StreamObserver; @@ -55,7 +57,8 @@ public RouteGuideServer(int port) throws IOException { /** Create a RouteGuide server listening on {@code port} using {@code featureFile} database. */ public RouteGuideServer(int port, URL featureFile) throws IOException { - this(ServerBuilder.forPort(port), port, RouteGuideUtil.parseFeatures(featureFile)); + this(Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()), + port, RouteGuideUtil.parseFeatures(featureFile)); } /** Create a RouteGuide server using serverBuilder as a base and features as data. */ diff --git a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java index d0262d687ee..8c6cf60279a 100644 --- a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java +++ b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java @@ -40,8 +40,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For more unit test examples see {@link io.grpc.examples.routeguide.RouteGuideClientTest} and * {@link io.grpc.examples.routeguide.RouteGuideServerTest}. diff --git a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java index 9a20476772c..63281eeba1a 100644 --- a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java +++ b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java @@ -33,8 +33,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For more unit test examples see {@link io.grpc.examples.routeguide.RouteGuideClientTest} and * {@link io.grpc.examples.routeguide.RouteGuideServerTest}. diff --git a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java index be2337cc4ce..4c184fb82ee 100644 --- a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java +++ b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java @@ -53,8 +53,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For basic unit test examples see {@link io.grpc.examples.helloworld.HelloWorldClientTest} and * {@link io.grpc.examples.helloworld.HelloWorldServerTest}. diff --git a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java index 19322c2d72c..a5a84824af6 100644 --- a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java +++ b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java @@ -50,8 +50,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For basic unit test examples see {@link io.grpc.examples.helloworld.HelloWorldClientTest} and * {@link io.grpc.examples.helloworld.HelloWorldServerTest}. diff --git a/gae-interop-testing/gae-jdk8/build.gradle b/gae-interop-testing/gae-jdk8/build.gradle index 325e4651e0b..03d603bd3cf 100644 --- a/gae-interop-testing/gae-jdk8/build.gradle +++ b/gae-interop-testing/gae-jdk8/build.gradle @@ -27,6 +27,8 @@ buildscript { plugins { id "java" id "war" + + id "ru.vyarus.animalsniffer" } description = 'gRPC: gae interop testing (jdk8)' @@ -51,11 +53,12 @@ dependencies { exclude group: 'io.grpc', module: 'grpc-xds' } implementation libraries.junit - implementation libraries.protobuf - runtimeOnly libraries.netty_tcnative + implementation libraries.protobuf.java + runtimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes + signature libraries.signature.java } -compileJava { +tasks.named("compileJava").configure { // Disable "No processor claimed any of these annotations: org.junit.Ignore" options.compilerArgs += ["-Xlint:-processing"] } @@ -118,7 +121,8 @@ String getAppUrl(String project, String service, String version) { return "http://${version}.${service}.${project}.appspot.com" } -task runInteropTestRemote(dependsOn: 'appengineDeploy') { +tasks.register("runInteropTestRemote") { + dependsOn appengineDeploy doLast { // give remote app some time to settle down sleep(20000) diff --git a/gae-interop-testing/gae-jdk8/src/main/java/io/grpc/testing/integration/NettyClientInteropServlet.java b/gae-interop-testing/gae-jdk8/src/main/java/io/grpc/testing/integration/NettyClientInteropServlet.java index 6af02227143..48978fac0b3 100644 --- a/gae-interop-testing/gae-jdk8/src/main/java/io/grpc/testing/integration/NettyClientInteropServlet.java +++ b/gae-interop-testing/gae-jdk8/src/main/java/io/grpc/testing/integration/NettyClientInteropServlet.java @@ -28,6 +28,7 @@ import java.io.StringWriter; import java.text.SimpleDateFormat; import java.util.Calendar; +import java.util.Locale; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.logging.Handler; @@ -102,6 +103,7 @@ private void doGetHelper(HttpServletResponse resp) throws IOException { resp.setStatus(200); writer.println( String.format( + Locale.US, "PASS! Tests ran %d, tests ignored %d", result.getRunCount(), result.getIgnoreCount())); @@ -109,6 +111,7 @@ private void doGetHelper(HttpServletResponse resp) throws IOException { resp.setStatus(500); writer.println( String.format( + Locale.US, "FAILED! Tests ran %d, tests failed %d, tests ignored %d", result.getRunCount(), result.getFailureCount(), diff --git a/gcp-observability/build.gradle b/gcp-observability/build.gradle new file mode 100644 index 00000000000..f03d6f9620a --- /dev/null +++ b/gcp-observability/build.gradle @@ -0,0 +1,56 @@ +plugins { + id "java-library" + id "maven-publish" + + id "com.google.protobuf" + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: Google Cloud Platform Observability" + +tasks.named("compileJava").configure { + it.options.compilerArgs += [ + // only has AutoValue annotation processor + "-Xlint:-processing" + ] + appendToProperty( + it.options.errorprone.excludedPaths, + ".*/build/generated/sources/annotationProcessor/java/.*", + "|") +} + +dependencies { + def cloudLoggingVersion = '3.6.1' + + annotationProcessor libraries.auto.value + api project(':grpc-api') + + implementation project(':grpc-protobuf'), + project(':grpc-stub'), + project(':grpc-alts'), + project(':grpc-census'), + ("com.google.cloud:google-cloud-logging:${cloudLoggingVersion}"), + libraries.opencensus.contrib.grpc.metrics, + libraries.opencensus.exporter.stats.stackdriver, + libraries.opencensus.exporter.trace.stackdriver, + libraries.animalsniffer.annotations, // Prefer our version + libraries.google.auth.credentials, // Prefer our version + libraries.protobuf.java.util, // Prefer our version + libraries.gson, // Prefer our version + libraries.perfmark.api, // Prefer our version + ('com.google.guava:guava:31.1-jre') + + runtimeOnly libraries.opencensus.impl + + testImplementation project(':grpc-context').sourceSets.test.output, + project(':grpc-testing'), + project(':grpc-testing-proto'), + project(':grpc-netty-shaded') + testImplementation (libraries.guava.testlib) { + exclude group: 'junit', module: 'junit' + } + + signature libraries.signature.java +} + +configureProtoCompilation() diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java new file mode 100644 index 00000000000..770d764e0cd --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java @@ -0,0 +1,206 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; +import io.grpc.ClientInterceptor; +import io.grpc.ExperimentalApi; +import io.grpc.InternalGlobalInterceptors; +import io.grpc.ManagedChannelProvider.ProviderNotFoundException; +import io.grpc.ServerInterceptor; +import io.grpc.ServerStreamTracer; +import io.grpc.census.InternalCensusStatsAccessor; +import io.grpc.census.InternalCensusTracingAccessor; +import io.grpc.gcp.observability.interceptors.ConditionalClientInterceptor; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper; +import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor; +import io.grpc.gcp.observability.interceptors.LogHelper; +import io.grpc.gcp.observability.logging.GcpLogSink; +import io.grpc.gcp.observability.logging.Sink; +import io.opencensus.common.Duration; +import io.opencensus.contrib.grpc.metrics.RpcViewConstants; +import io.opencensus.exporter.stats.stackdriver.StackdriverStatsConfiguration; +import io.opencensus.exporter.stats.stackdriver.StackdriverStatsExporter; +import io.opencensus.exporter.trace.stackdriver.StackdriverTraceConfiguration; +import io.opencensus.exporter.trace.stackdriver.StackdriverTraceExporter; +import io.opencensus.metrics.LabelKey; +import io.opencensus.metrics.LabelValue; +import io.opencensus.stats.Stats; +import io.opencensus.stats.ViewManager; +import io.opencensus.trace.AttributeValue; +import io.opencensus.trace.Tracing; +import io.opencensus.trace.config.TraceConfig; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; +import java.util.stream.Collectors; + +/** The main class for gRPC Google Cloud Platform Observability features. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8869") +public final class GcpObservability implements AutoCloseable { + private static final int METRICS_EXPORT_INTERVAL = 30; + private static final ImmutableSet SERVICES_TO_EXCLUDE = ImmutableSet.of( + "google.logging.v2.LoggingServiceV2", "google.monitoring.v3.MetricService", + "google.devtools.cloudtrace.v2.TraceService"); + private static GcpObservability instance = null; + private final Sink sink; + private final ObservabilityConfig config; + private final ArrayList clientInterceptors = new ArrayList<>(); + private final ArrayList serverInterceptors = new ArrayList<>(); + private final ArrayList tracerFactories = new ArrayList<>(); + + /** + * Initialize grpc-observability. + * + * @throws ProviderNotFoundException if no underlying channel/server provider is available. + */ + public static synchronized GcpObservability grpcInit() throws IOException { + if (instance == null) { + GlobalLocationTags globalLocationTags = new GlobalLocationTags(); + ObservabilityConfigImpl observabilityConfig = ObservabilityConfigImpl.getInstance(); + Sink sink = new GcpLogSink(observabilityConfig.getProjectId(), + globalLocationTags.getLocationTags(), observabilityConfig.getCustomTags(), + SERVICES_TO_EXCLUDE); + LogHelper helper = new LogHelper(sink); + ConfigFilterHelper configFilterHelper = ConfigFilterHelper.getInstance(observabilityConfig); + instance = grpcInit(sink, observabilityConfig, + new InternalLoggingChannelInterceptor.FactoryImpl(helper, configFilterHelper), + new InternalLoggingServerInterceptor.FactoryImpl(helper, configFilterHelper)); + instance.registerStackDriverExporter(observabilityConfig.getProjectId(), + observabilityConfig.getCustomTags()); + } + return instance; + } + + @VisibleForTesting + static GcpObservability grpcInit( + Sink sink, + ObservabilityConfig config, + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory, + InternalLoggingServerInterceptor.Factory serverInterceptorFactory) + throws IOException { + if (instance == null) { + instance = new GcpObservability(sink, config); + instance.setProducer(channelInterceptorFactory, serverInterceptorFactory); + } + return instance; + } + + /** Un-initialize/shutdown grpc-observability. */ + @Override + public void close() { + synchronized (GcpObservability.class) { + if (instance == null) { + throw new IllegalStateException("GcpObservability already closed!"); + } + sink.close(); + instance = null; + } + } + + // TODO(dnvindhya): Remove InterceptorFactory and replace with respective + // interceptors + private void setProducer( + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory, + InternalLoggingServerInterceptor.Factory serverInterceptorFactory) { + if (config.isEnableCloudLogging()) { + clientInterceptors.add(channelInterceptorFactory.create()); + serverInterceptors.add(serverInterceptorFactory.create()); + } + if (config.isEnableCloudMonitoring()) { + clientInterceptors.add(getConditionalInterceptor( + InternalCensusStatsAccessor.getClientInterceptor(true, true, true, true))); + tracerFactories.add( + InternalCensusStatsAccessor.getServerStreamTracerFactory(true, true, true)); + } + if (config.isEnableCloudTracing()) { + clientInterceptors.add( + getConditionalInterceptor(InternalCensusTracingAccessor.getClientInterceptor(false))); + tracerFactories.add(InternalCensusTracingAccessor.getServerStreamTracerFactory(false)); + } + + InternalGlobalInterceptors.setInterceptorsTracers( + clientInterceptors, serverInterceptors, tracerFactories); + } + + static ConditionalClientInterceptor getConditionalInterceptor(ClientInterceptor interceptor) { + return new ConditionalClientInterceptor(interceptor, + (m, c) -> !SERVICES_TO_EXCLUDE.contains(m.getServiceName())); + } + + private static void registerObservabilityViews() { + ViewManager viewManager = Stats.getViewManager(); + + // client views + viewManager.registerView(RpcViewConstants.GRPC_CLIENT_COMPLETED_RPC_VIEW); + viewManager.registerView(RpcViewConstants.GRPC_CLIENT_STARTED_RPC_VIEW); + + // server views + viewManager.registerView(RpcViewConstants.GRPC_SERVER_COMPLETED_RPC_VIEW); + viewManager.registerView(RpcViewConstants.GRPC_SERVER_STARTED_RPC_VIEW); + } + + @VisibleForTesting + void registerStackDriverExporter(String projectId, Map customTags) + throws IOException { + if (config.isEnableCloudMonitoring()) { + registerObservabilityViews(); + StackdriverStatsConfiguration.Builder statsConfigurationBuilder = + StackdriverStatsConfiguration.builder(); + if (projectId != null) { + statsConfigurationBuilder.setProjectId(projectId); + } + if (customTags != null) { + Map constantLabels = customTags.entrySet().stream() + .collect(Collectors.toMap(e -> LabelKey.create(e.getKey(), e.getKey()), + e -> LabelValue.create(e.getValue()))); + statsConfigurationBuilder.setConstantLabels(constantLabels); + } + statsConfigurationBuilder.setExportInterval(Duration.create(METRICS_EXPORT_INTERVAL, 0)); + StackdriverStatsExporter.createAndRegister(statsConfigurationBuilder.build()); + } + + if (config.isEnableCloudTracing()) { + TraceConfig traceConfig = Tracing.getTraceConfig(); + traceConfig.updateActiveTraceParams( + traceConfig.getActiveTraceParams().toBuilder().setSampler(config.getSampler()).build()); + StackdriverTraceConfiguration.Builder traceConfigurationBuilder = + StackdriverTraceConfiguration.builder(); + if (projectId != null) { + traceConfigurationBuilder.setProjectId(projectId); + } + if (customTags != null) { + Map fixedAttributes = customTags.entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey(), + e -> AttributeValue.stringAttributeValue(e.getValue()))); + traceConfigurationBuilder.setFixedAttributes(fixedAttributes); + } + StackdriverTraceExporter.createAndRegister(traceConfigurationBuilder.build()); + } + } + + private GcpObservability( + Sink sink, + ObservabilityConfig config) { + this.sink = checkNotNull(sink); + this.config = checkNotNull(config); + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/GlobalLocationTags.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/GlobalLocationTags.java new file mode 100644 index 00000000000..045671814f0 --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/GlobalLocationTags.java @@ -0,0 +1,148 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.javanet.NetHttpTransport; +import com.google.api.client.util.Strings; +import com.google.auth.http.HttpTransportFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; +import java.util.Scanner; +import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** A container of all global location tags used for observability. */ +final class GlobalLocationTags { + private static final Logger logger = Logger.getLogger(GlobalLocationTags.class.getName()); + + private final Map locationTags; + + GlobalLocationTags() { + ImmutableMap.Builder locationTagsBuilder = ImmutableMap.builder(); + populate(locationTagsBuilder); + locationTags = locationTagsBuilder.buildOrThrow(); + } + + private static String applyTrim(String value) { + if (!Strings.isNullOrEmpty(value)) { + value = value.trim(); + } + return value; + } + + Map getLocationTags() { + return locationTags; + } + + @VisibleForTesting + static void populateFromMetadataServer(ImmutableMap.Builder locationTags) { + MetadataConfig metadataConfig = new MetadataConfig(new DefaultHttpTransportFactory()); + metadataConfig.init(); + locationTags.putAll(metadataConfig.getAllValues()); + } + + @VisibleForTesting + static void populateFromKubernetesValues(ImmutableMap.Builder locationTags, + String namespaceFile, + String hostnameFile, String cgroupFile) { + // namespace name: contents of file /var/run/secrets/kubernetes.io/serviceaccount/namespace + populateFromFileContents(locationTags, "namespace_name", + namespaceFile, GlobalLocationTags::applyTrim); + + // pod_name: hostname i.e. contents of /etc/hostname + populateFromFileContents(locationTags, "pod_name", hostnameFile, + GlobalLocationTags::applyTrim); + + // container_id: parsed from /proc/self/cgroup . Note: only works for Linux-based containers + populateFromFileContents(locationTags, "container_id", cgroupFile, + (value) -> getContainerIdFromFileContents(value)); + } + + @VisibleForTesting + static void populateFromFileContents(ImmutableMap.Builder locationTags, + String key, String filePath, Function parser) { + String value = parser.apply(readFileContents(filePath)); + if (value != null) { + locationTags.put(key, value); + } + } + + /** + * Parse from a line such as this. + * 1:name=systemd:/kubepods/burstable/podf5143dd2/de67c4419b20924eaa141813 + * + * @param value file contents + * @return container-id parsed ("podf5143dd2/de67c4419b20924eaa141813" from the above snippet) + */ + @VisibleForTesting static String getContainerIdFromFileContents(String value) { + if (value != null) { + try (Scanner scanner = new Scanner(value)) { + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + String[] tokens = line.split(":"); + if (tokens.length == 3 && tokens[2].startsWith("/kubepods/burstable/")) { + tokens = tokens[2].split("/"); + if (tokens.length == 5) { + return tokens[4]; + } + } + } + } + } + return null; + } + + private static String readFileContents(String file) { + Path fileName = Paths.get(file); + if (Files.isReadable(fileName)) { + try { + byte[] bytes = Files.readAllBytes(fileName); + return new String(bytes, Charsets.US_ASCII); + } catch (IOException e) { + logger.log(Level.FINE, "Reading file:" + file, e); + } + } else { + logger.log(Level.FINE, "File:" + file + " is not readable (or missing?)"); + } + return null; + } + + static void populate(ImmutableMap.Builder locationTags) { + populateFromMetadataServer(locationTags); + populateFromKubernetesValues(locationTags, + "/var/run/secrets/kubernetes.io/serviceaccount/namespace", + "/etc/hostname", "/proc/self/cgroup"); + } + + private static class DefaultHttpTransportFactory implements HttpTransportFactory { + + private static final HttpTransport netHttpTransport = new NetHttpTransport(); + + @Override + public HttpTransport create() { + return netHttpTransport; + } + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/MetadataConfig.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/MetadataConfig.java new file mode 100644 index 00000000000..071636e2a92 --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/MetadataConfig.java @@ -0,0 +1,107 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpStatusCodes; +import com.google.api.client.http.HttpTransport; +import com.google.auth.http.HttpTransportFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.io.InputStream; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** Class to read Google Metadata Server values. */ +final class MetadataConfig { + private static final Logger logger = Logger.getLogger(MetadataConfig.class.getName()); + + private static final int TIMEOUT_MS = 5000; + private static final String METADATA_URL = "http://metadata.google.internal/computeMetadata/v1/"; + private HttpRequestFactory requestFactory; + private HttpTransportFactory transportFactory; + + @VisibleForTesting public MetadataConfig(HttpTransportFactory transportFactory) { + this.transportFactory = transportFactory; + + } + + void init() { + HttpTransport httpTransport = transportFactory.create(); + requestFactory = httpTransport.createRequestFactory(); + } + + /** gets all the values from the MDS we need to set in our logging tags. */ + ImmutableMap getAllValues() { + ImmutableMap.Builder builder = ImmutableMap.builder(); + //addValueFor(builder, "instance/hostname", "GCE_INSTANCE_HOSTNAME"); + addValueFor(builder, "instance/id", "gke_node_id"); + //addValueFor(builder, "instance/zone", "GCE_INSTANCE_ZONE"); + addValueFor(builder, "project/project-id", "project_id"); + addValueFor(builder, "project/numeric-project-id", "project_numeric_id"); + addValueFor(builder, "instance/attributes/cluster-name", "cluster_name"); + addValueFor(builder, "instance/attributes/cluster-uid", "cluster_uid"); + addValueFor(builder, "instance/attributes/cluster-location", "location"); + try { + requestFactory.getTransport().shutdown(); + } catch (IOException e) { + logger.log(Level.FINE, "Calling HttpTransport.shutdown()", e); + } + return builder.buildOrThrow(); + } + + void addValueFor(ImmutableMap.Builder builder, String attribute, String key) { + try { + String value = getAttribute(attribute); + if (value != null) { + builder.put(key, value); + } + } catch (IOException e) { + logger.log(Level.FINE, "Calling getAttribute('" + attribute + "')", e); + } + } + + String getAttribute(String attributeName) throws IOException { + GenericUrl url = new GenericUrl(METADATA_URL + attributeName); + HttpRequest request = requestFactory.buildGetRequest(url); + request = request.setReadTimeout(TIMEOUT_MS); + request = request.setConnectTimeout(TIMEOUT_MS); + request = request.setHeaders(new HttpHeaders().set("Metadata-Flavor", "Google")); + HttpResponse response = null; + try { + response = request.execute(); + if (response.getStatusCode() == HttpStatusCodes.STATUS_CODE_OK) { + InputStream stream = response.getContent(); + if (stream != null) { + byte[] bytes = new byte[stream.available()]; + stream.read(bytes); + return new String(bytes, response.getContentCharset()); + } + } + } finally { + if (response != null) { + response.disconnect(); + } + } + return null; + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java new file mode 100644 index 00000000000..0489c8b5e3b --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java @@ -0,0 +1,95 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import io.grpc.Internal; +import io.opencensus.trace.Sampler; +import java.util.List; +import java.util.Map; +import java.util.Set; +import javax.annotation.concurrent.ThreadSafe; + +@Internal +public interface ObservabilityConfig { + /** Is Cloud Logging enabled. */ + boolean isEnableCloudLogging(); + + /** Is Cloud Monitoring enabled. */ + boolean isEnableCloudMonitoring(); + + /** Is Cloud Tracing enabled. */ + boolean isEnableCloudTracing(); + + /** Get project ID - where logs will go. */ + String getProjectId(); + + /** Get filters for client logging. */ + List getClientLogFilters(); + + /** Get filters for server logging. */ + List getServerLogFilters(); + + /** Get sampler for TraceConfig - when Cloud Tracing is enabled. */ + Sampler getSampler(); + + /** Map of all custom tags used for logging, metrics and traces. */ + Map getCustomTags(); + + /** + * POJO for representing a filter used in configuration. + */ + @ThreadSafe + class LogFilter { + /** Set of services. */ + public final Set services; + + /* Set of fullMethodNames. */ + public final Set methods; + + /** Boolean to indicate all services and methods. */ + public final boolean matchAll; + + /** Number of bytes of header to log. */ + public final int headerBytes; + + /** Number of bytes of message to log. */ + public final int messageBytes; + + /** Boolean to indicate if services and methods matching pattern needs to be excluded. */ + public final boolean excludePattern; + + /** + * Object used to represent filter used in configuration. + * @param services Set of services derived from pattern + * @param serviceMethods Set of fullMethodNames derived from pattern + * @param matchAll If true, match all services and methods + * @param headerBytes Total number of bytes of header to log + * @param messageBytes Total number of bytes of message to log + * @param excludePattern If true, services and methods matching pattern be excluded + */ + public LogFilter(Set services, Set serviceMethods, boolean matchAll, + int headerBytes, int messageBytes, + boolean excludePattern) { + this.services = services; + this.methods = serviceMethods; + this.matchAll = matchAll; + this.headerBytes = headerBytes; + this.messageBytes = messageBytes; + this.excludePattern = excludePattern; + } + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java new file mode 100644 index 00000000000..2b0a44473d0 --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java @@ -0,0 +1,276 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.ServiceOptions; +import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.grpc.internal.JsonParser; +import io.grpc.internal.JsonUtil; +import io.opencensus.trace.Sampler; +import io.opencensus.trace.samplers.Samplers; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * gRPC GcpObservability configuration processor. + */ +final class ObservabilityConfigImpl implements ObservabilityConfig { + private static final Logger logger = Logger + .getLogger(ObservabilityConfigImpl.class.getName()); + private static final String CONFIG_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG"; + private static final String CONFIG_FILE_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG_FILE"; + // Tolerance for floating-point comparisons. + private static final double EPSILON = 1e-6; + + private static final Pattern METHOD_NAME_REGEX = + Pattern.compile("^([*])|((([\\w.]+)/((?:\\w+)|[*])))$"); + + private boolean enableCloudLogging = false; + private boolean enableCloudMonitoring = false; + private boolean enableCloudTracing = false; + private String projectId = null; + + private List clientLogFilters; + private List serverLogFilters; + private Sampler sampler; + private Map customTags; + + static ObservabilityConfigImpl getInstance() throws IOException { + ObservabilityConfigImpl config = new ObservabilityConfigImpl(); + String configFile = System.getenv(CONFIG_FILE_ENV_VAR_NAME); + if (configFile != null) { + config.parseFile(configFile); + } else { + config.parse(System.getenv(CONFIG_ENV_VAR_NAME)); + } + return config; + } + + void parseFile(String configFile) throws IOException { + String configFileContent = + new String(Files.readAllBytes(Paths.get(configFile)), Charsets.UTF_8); + checkArgument(!configFileContent.isEmpty(), CONFIG_FILE_ENV_VAR_NAME + " is empty!"); + parse(configFileContent); + } + + @SuppressWarnings("unchecked") + void parse(String config) throws IOException { + checkArgument(config != null, CONFIG_ENV_VAR_NAME + " value is null!"); + parseConfig((Map) JsonParser.parse(config)); + } + + private void parseConfig(Map config) { + checkArgument(config != null, "Invalid configuration"); + if (config.isEmpty()) { + clientLogFilters = Collections.emptyList(); + serverLogFilters = Collections.emptyList(); + customTags = Collections.emptyMap(); + return; + } + projectId = fetchProjectId(JsonUtil.getString(config, "project_id")); + + Map rawCloudLoggingObject = JsonUtil.getObject(config, "cloud_logging"); + if (rawCloudLoggingObject != null) { + enableCloudLogging = true; + ImmutableList.Builder clientFiltersBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder serverFiltersBuilder = new ImmutableList.Builder<>(); + parseLoggingObject(rawCloudLoggingObject, clientFiltersBuilder, serverFiltersBuilder); + clientLogFilters = clientFiltersBuilder.build(); + serverLogFilters = serverFiltersBuilder.build(); + } + + Map rawCloudMonitoringObject = JsonUtil.getObject(config, "cloud_monitoring"); + if (rawCloudMonitoringObject != null) { + enableCloudMonitoring = true; + } + + Map rawCloudTracingObject = JsonUtil.getObject(config, "cloud_trace"); + if (rawCloudTracingObject != null) { + enableCloudTracing = true; + sampler = parseTracingObject(rawCloudTracingObject); + } + + Map rawCustomTagsObject = JsonUtil.getObject(config, "labels"); + if (rawCustomTagsObject != null) { + customTags = parseCustomTags(rawCustomTagsObject); + } + + if (clientLogFilters == null) { + clientLogFilters = Collections.emptyList(); + } + if (serverLogFilters == null) { + serverLogFilters = Collections.emptyList(); + } + if (customTags == null) { + customTags = Collections.emptyMap(); + } + } + + private static String fetchProjectId(String configProjectId) { + // If project_id is not specified in config, get default GCP project id from the environment + String projectId = configProjectId != null ? configProjectId : getDefaultGcpProjectId(); + checkArgument(projectId != null, "Unable to detect project_id"); + logger.log(Level.FINEST, "Found project ID : ", projectId); + return projectId; + } + + private static String getDefaultGcpProjectId() { + return ServiceOptions.getDefaultProjectId(); + } + + private static void parseLoggingObject( + Map rawLoggingConfig, + ImmutableList.Builder clientFilters, + ImmutableList.Builder serverFilters) { + parseRpcEvents(JsonUtil.getList(rawLoggingConfig, "client_rpc_events"), clientFilters); + parseRpcEvents(JsonUtil.getList(rawLoggingConfig, "server_rpc_events"), serverFilters); + } + + private static Sampler parseTracingObject(Map rawCloudTracingConfig) { + Sampler defaultSampler = Samplers.probabilitySampler(0.0); + Double samplingRate = JsonUtil.getNumberAsDouble(rawCloudTracingConfig, "sampling_rate"); + if (samplingRate == null) { + return defaultSampler; + } + checkArgument(samplingRate >= 0.0 && samplingRate <= 1.0, + "'sampling_rate' needs to be between [0.0, 1.0]"); + // Using alwaysSample() instead of probabilitySampler() because according to + // {@link io.opencensus.trace.samplers.ProbabilitySampler#shouldSample} + // there is a (very) small chance of *not* sampling if probability = 1.00. + return 1 - samplingRate < EPSILON ? Samplers.alwaysSample() + : Samplers.probabilitySampler(samplingRate); + } + + private static Map parseCustomTags(Map rawCustomTags) { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + for (Map.Entry entry: rawCustomTags.entrySet()) { + checkArgument( + entry.getValue() instanceof String, + "'labels' needs to be a map of "); + builder.put(entry.getKey(), (String) entry.getValue()); + } + return builder.build(); + } + + private static void parseRpcEvents(List rpcEvents, ImmutableList.Builder filters) { + if (rpcEvents == null) { + return; + } + List> jsonRpcEvents = JsonUtil.checkObjectList(rpcEvents); + for (Map jsonClientRpcEvent : jsonRpcEvents) { + filters.add(parseJsonLogFilter(jsonClientRpcEvent)); + } + } + + private static LogFilter parseJsonLogFilter(Map logFilterMap) { + ImmutableSet.Builder servicesSetBuilder = new ImmutableSet.Builder<>(); + ImmutableSet.Builder methodsSetBuilder = new ImmutableSet.Builder<>(); + boolean wildCardFilter = false; + + boolean excludeFilter = + Boolean.TRUE.equals(JsonUtil.getBoolean(logFilterMap, "exclude")); + List methodsList = JsonUtil.getListOfStrings(logFilterMap, "methods"); + if (methodsList != null) { + wildCardFilter = extractMethodOrServicePattern( + methodsList, excludeFilter, servicesSetBuilder, methodsSetBuilder); + } + Integer maxHeaderBytes = JsonUtil.getNumberAsInteger(logFilterMap, "max_metadata_bytes"); + Integer maxMessageBytes = JsonUtil.getNumberAsInteger(logFilterMap, "max_message_bytes"); + + return new LogFilter( + servicesSetBuilder.build(), + methodsSetBuilder.build(), + wildCardFilter, + maxHeaderBytes != null ? maxHeaderBytes.intValue() : 0, + maxMessageBytes != null ? maxMessageBytes.intValue() : 0, + excludeFilter); + } + + private static boolean extractMethodOrServicePattern(List patternList, boolean exclude, + ImmutableSet.Builder servicesSetBuilder, + ImmutableSet.Builder methodsSetBuilder) { + boolean globalFilter = false; + for (String methodOrServicePattern : patternList) { + Matcher matcher = METHOD_NAME_REGEX.matcher(methodOrServicePattern); + checkArgument( + matcher.matches(), "invalid service or method filter : " + methodOrServicePattern); + if ("*".equals(methodOrServicePattern)) { + checkArgument(!exclude, "cannot have 'exclude' and '*' wildcard in the same filter"); + globalFilter = true; + } else if ("*".equals(matcher.group(5))) { + String service = matcher.group(4); + servicesSetBuilder.add(service); + } else { + methodsSetBuilder.add(methodOrServicePattern); + } + } + return globalFilter; + } + + @Override + public boolean isEnableCloudLogging() { + return enableCloudLogging; + } + + @Override + public boolean isEnableCloudMonitoring() { + return enableCloudMonitoring; + } + + @Override + public boolean isEnableCloudTracing() { + return enableCloudTracing; + } + + @Override + public String getProjectId() { + return projectId; + } + + @Override + public List getClientLogFilters() { + return clientLogFilters; + } + + @Override + public List getServerLogFilters() { + return serverLogFilters; + } + + @Override + public Sampler getSampler() { + return sampler; + } + + @Override + public Map getCustomTags() { + return customTags; + } +} diff --git a/observability/src/main/java/io/grpc/observability/interceptors/InternalLoggingChannelInterceptor.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConditionalClientInterceptor.java similarity index 53% rename from observability/src/main/java/io/grpc/observability/interceptors/InternalLoggingChannelInterceptor.java rename to gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConditionalClientInterceptor.java index 3e535de3816..5051453ce0e 100644 --- a/observability/src/main/java/io/grpc/observability/interceptors/InternalLoggingChannelInterceptor.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConditionalClientInterceptor.java @@ -14,7 +14,9 @@ * limitations under the License. */ -package io.grpc.observability.interceptors; +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.CallOptions; import io.grpc.Channel; @@ -22,27 +24,29 @@ import io.grpc.ClientInterceptor; import io.grpc.Internal; import io.grpc.MethodDescriptor; +import java.util.function.BiPredicate; -/** A logging interceptor for {@code LoggingChannelProvider}. */ +/** + * A client interceptor that conditionally calls a delegated interceptor. + */ @Internal -public final class InternalLoggingChannelInterceptor implements ClientInterceptor { - - public interface Factory { - ClientInterceptor create(); - } +public final class ConditionalClientInterceptor implements ClientInterceptor { - public static class FactoryImpl implements Factory { + private final ClientInterceptor delegate; + private final BiPredicate, CallOptions> predicate; - @Override - public ClientInterceptor create() { - return new InternalLoggingChannelInterceptor(); - } + public ConditionalClientInterceptor(ClientInterceptor delegate, + BiPredicate, CallOptions> predicate) { + this.delegate = checkNotNull(delegate, "delegate"); + this.predicate = checkNotNull(predicate, "predicate"); } @Override public ClientCall interceptCall(MethodDescriptor method, CallOptions callOptions, Channel next) { - // TODO(dnvindhya) implement the interceptor - return null; + if (!predicate.test(method, callOptions)) { + return next.newCall(method, callOptions); + } + return delegate.interceptCall(method, callOptions, next); } } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java new file mode 100644 index 00000000000..9b05634dbfe --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java @@ -0,0 +1,107 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Internal; +import io.grpc.gcp.observability.ObservabilityConfig; +import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; +import java.util.List; + +/** + * Parses gRPC GcpObservability configuration filters for interceptors usage. + */ +@Internal +public class ConfigFilterHelper { + public static final FilterParams NO_FILTER_PARAMS + = FilterParams.create(false, 0, 0); + + private final ObservabilityConfig config; + + private ConfigFilterHelper(ObservabilityConfig config) { + this.config = config; + } + + /** + * Creates and returns helper instance for log filtering. + * + * @param config processed ObservabilityConfig object + * @return helper instance for filtering + */ + public static ConfigFilterHelper getInstance(ObservabilityConfig config) { + return new ConfigFilterHelper(config); + } + + + /** + * Checks if the corresponding service/method passed needs to be logged according to user provided + * observability configuration. + * Filters are evaluated in text order, first match is used. + * + * @param fullMethodName the fully qualified name of the method + * @param client set to true if method being checked is a client method; false otherwise + * @return FilterParams object 1. specifies if the corresponding method needs to be logged + * (log field will be set to true) 2. values of payload limits retrieved from configuration + */ + public FilterParams logRpcMethod(String fullMethodName, boolean client) { + FilterParams params = NO_FILTER_PARAMS; + + int index = checkNotNull(fullMethodName, "fullMethodName").lastIndexOf('/'); + String serviceName = fullMethodName.substring(0, index); + + List logFilters = + client ? config.getClientLogFilters() : config.getServerLogFilters(); + + // TODO (dnvindhya): Optimize by caching results for fullMethodName. + for (LogFilter logFilter : logFilters) { + if (logFilter.matchAll + || logFilter.services.contains(serviceName) + || logFilter.methods.contains(fullMethodName)) { + if (logFilter.excludePattern) { + return params; + } + int currentHeaderBytes = logFilter.headerBytes; + int currentMessageBytes = logFilter.messageBytes; + return FilterParams.create(true, currentHeaderBytes, currentMessageBytes); + } + } + return params; + } + + /** + * Class containing results for method/service filter information, such as flag for logging + * method/service and payload limits to be used for filtering. + */ + @AutoValue + public abstract static class FilterParams { + + abstract boolean log(); + + abstract int headerBytes(); + + abstract int messageBytes(); + + @VisibleForTesting + public static FilterParams create(boolean log, int headerBytes, int messageBytes) { + return new AutoValue_ConfigFilterHelper_FilterParams( + log, headerBytes, messageBytes); + } + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InetAddressUtil.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InetAddressUtil.java new file mode 100644 index 00000000000..376ada4f736 --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InetAddressUtil.java @@ -0,0 +1,94 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.primitives.Ints; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.util.Arrays; + +// This is copied from guava 20.0 because it is a @Beta api +final class InetAddressUtil { + private static final int IPV6_PART_COUNT = 8; + + public static String toAddrString(InetAddress ip) { + checkNotNull(ip); + if (ip instanceof Inet4Address) { + // For IPv4, Java's formatting is good enough. + return ip.getHostAddress(); + } + checkArgument(ip instanceof Inet6Address); + byte[] bytes = ip.getAddress(); + int[] hextets = new int[IPV6_PART_COUNT]; + for (int i = 0; i < hextets.length; i++) { + hextets[i] = Ints.fromBytes((byte) 0, (byte) 0, bytes[2 * i], bytes[2 * i + 1]); + } + compressLongestRunOfZeroes(hextets); + return hextetsToIPv6String(hextets); + } + + private static void compressLongestRunOfZeroes(int[] hextets) { + int bestRunStart = -1; + int bestRunLength = -1; + int runStart = -1; + for (int i = 0; i < hextets.length + 1; i++) { + if (i < hextets.length && hextets[i] == 0) { + if (runStart < 0) { + runStart = i; + } + } else if (runStart >= 0) { + int runLength = i - runStart; + if (runLength > bestRunLength) { + bestRunStart = runStart; + bestRunLength = runLength; + } + runStart = -1; + } + } + if (bestRunLength >= 2) { + Arrays.fill(hextets, bestRunStart, bestRunStart + bestRunLength, -1); + } + } + + private static String hextetsToIPv6String(int[] hextets) { + // While scanning the array, handle these state transitions: + // start->num => "num" start->gap => "::" + // num->num => ":num" num->gap => "::" + // gap->num => "num" gap->gap => "" + StringBuilder buf = new StringBuilder(39); + boolean lastWasNumber = false; + for (int i = 0; i < hextets.length; i++) { + boolean thisIsNumber = hextets[i] >= 0; + if (thisIsNumber) { + if (lastWasNumber) { + buf.append(':'); + } + buf.append(Integer.toHexString(hextets[i])); + } else { + if (i == 0 || lastWasNumber) { + buf.append("::"); + } + } + lastWasNumber = thisIsNumber; + } + return buf.toString(); + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java new file mode 100644 index 00000000000..517745a5afc --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java @@ -0,0 +1,257 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import com.google.protobuf.Duration; +import com.google.protobuf.util.Durations; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.Context; +import io.grpc.Deadline; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Internal; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A logging client interceptor for Observability. + */ +@Internal +public final class InternalLoggingChannelInterceptor implements ClientInterceptor { + + private static final Logger logger = Logger + .getLogger(InternalLoggingChannelInterceptor.class.getName()); + + private final LogHelper helper; + private final ConfigFilterHelper filterHelper; + + // TODO(dnvindhya): Remove factory and use interceptors directly + public interface Factory { + ClientInterceptor create(); + } + + public static class FactoryImpl implements Factory { + + private final LogHelper helper; + private final ConfigFilterHelper filterHelper; + + /** + * Create the {@link Factory} we need to create our {@link ClientInterceptor}s. + */ + public FactoryImpl(LogHelper helper, ConfigFilterHelper filterHelper) { + this.helper = helper; + this.filterHelper = filterHelper; + } + + @Override + public ClientInterceptor create() { + return new InternalLoggingChannelInterceptor(helper, filterHelper); + } + } + + private InternalLoggingChannelInterceptor(LogHelper helper, ConfigFilterHelper filterHelper) { + this.helper = helper; + this.filterHelper = filterHelper; + } + + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, Channel next) { + + final AtomicLong seq = new AtomicLong(1); + final String callId = UUID.randomUUID().toString(); + final String authority = next.authority(); + final String serviceName = method.getServiceName(); + final String methodName = method.getBareMethodName(); + // Get the stricter deadline to calculate the timeout once the call starts + final Deadline deadline = LogHelper.min(callOptions.getDeadline(), + Context.current().getDeadline()); + + FilterParams filterParams = filterHelper.logRpcMethod(method.getFullMethodName(), true); + if (!filterParams.log()) { + return next.newCall(method, callOptions); + } + + final int maxHeaderBytes = filterParams.headerBytes(); + final int maxMessageBytes = filterParams.messageBytes(); + + return new SimpleForwardingClientCall(next.newCall(method, callOptions)) { + + @Override + public void start(Listener responseListener, Metadata headers) { + // Event: EventType.CLIENT_HEADER + // The timeout should reflect the time remaining when the call is started, so compute + // remaining time here. + final Duration timeout = deadline == null ? null + : Durations.fromNanos(deadline.timeRemaining(TimeUnit.NANOSECONDS)); + + try { + helper.logClientHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + timeout, + headers, + maxHeaderBytes, + EventLogger.CLIENT, + callId, + null); + } catch (Exception e) { + // Catching generic exceptions instead of specific ones for all the events. + // This way we can catch both expected and unexpected exceptions instead of re-throwing + // exceptions to callers which will lead to RPC getting aborted. + // Expected exceptions to be caught: + // 1. IllegalArgumentException + // 2. NullPointerException + logger.log(Level.SEVERE, "Unable to log request header", e); + } + + Listener observabilityListener = + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onMessage(RespT message) { + // Event: EventType.SERVER_MESSAGE + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventType.SERVER_MESSAGE, + message, + maxMessageBytes, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response message", e); + } + super.onMessage(message); + } + + @Override + public void onHeaders(Metadata headers) { + // Event: EventType.SERVER_HEADER + try { + helper.logServerHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + headers, + maxHeaderBytes, + EventLogger.CLIENT, + callId, + LogHelper.getPeerAddress(getAttributes())); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response header", e); + } + super.onHeaders(headers); + } + + @Override + public void onClose(Status status, Metadata trailers) { + // Event: EventType.SERVER_TRAILER + try { + helper.logTrailer( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + status, + trailers, + maxHeaderBytes, + EventLogger.CLIENT, + callId, + LogHelper.getPeerAddress(getAttributes())); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log trailer", e); + } + super.onClose(status, trailers); + } + }; + super.start(observabilityListener, headers); + } + + @Override + public void sendMessage(ReqT message) { + // Event: EventType.CLIENT_MESSAGE + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventType.CLIENT_MESSAGE, + message, + maxMessageBytes, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log request message", e); + } + super.sendMessage(message); + } + + @Override + public void halfClose() { + // Event: EventType.CLIENT_HALF_CLOSE + try { + helper.logHalfClose( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log half close", e); + } + super.halfClose(); + } + + @Override + public void cancel(String message, Throwable cause) { + // Event: EventType.CANCEL + try { + helper.logCancel( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log cancel", e); + } + super.cancel(message, cause); + } + }; + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java new file mode 100644 index 00000000000..acb8df29166 --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java @@ -0,0 +1,251 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import com.google.protobuf.Duration; +import com.google.protobuf.util.Durations; +import io.grpc.Context; +import io.grpc.Deadline; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Internal; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import java.net.SocketAddress; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A logging server interceptor for Observability. + */ +@Internal +public final class InternalLoggingServerInterceptor implements ServerInterceptor { + + private static final Logger logger = Logger + .getLogger(InternalLoggingServerInterceptor.class.getName()); + + private final LogHelper helper; + private final ConfigFilterHelper filterHelper; + + // TODO(dnvindhya): Remove factory and use interceptors directly + public interface Factory { + ServerInterceptor create(); + } + + public static class FactoryImpl implements Factory { + + private final LogHelper helper; + private final ConfigFilterHelper filterHelper; + + /** + * Create the {@link Factory} we need to create our {@link ServerInterceptor}s. + */ + public FactoryImpl(LogHelper helper, ConfigFilterHelper filterHelper) { + this.helper = helper; + this.filterHelper = filterHelper; + } + + @Override + public ServerInterceptor create() { + return new InternalLoggingServerInterceptor(helper, filterHelper); + } + } + + private InternalLoggingServerInterceptor(LogHelper helper, ConfigFilterHelper filterHelper) { + this.helper = helper; + this.filterHelper = filterHelper; + } + + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + final AtomicLong seq = new AtomicLong(1); + final String callId = UUID.randomUUID().toString(); + final String authority = call.getAuthority(); + final String serviceName = call.getMethodDescriptor().getServiceName(); + final String methodName = call.getMethodDescriptor().getBareMethodName(); + final SocketAddress peerAddress = LogHelper.getPeerAddress(call.getAttributes()); + Deadline deadline = Context.current().getDeadline(); + final Duration timeout = deadline == null ? null + : Durations.fromNanos(deadline.timeRemaining(TimeUnit.NANOSECONDS)); + + FilterParams filterParams = + filterHelper.logRpcMethod(call.getMethodDescriptor().getFullMethodName(), false); + if (!filterParams.log()) { + return next.startCall(call, headers); + } + + final int maxHeaderBytes = filterParams.headerBytes(); + final int maxMessageBytes = filterParams.messageBytes(); + + // Event: EventType.CLIENT_HEADER + try { + helper.logClientHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + timeout, + headers, + maxHeaderBytes, + EventLogger.SERVER, + callId, + peerAddress); + } catch (Exception e) { + // Catching generic exceptions instead of specific ones for all the events. + // This way we can catch both expected and unexpected exceptions instead of re-throwing + // exceptions to callers which will lead to RPC getting aborted. + // Expected exceptions to be caught: + // 1. IllegalArgumentException + // 2. NullPointerException + logger.log(Level.SEVERE, "Unable to log request header", e); + } + + ServerCall wrapperCall = + new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata headers) { + // Event: EventType.SERVER_HEADER + try { + helper.logServerHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + headers, + maxHeaderBytes, + EventLogger.SERVER, + callId, + null); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response header", e); + } + super.sendHeaders(headers); + } + + @Override + public void sendMessage(RespT message) { + // Event: EventType.SERVER_MESSAGE + EventType responseMessageType = EventType.SERVER_MESSAGE; + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + responseMessageType, + message, + maxMessageBytes, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response message", e); + } + super.sendMessage(message); + } + + @Override + public void close(Status status, Metadata trailers) { + // Event: EventType.SERVER_TRAILER + try { + helper.logTrailer( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + status, + trailers, + maxHeaderBytes, + EventLogger.SERVER, + callId, + null); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log trailer", e); + } + super.close(status, trailers); + } + }; + + ServerCall.Listener listener = next.startCall(wrapperCall, headers); + return new SimpleForwardingServerCallListener(listener) { + @Override + public void onMessage(ReqT message) { + + // Event: EventType.CLIENT_MESSAGE + EventType requestMessageType = EventType.CLIENT_MESSAGE; + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + requestMessageType, + message, + maxMessageBytes, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log request message", e); + } + super.onMessage(message); + } + + @Override + public void onHalfClose() { + // Event: EventType.CLIENT_HALF_CLOSE + try { + helper.logHalfClose( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log half close", e); + } + super.onHalfClose(); + } + + @Override + public void onCancel() { + // Event: EventType.CANCEL + try { + helper.logCancel( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log cancel", e); + } + super.onCancel(); + } + }; + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java new file mode 100644 index 00000000000..9b46699efaf --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java @@ -0,0 +1,447 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; + +import com.google.common.base.Joiner; +import com.google.protobuf.ByteString; +import com.google.protobuf.Duration; +import io.grpc.Attributes; +import io.grpc.Deadline; +import io.grpc.Grpc; +import io.grpc.Internal; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.gcp.observability.logging.Sink; +import io.grpc.observabilitylog.v1.Address; +import io.grpc.observabilitylog.v1.GrpcLogRecord; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import io.grpc.observabilitylog.v1.Payload; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * Helper class for GCP observability logging. + */ +@Internal +public class LogHelper { + private static final Logger logger = Logger.getLogger(LogHelper.class.getName()); + + // TODO(DNVindhya): Define it in one places(TBD) to make it easily accessible from everywhere + static final Metadata.Key STATUS_DETAILS_KEY = + Metadata.Key.of( + "grpc-status-details-bin", + Metadata.BINARY_BYTE_MARSHALLER); + + private final Sink sink; + + /** + * Creates a LogHelper instance. + * @param sink sink + * + */ + public LogHelper(Sink sink) { + this.sink = sink; + } + + /** + * Logs the request header. Binary logging equivalent of logClientHeader. + */ + void logClientHeader( + long seqId, + String serviceName, + String methodName, + String authority, + @Nullable Duration timeout, + Metadata metadata, + int maxHeaderBytes, + GrpcLogRecord.EventLogger eventLogger, + String callId, + // null on client side + @Nullable SocketAddress peerAddress) { + checkNotNull(serviceName, "serviceName"); + checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); + checkArgument( + peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.SERVER, + "peerAddress can only be specified by server"); + PayloadBuilderHelper pair = + createMetadataProto(metadata, maxHeaderBytes); + if (timeout != null) { + pair.payloadBuilder.setTimeout(timeout); + } + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.CLIENT_HEADER) + .setLogger(eventLogger) + .setPayload(pair.payloadBuilder) + .setPayloadTruncated(pair.truncated) + .setCallId(callId); + if (peerAddress != null) { + logEntryBuilder.setPeer(socketAddressToProto(peerAddress)); + } + sink.write(logEntryBuilder.build()); + } + + /** + * Logs the response header. Binary logging equivalent of logServerHeader. + */ + void logServerHeader( + long seqId, + String serviceName, + String methodName, + String authority, + Metadata metadata, + int maxHeaderBytes, + GrpcLogRecord.EventLogger eventLogger, + String callId, + @Nullable SocketAddress peerAddress) { + checkNotNull(serviceName, "serviceName"); + checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); + // Logging peer address only on the first incoming event. On server side, peer address will + // of logging request header + checkArgument( + peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.CLIENT, + "peerAddress can only be specified for client"); + + PayloadBuilderHelper pair = + createMetadataProto(metadata, maxHeaderBytes); + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.SERVER_HEADER) + .setLogger(eventLogger) + .setPayload(pair.payloadBuilder) + .setPayloadTruncated(pair.truncated) + .setCallId(callId); + if (peerAddress != null) { + logEntryBuilder.setPeer(socketAddressToProto(peerAddress)); + } + sink.write(logEntryBuilder.build()); + } + + /** + * Logs the server trailer. + */ + void logTrailer( + long seqId, + String serviceName, + String methodName, + String authority, + Status status, + Metadata metadata, + int maxHeaderBytes, + GrpcLogRecord.EventLogger eventLogger, + String callId, + @Nullable SocketAddress peerAddress) { + checkNotNull(serviceName, "serviceName"); + checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); + checkNotNull(status, "status"); + checkNotNull(callId, "callId"); + checkArgument( + peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.CLIENT, + "peerAddress can only be specified for client"); + + PayloadBuilderHelper pair = + createMetadataProto(metadata, maxHeaderBytes); + pair.payloadBuilder.setStatusCode(status.getCode().value()); + String statusDescription = status.getDescription(); + if (statusDescription != null) { + pair.payloadBuilder.setStatusMessage(statusDescription); + } + byte[] statusDetailBytes = metadata.get(STATUS_DETAILS_KEY); + if (statusDetailBytes != null) { + pair.payloadBuilder.setStatusDetails(ByteString.copyFrom(statusDetailBytes)); + } + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.SERVER_TRAILER) + .setLogger(eventLogger) + .setPayload(pair.payloadBuilder) + .setPayloadTruncated(pair.truncated) + .setCallId(callId); + if (peerAddress != null) { + logEntryBuilder.setPeer(socketAddressToProto(peerAddress)); + } + sink.write(logEntryBuilder.build()); + } + + /** + * Logs the RPC message. + */ + void logRpcMessage( + long seqId, + String serviceName, + String methodName, + String authority, + EventType eventType, + T message, + int maxMessageBytes, + EventLogger eventLogger, + String callId) { + checkNotNull(serviceName, "serviceName"); + checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); + checkArgument( + eventType == EventType.CLIENT_MESSAGE + || eventType == EventType.SERVER_MESSAGE, + "event type must correspond to client message or server message"); + checkNotNull(message, "message"); + + // TODO(DNVindhya): Implement conversion of generics to ByteString + // Following is a temporary workaround to log if message is of following types : + // 1. com.google.protobuf.Message + // 2. byte[] + byte[] messageBytesArray = null; + if (message instanceof com.google.protobuf.Message) { + messageBytesArray = ((com.google.protobuf.Message) message).toByteArray(); + } else if (message instanceof byte[]) { + messageBytesArray = (byte[]) message; + } else { + logger.log(Level.WARNING, "message is of UNKNOWN type, message and payload_size fields " + + "of GrpcLogRecord proto will not be logged"); + } + PayloadBuilderHelper pair = null; + if (messageBytesArray != null) { + pair = createMessageProto(messageBytesArray, maxMessageBytes); + } + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(eventType) + .setLogger(eventLogger) + .setCallId(callId); + if (pair != null) { + logEntryBuilder.setPayload(pair.payloadBuilder) + .setPayloadTruncated(pair.truncated); + } + sink.write(logEntryBuilder.build()); + } + + /** + * Logs half close. + */ + void logHalfClose( + long seqId, + String serviceName, + String methodName, + String authority, + GrpcLogRecord.EventLogger eventLogger, + String callId) { + checkNotNull(serviceName, "serviceName"); + checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); + + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.CLIENT_HALF_CLOSE) + .setLogger(eventLogger) + .setCallId(callId); + sink.write(logEntryBuilder.build()); + } + + /** + * Logs cancellation. + */ + void logCancel( + long seqId, + String serviceName, + String methodName, + String authority, + GrpcLogRecord.EventLogger eventLogger, + String callId) { + checkNotNull(serviceName, "serviceName"); + checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); + + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.CANCEL) + .setLogger(eventLogger) + .setCallId(callId); + sink.write(logEntryBuilder.build()); + } + + // TODO(DNVindhya): Evaluate if we need following clause for metadata logging in GcpObservability + // Leaving the implementation for now as is to have same behavior across Java and Go + private static final Set NEVER_INCLUDED_METADATA = new HashSet<>( + Collections.singletonList( + // grpc-status-details-bin is already logged in `status_details` field of the + // observabilitylog proto + STATUS_DETAILS_KEY.name())); + private static final Set ALWAYS_INCLUDED_METADATA = new HashSet<>( + Collections.singletonList( + "grpc-trace-bin")); + + static final class PayloadBuilderHelper { + T payloadBuilder; + boolean truncated; + + private PayloadBuilderHelper(T payload, boolean truncated) { + this.payloadBuilder = payload; + this.truncated = truncated; + } + } + + static PayloadBuilderHelper createMetadataProto(Metadata metadata, + int maxHeaderBytes) { + checkNotNull(metadata, "metadata"); + checkArgument(maxHeaderBytes >= 0, + "maxHeaderBytes must be non negative"); + Joiner joiner = Joiner.on(",").skipNulls(); + Payload.Builder payloadBuilder = Payload.newBuilder(); + boolean truncated = false; + int totalMetadataBytes = 0; + for (String key : metadata.keys()) { + if (NEVER_INCLUDED_METADATA.contains(key)) { + continue; + } + boolean forceInclude = ALWAYS_INCLUDED_METADATA.contains(key); + String metadataValue; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + Iterable metadataValues = + metadata.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + List numList = new ArrayList(); + metadataValues.forEach( + (element) -> { + numList.add(BASE64_ENCODING_OMIT_PADDING.encode(element)); + }); + metadataValue = joiner.join(numList); + } else { + Iterable metadataValues = metadata.getAll( + Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + metadataValue = joiner.join(metadataValues); + } + + int metadataBytesAfterAdd = totalMetadataBytes + key.length() + metadataValue.length(); + if (!forceInclude && metadataBytesAfterAdd > maxHeaderBytes) { + truncated = true; + continue; + } + payloadBuilder.putMetadata(key, metadataValue); + if (!forceInclude) { + // force included keys do not count towards the size limit + totalMetadataBytes = metadataBytesAfterAdd; + } + } + return new PayloadBuilderHelper<>(payloadBuilder, truncated); + } + + static PayloadBuilderHelper createMessageProto( + byte[] message, int maxMessageBytes) { + checkArgument(maxMessageBytes >= 0, + "maxMessageBytes must be non negative"); + Payload.Builder payloadBuilder = Payload.newBuilder(); + int desiredBytes = 0; + int messageLength = message.length; + if (maxMessageBytes > 0) { + desiredBytes = Math.min(maxMessageBytes, messageLength); + } + ByteString messageData = + ByteString.copyFrom(message, 0, desiredBytes); + payloadBuilder.setMessage(messageData); + payloadBuilder.setMessageLength(messageLength); + + return new PayloadBuilderHelper<>(payloadBuilder, + maxMessageBytes < message.length); + } + + static Address socketAddressToProto(SocketAddress address) { + checkNotNull(address, "address"); + Address.Builder builder = Address.newBuilder(); + if (address instanceof InetSocketAddress) { + InetAddress inetAddress = ((InetSocketAddress) address).getAddress(); + if (inetAddress instanceof Inet4Address) { + builder.setType(Address.Type.TYPE_IPV4) + .setAddress(InetAddressUtil.toAddrString(inetAddress)); + } else if (inetAddress instanceof Inet6Address) { + builder.setType(Address.Type.TYPE_IPV6) + .setAddress(InetAddressUtil.toAddrString(inetAddress)); + } else { + logger.log(Level.SEVERE, "unknown type of InetSocketAddress: {}", address); + builder.setAddress(address.toString()); + } + builder.setIpPort(((InetSocketAddress) address).getPort()); + } else if (address.getClass().getName().equals("io.netty.channel.unix.DomainSocketAddress")) { + // To avoid a compiled time dependency on grpc-netty, we check against the + // runtime class name. + builder.setType(Address.Type.TYPE_UNIX) + .setAddress(address.toString()); + } else { + builder.setType(Address.Type.TYPE_UNKNOWN).setAddress(address.toString()); + } + return builder.build(); + } + + /** + * Retrieves socket address. + */ + static SocketAddress getPeerAddress(Attributes streamAttributes) { + return streamAttributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + } + + /** + * Checks deadline for timeout. + */ + static Deadline min(@Nullable Deadline deadline0, @Nullable Deadline deadline1) { + if (deadline0 == null) { + return deadline1; + } + if (deadline1 == null) { + return deadline0; + } + return deadline0.minimum(deadline1); + } +} diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java new file mode 100644 index 00000000000..e91f310e647 --- /dev/null +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java @@ -0,0 +1,191 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.logging; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.MonitoredResource; +import com.google.cloud.logging.LogEntry; +import com.google.cloud.logging.Logging; +import com.google.cloud.logging.LoggingOptions; +import com.google.cloud.logging.Payload.JsonPayload; +import com.google.cloud.logging.Severity; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.util.JsonFormat; +import io.grpc.Internal; +import io.grpc.internal.JsonParser; +import io.grpc.observabilitylog.v1.GrpcLogRecord; +import java.io.IOException; +import java.time.Instant; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Sink for Google Cloud Logging. + */ +@Internal +public class GcpLogSink implements Sink { + private final Logger logger = Logger.getLogger(GcpLogSink.class.getName()); + + private static final String DEFAULT_LOG_NAME = + "microservices.googleapis.com%2Fobservability%2Fgrpc"; + private static final Severity DEFAULT_LOG_LEVEL = Severity.DEBUG; + private static final String K8S_MONITORED_RESOURCE_TYPE = "k8s_container"; + private static final Set kubernetesResourceLabelSet + = ImmutableSet.of("project_id", "location", "cluster_name", "namespace_name", + "pod_name", "container_name"); + private final String projectId; + private final Map customTags; + private final MonitoredResource kubernetesResource; + /** Lazily initialize cloud logging client to avoid circular initialization. Because cloud + * logging APIs also uses gRPC. */ + private volatile Logging gcpLoggingClient; + private final Collection servicesToExclude; + + @VisibleForTesting + GcpLogSink(Logging loggingClient, String projectId, Map locationTags, + Map customTags, Collection servicesToExclude) { + this(projectId, locationTags, customTags, servicesToExclude); + this.gcpLoggingClient = loggingClient; + } + + /** + * Retrieves a single instance of GcpLogSink. + * + * @param projectId GCP project id to write logs + * @param servicesToExclude service names for which log entries should not be generated + */ + public GcpLogSink(String projectId, Map locationTags, + Map customTags, Collection servicesToExclude) { + this.projectId = projectId; + this.customTags = getCustomTags(customTags, locationTags, projectId); + this.kubernetesResource = getResource(locationTags); + this.servicesToExclude = checkNotNull(servicesToExclude, "servicesToExclude"); + } + + /** + * Writes logs to GCP Cloud Logging. + * + * @param logProto gRPC logging proto containing the message to be logged + */ + @Override + public void write(GrpcLogRecord logProto) { + if (gcpLoggingClient == null) { + synchronized (this) { + if (gcpLoggingClient == null) { + gcpLoggingClient = createLoggingClient(); + } + } + } + if (servicesToExclude.contains(logProto.getServiceName())) { + return; + } + try { + GrpcLogRecord.EventType eventType = logProto.getType(); + // TODO(DNVindhya): make sure all (int, long) values are not displayed as double + // For now, every value is being converted as string because of JsonFormat.printer().print + Map logProtoMap = protoToMapConverter(logProto); + LogEntry.Builder grpcLogEntryBuilder = + LogEntry.newBuilder(JsonPayload.of(logProtoMap)) + .setSeverity(DEFAULT_LOG_LEVEL) + .setLogName(DEFAULT_LOG_NAME) + .setResource(kubernetesResource) + .setTimestamp(Instant.now()); + + if (!customTags.isEmpty()) { + grpcLogEntryBuilder.setLabels(customTags); + } + LogEntry grpcLogEntry = grpcLogEntryBuilder.build(); + synchronized (this) { + logger.log(Level.FINEST, "Writing gRPC event : {0} to Cloud Logging", eventType); + gcpLoggingClient.write(Collections.singleton(grpcLogEntry)); + } + } catch (Exception e) { + logger.log(Level.SEVERE, "Caught exception while writing to Cloud Logging", e); + } + } + + Logging createLoggingClient() { + LoggingOptions.Builder builder = LoggingOptions.newBuilder(); + if (!Strings.isNullOrEmpty(projectId)) { + builder.setProjectId(projectId); + } + return builder.build().getService(); + } + + @VisibleForTesting + static Map getCustomTags(Map customTags, + Map locationTags, String projectId) { + ImmutableMap.Builder tagsBuilder = ImmutableMap.builder(); + String sourceProjectId = locationTags.get("project_id"); + if (!Strings.isNullOrEmpty(projectId) + && !Strings.isNullOrEmpty(sourceProjectId) + && !Objects.equals(sourceProjectId, projectId)) { + tagsBuilder.put("source_project_id", sourceProjectId); + } + if (customTags != null) { + tagsBuilder.putAll(customTags); + } + return tagsBuilder.buildOrThrow(); + } + + @VisibleForTesting + static MonitoredResource getResource(Map resourceTags) { + MonitoredResource.Builder builder = MonitoredResource.newBuilder(K8S_MONITORED_RESOURCE_TYPE); + if ((resourceTags != null) && !resourceTags.isEmpty()) { + for (Map.Entry entry : resourceTags.entrySet()) { + String resourceKey = entry.getKey(); + if (kubernetesResourceLabelSet.contains(resourceKey)) { + builder.addLabel(resourceKey, entry.getValue()); + } + } + } + return builder.build(); + } + + @SuppressWarnings("unchecked") + private Map protoToMapConverter(GrpcLogRecord logProto) + throws IOException { + JsonFormat.Printer printer = JsonFormat.printer(); + String recordJson = printer.print(logProto); + return (Map) JsonParser.parse(recordJson); + } + + /** + * Closes Cloud Logging Client. + */ + @Override + public synchronized void close() { + if (gcpLoggingClient == null) { + logger.log(Level.WARNING, "Attempt to close after GcpLogSink is closed."); + return; + } + try { + gcpLoggingClient.close(); + } catch (Exception e) { + logger.log(Level.SEVERE, "Caught exception while closing", e); + } + } +} diff --git a/observability/src/main/java/io/grpc/observability/logging/LogRecordExtension.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/Sink.java similarity index 51% rename from observability/src/main/java/io/grpc/observability/logging/LogRecordExtension.java rename to gcp-observability/src/main/java/io/grpc/gcp/observability/logging/Sink.java index a1220659302..c0908cfe3db 100644 --- a/observability/src/main/java/io/grpc/observability/logging/LogRecordExtension.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/Sink.java @@ -14,31 +14,23 @@ * limitations under the License. */ -package io.grpc.observability.logging; +package io.grpc.gcp.observability.logging; import io.grpc.Internal; import io.grpc.observabilitylog.v1.GrpcLogRecord; -import java.util.logging.Level; -import java.util.logging.LogRecord; /** - * An extension of java.util.logging.LogRecord which includes gRPC observability logging specific - * fields. + * Sink for GCP observability. */ @Internal -public final class LogRecordExtension extends LogRecord { +public interface Sink { + /** + * Writes the {@code message} to the destination. + */ + void write(GrpcLogRecord message); - private final GrpcLogRecord grpcLogRecord; - - public LogRecordExtension(Level recordLevel, GrpcLogRecord record) { - super(recordLevel, null); - this.grpcLogRecord = record; - } - - public GrpcLogRecord getGrpcLogRecord() { - return grpcLogRecord; - } - - // Adding a serial version UID since base class i.e LogRecord is Serializable - private static final long serialVersionUID = 1L; + /** + * Closes the sink. + */ + void close(); } diff --git a/observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto b/gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto similarity index 52% rename from observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto rename to gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto index d2e72329cde..85ef00ac2dd 100644 --- a/observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto +++ b/gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto @@ -28,151 +28,99 @@ option java_outer_classname = "ObservabilityLogProto"; message GrpcLogRecord { // List of event types enum EventType { - GRPC_CALL_UNKNOWN = 0; + EVENT_TYPE_UNKNOWN = 0; // Header sent from client to server - GRPC_CALL_REQUEST_HEADER = 1; + CLIENT_HEADER = 1; // Header sent from server to client - GRPC_CALL_RESPONSE_HEADER = 2; + SERVER_HEADER = 2; // Message sent from client to server - GRPC_CALL_REQUEST_MESSAGE = 3; + CLIENT_MESSAGE = 3; // Message sent from server to client - GRPC_CALL_RESPONSE_MESSAGE = 4; - // Trailer indicates the end of the gRPC call - GRPC_CALL_TRAILER = 5; + SERVER_MESSAGE = 4; // A signal that client is done sending - GRPC_CALL_HALF_CLOSE = 6; + CLIENT_HALF_CLOSE = 5; + // Trailer indicates the end of the gRPC call + SERVER_TRAILER = 6; // A signal that the rpc is canceled - GRPC_CALL_CANCEL = 7; + CANCEL = 7; } + // The entity that generates the log entry enum EventLogger { LOGGER_UNKNOWN = 0; - LOGGER_CLIENT = 1; - LOGGER_SERVER = 2; - } - // The log severity level of the log entry - enum LogLevel { - LOG_LEVEL_UNKNOWN = 0; - LOG_LEVEL_TRACE = 1; - LOG_LEVEL_DEBUG = 2; - LOG_LEVEL_INFO = 3; - LOG_LEVEL_WARN = 4; - LOG_LEVEL_ERROR = 5; - LOG_LEVEL_CRITICAL = 6; + CLIENT = 1; + SERVER = 2; } - // The timestamp of the log event - google.protobuf.Timestamp timestamp = 1; - - // Uniquely identifies a call. The value must not be 0 in order to disambiguate - // from an unset value. - // Each call may have several log entries. They will all have the same rpc_id. + // Uniquely identifies a call. + // Each call may have several log entries. They will all have the same call_id. // Nothing is guaranteed about their value other than they are unique across // different RPCs in the same gRPC process. - uint64 rpc_id = 2; + string call_id = 2; - EventType event_type = 3; // one of the above EventType enum - EventLogger event_logger = 4; // one of the above EventLogger enum + // The entry sequence ID for this call. The first message has a value of 1, + // to disambiguate from an unset value. The purpose of this field is to + // detect missing entries in environments where durability or ordering is + // not guaranteed. + uint64 sequence_id = 3; - // the name of the service - string service_name = 5; - // the name of the RPC method - string method_name = 6; + EventType type = 4; // one of the above EventType enum + EventLogger logger = 5; // one of the above EventLogger enum - LogLevel log_level = 7; // one of the above LogLevel enum + // Payload for log entry. + // It can include a combination of {metadata, message, status based on type of + // the event event being logged and config options. + Payload payload = 6; + // true if message or metadata field is either truncated or omitted due + // to config options + bool payload_truncated = 7; // Peer address information. On client side, peer is logged on server // header event or trailer event (if trailer-only). On server side, peer // is always logged on the client header event. - Address peer_address = 8; - - // the RPC timeout value - google.protobuf.Duration timeout = 11; + Address peer = 8; // A single process may be used to run multiple virtual servers with // different identities. // The authority is the name of such a server identify. It is typically a // portion of the URI in the form of or :. - string authority = 12; - - // Size of the message or metadata, depending on the event type, - // regardless of whether the full message or metadata is being logged - // (i.e. could be truncated or omitted). - uint32 payload_size = 13; - - // true if message or metadata field is either truncated or omitted due - // to config options - bool payload_truncated = 14; - - // Used by header event or trailer event - Metadata metadata = 15; - - // The entry sequence ID for this call. The first message has a value of 1, - // to disambiguate from an unset value. The purpose of this field is to - // detect missing entries in environments where durability or ordering is - // not guaranteed. - uint64 sequence_id = 16; - - // Used by message event - bytes message = 17; + string authority = 10; + // the name of the service + string service_name = 11; + // the name of the RPC method + string method_name = 12; +} +message Payload { + // A list of metadata pairs + map metadata = 1; + // the RPC timeout value + google.protobuf.Duration timeout = 2; // The gRPC status code - uint32 status_code = 18; - + uint32 status_code = 3; // The gRPC status message - string status_message = 19; - + string status_message = 4; // The value of the grpc-status-details-bin metadata key, if any. // This is always an encoded google.rpc.Status message - bytes status_details = 20; - - // Attributes of the environment generating log record. The purpose of this - // field is to identify the source environment. - EnvironmentTags env_tags = 21; - - // A list of non-gRPC custom values specified by the application - repeated CustomTags custom_tags = 22; - - // A list of metadata pairs - message Metadata { - repeated MetadataEntry entry = 1; - } - - // One metadata key value pair - message MetadataEntry { - string key = 1; - bytes value = 2; - } - - // Address information - message Address { - enum Type { - TYPE_UNKNOWN = 0; - TYPE_IPV4 = 1; // in 1.2.3.4 form - TYPE_IPV6 = 2; // IPv6 canonical form (RFC5952 section 4) - TYPE_UNIX = 3; // UDS string - } - Type type = 1; - string address = 2; - // only for TYPE_IPV4 and TYPE_IPV6 - uint32 ip_port = 3; - } - - // Source Environment information - message EnvironmentTags { - string gcp_project_id = 1; - string gcp_numeric_project_id = 2; - string gce_instance_id = 3; - string gce_instance_hostname = 4; - string gce_instance_zone = 5; - string gke_cluster_uid = 6; - string gke_cluster_name = 7; - string gke_cluster_location = 8; - } + bytes status_details = 5; + // Size of the message or metadata, depending on the event type, + // regardless of whether the full message or metadata is being logged + // (i.e. could be truncated or omitted). + uint32 message_length = 6; + // Used by message event + bytes message = 7; +} - // Custom key value pair - message CustomTags { - string key = 1; - string value = 2; +// Address information +message Address { + enum Type { + TYPE_UNKNOWN = 0; + TYPE_IPV4 = 1; // in 1.2.3.4 form + TYPE_IPV6 = 2; // IPv6 canonical form (RFC5952 section 4) + TYPE_UNIX = 3; // UDS string } + Type type = 1; + string address = 2; + // only for TYPE_IPV4 and TYPE_IPV6 + uint32 ip_port = 3; } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java new file mode 100644 index 00000000000..c42d7b65c08 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java @@ -0,0 +1,256 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.InternalGlobalInterceptors; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.StaticTestingClassLoader; +import io.grpc.gcp.observability.interceptors.ConditionalClientInterceptor; +import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor; +import io.grpc.gcp.observability.logging.Sink; +import io.opencensus.trace.samplers.Samplers; +import java.io.IOException; +import java.util.List; +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GcpObservabilityTest { + + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader( + getClass().getClassLoader(), + Pattern.compile( + "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|" + + "io\\.grpc\\.gcp\\.observability\\.[^.]+|" + + "io\\.grpc\\.gcp\\.observability\\.interceptors\\.[^.]+|" + + "io\\.grpc\\.gcp\\.observability\\.GcpObservabilityTest\\$.*")); + + @Test + public void initFinish() throws Exception { + Class runnable = + classLoader.loadClass(StaticTestingClassInitFinish.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void enableObservability() throws Exception { + Class runnable = + classLoader.loadClass(StaticTestingClassEnableObservability.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void disableObservability() throws Exception { + Class runnable = + classLoader.loadClass(StaticTestingClassDisableObservability.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + @SuppressWarnings("unchecked") + public void conditionalInterceptor() { + ClientInterceptor delegate = mock(ClientInterceptor.class); + Channel channel = mock(Channel.class); + ClientCall returnedCall = mock(ClientCall.class); + + ConditionalClientInterceptor conditionalClientInterceptor + = GcpObservability.getConditionalInterceptor( + delegate); + MethodDescriptor method = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("google.logging.v2.LoggingServiceV2/method") + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build(); + doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT); + ClientCall clientCall = conditionalClientInterceptor.interceptCall(method, + CallOptions.DEFAULT, channel); + verifyNoInteractions(delegate); + assertThat(clientCall).isSameInstanceAs(returnedCall); + + method = MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("google.monitoring.v3.MetricService/method2") + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build(); + doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT); + clientCall = conditionalClientInterceptor.interceptCall(method, CallOptions.DEFAULT, channel); + verifyNoInteractions(delegate); + assertThat(clientCall).isSameInstanceAs(returnedCall); + + method = MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("google.devtools.cloudtrace.v2.TraceService/method3") + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build(); + doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT); + clientCall = conditionalClientInterceptor.interceptCall(method, CallOptions.DEFAULT, channel); + verifyNoInteractions(delegate); + assertThat(clientCall).isSameInstanceAs(returnedCall); + + reset(channel); + ClientCall interceptedCall = mock(ClientCall.class); + method = MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("some.other.random.service/method4") + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build(); + doReturn(interceptedCall).when(delegate).interceptCall(method, CallOptions.DEFAULT, channel); + clientCall = conditionalClientInterceptor.interceptCall(method, CallOptions.DEFAULT, channel); + verifyNoInteractions(channel); + assertThat(clientCall).isSameInstanceAs(interceptedCall); + } + + // UsedReflectively + public static final class StaticTestingClassInitFinish implements Runnable { + + @Override + public void run() { + Sink sink = mock(Sink.class); + ObservabilityConfig config = mock(ObservabilityConfig.class); + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = + mock(InternalLoggingChannelInterceptor.Factory.class); + InternalLoggingServerInterceptor.Factory serverInterceptorFactory = + mock(InternalLoggingServerInterceptor.Factory.class); + GcpObservability observability1; + try { + GcpObservability observability = + GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory); + observability1 = + GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory); + assertThat(observability1).isSameInstanceAs(observability); + observability.close(); + verify(sink).close(); + try { + observability1.close(); + fail("should have failed for calling close() second time"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().contains("GcpObservability already closed!"); + } + } catch (IOException e) { + fail("Encountered exception: " + e); + } + } + } + + public static final class StaticTestingClassEnableObservability implements Runnable { + + @Override + public void run() { + Sink sink = mock(Sink.class); + ObservabilityConfig config = mock(ObservabilityConfig.class); + when(config.isEnableCloudLogging()).thenReturn(true); + when(config.isEnableCloudMonitoring()).thenReturn(true); + when(config.isEnableCloudTracing()).thenReturn(true); + when(config.getSampler()).thenReturn(Samplers.neverSample()); + + ClientInterceptor clientInterceptor = + mock(ClientInterceptor.class, delegatesTo(new NoopClientInterceptor())); + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = + mock(InternalLoggingChannelInterceptor.Factory.class); + when(channelInterceptorFactory.create()).thenReturn(clientInterceptor); + + ServerInterceptor serverInterceptor = + mock(ServerInterceptor.class, delegatesTo(new NoopServerInterceptor())); + InternalLoggingServerInterceptor.Factory serverInterceptorFactory = + mock(InternalLoggingServerInterceptor.Factory.class); + when(serverInterceptorFactory.create()).thenReturn(serverInterceptor); + + try (GcpObservability unused = + GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory)) { + List list = InternalGlobalInterceptors.getClientInterceptors(); + assertThat(list).hasSize(3); + assertThat(list.get(1)).isInstanceOf(ConditionalClientInterceptor.class); + assertThat(list.get(2)).isInstanceOf(ConditionalClientInterceptor.class); + assertThat(InternalGlobalInterceptors.getServerInterceptors()).hasSize(1); + assertThat(InternalGlobalInterceptors.getServerStreamTracerFactories()).hasSize(2); + } catch (Exception e) { + fail("Encountered exception: " + e); + } + } + } + + public static final class StaticTestingClassDisableObservability implements Runnable { + + @Override + public void run() { + Sink sink = mock(Sink.class); + ObservabilityConfig config = mock(ObservabilityConfig.class); + when(config.isEnableCloudLogging()).thenReturn(false); + when(config.isEnableCloudMonitoring()).thenReturn(false); + when(config.isEnableCloudTracing()).thenReturn(false); + when(config.getSampler()).thenReturn(Samplers.neverSample()); + + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = + mock(InternalLoggingChannelInterceptor.Factory.class); + InternalLoggingServerInterceptor.Factory serverInterceptorFactory = + mock(InternalLoggingServerInterceptor.Factory.class); + + try (GcpObservability unused = + GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory)) { + assertThat(InternalGlobalInterceptors.getClientInterceptors()).isEmpty(); + assertThat(InternalGlobalInterceptors.getServerInterceptors()).isEmpty(); + assertThat(InternalGlobalInterceptors.getServerStreamTracerFactories()).isEmpty(); + } catch (Exception e) { + fail("Encountered exception: " + e); + } + verify(sink).close(); + } + } + + private static class NoopClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + } + + private static class NoopServerInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + return next.startCall(call, headers); + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/GlobalLocationTagsTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/GlobalLocationTagsTest.java new file mode 100644 index 00000000000..86feb1aaec6 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/GlobalLocationTagsTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Files; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GlobalLocationTagsTest { + private static String FILE_CONTENTS = + "12:perf_event:/kubepods/burstable/podc43b6442-0725-4fb8-bb1c-d17f5122155c/" + + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7\n" + + "11:freezer:/kubepods/burstable/podc43b6442-0725-4fb8-bb1c-d17f5122155c/" + + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7\n" + + "2:rdma:/\n" + + "1:name=systemd:/kubepods/burstable/podc43b6442-0725-4fb8-bb1c-d17f5122155c/" + + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7\n" + + "0::/system.slice/containerd.service\n"; + + private static String FILE_CONTENTS_LAST_LINE = + "0::/system.slice/containerd.service\n" + + "6442-0725-4fb8-bb1c-d17f5122155cslslsl/fe61ca6482b58f4a9831d08d6ea15db25f\n" + + "\n" + + "12:perf_event:/kubepods/burstable/podc43b6442-0725-4fb8-bb1c-d17f5122155c/e19a54df\n"; + + @Rule public TemporaryFolder namespaceFolder = new TemporaryFolder(); + @Rule public TemporaryFolder hostnameFolder = new TemporaryFolder(); + @Rule public TemporaryFolder cgroupFolder = new TemporaryFolder(); + + @Test + public void testContainerIdParsing_lastLine() { + String containerId = GlobalLocationTags.getContainerIdFromFileContents(FILE_CONTENTS_LAST_LINE); + assertThat(containerId).isEqualTo("e19a54df"); + } + + @Test + public void testContainerIdParsing_fewerFields_notFound() { + String containerId = GlobalLocationTags.getContainerIdFromFileContents( + "12:/kubepods/burstable/podc43b6442-0725-4fb8-bb1c-d17f5122155c/" + + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7\n"); + assertThat(containerId).isNull(); + } + + @Test + public void testContainerIdParsing_fewerPaths_notFound() { + String containerId = GlobalLocationTags.getContainerIdFromFileContents( + "12:xdf:/kubepods/podc43b6442-0725-4fb8-bb1c-d17f5122155c/" + + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7\n"); + assertThat(containerId).isNull(); + } + + @Test + public void testPopulateKubernetesValues() throws IOException { + File namespaceFile = namespaceFolder.newFile(); + File hostnameFile = hostnameFolder.newFile(); + File cgroupFile = cgroupFolder.newFile(); + + Files.write("test-namespace1".getBytes(StandardCharsets.UTF_8), namespaceFile); + Files.write("test-hostname2\n".getBytes(StandardCharsets.UTF_8), hostnameFile); + Files.write(FILE_CONTENTS.getBytes(StandardCharsets.UTF_8), cgroupFile); + + ImmutableMap.Builder locationTags = ImmutableMap.builder(); + GlobalLocationTags.populateFromKubernetesValues(locationTags, namespaceFile.getAbsolutePath(), + hostnameFile.getAbsolutePath(), cgroupFile.getAbsolutePath()); + assertThat(locationTags.buildOrThrow()).containsExactly("container_id", + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7", "namespace_name", + "test-namespace1", "pod_name", "test-hostname2"); + } + + @Test + public void testNonKubernetesInstanceValues() throws IOException { + String namespaceFilePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"; + File hostnameFile = hostnameFolder.newFile(); + File cgroupFile = cgroupFolder.newFile(); + + Files.write("test-hostname2\n".getBytes(StandardCharsets.UTF_8), hostnameFile); + Files.write(FILE_CONTENTS.getBytes(StandardCharsets.UTF_8), cgroupFile); + + ImmutableMap.Builder locationTags = ImmutableMap.builder(); + GlobalLocationTags.populateFromKubernetesValues(locationTags, + namespaceFilePath, hostnameFile.getAbsolutePath(), cgroupFile.getAbsolutePath()); + assertThat(locationTags.buildOrThrow()).containsExactly("container_id", + "fe61ca6482b58f4a9831d08d6ea15db25f9fd19b4be19a54df8c6c0eab8742b7", + "pod_name", "test-hostname2"); + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java new file mode 100644 index 00000000000..992ccc5dbf5 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java @@ -0,0 +1,250 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.StaticTestingClassLoader; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; +import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor; +import io.grpc.gcp.observability.interceptors.LogHelper; +import io.grpc.gcp.observability.logging.GcpLogSink; +import io.grpc.gcp.observability.logging.Sink; +import io.grpc.observabilitylog.v1.GrpcLogRecord; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.IOException; +import java.util.Collections; +import java.util.regex.Pattern; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class LoggingTest { + + @ClassRule + public static final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + private static final String PROJECT_ID = "PROJECT"; + private static final ImmutableMap LOCATION_TAGS = ImmutableMap.of( + "project_id", "PROJECT", + "location", "us-central1-c", + "cluster_name", "grpc-observability-cluster", + "namespace_name", "default" , + "pod_name", "app1-6c7c58f897-n92c5"); + private static final ImmutableMap CUSTOM_TAGS = ImmutableMap.of( + "KEY1", "Value1", + "KEY2", "VALUE2"); + + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader(getClass().getClassLoader(), Pattern.compile("io\\.grpc\\..*")); + + /** + * Cloud logging test using GlobalInterceptors. + * + *

Ignoring test, because it calls external Cloud Logging APIs. + * To test cloud logging setup locally, + * 1. Set up Cloud auth credentials + * 2. Assign permissions to service account to write logs to project specified by + * variable PROJECT_ID + * 3. Comment @Ignore annotation + * 4. This test is expected to pass when ran with above setup. This has been verified manually. + *

+ */ + @Ignore + @Test + public void clientServer_interceptorCalled_logAlways() throws Exception { + Class runnable = + classLoader.loadClass(LoggingTest.StaticTestingClassEndtoEndLogging.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void clientServer_interceptorCalled_logNever() throws Exception { + Class runnable = + classLoader.loadClass(LoggingTest.StaticTestingClassLogNever.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void clientServer_interceptorCalled_logEvents_usingMockSink() throws Exception { + Class runnable = + classLoader.loadClass(StaticTestingClassLogEventsUsingMockSink.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassEndtoEndLogging implements Runnable { + + @Override + public void run() { + Sink sink = + new GcpLogSink( + PROJECT_ID, LOCATION_TAGS, CUSTOM_TAGS, Collections.emptySet()); + ObservabilityConfig config = mock(ObservabilityConfig.class); + LogHelper spyLogHelper = spy(new LogHelper(sink)); + ConfigFilterHelper mockFilterHelper = mock(ConfigFilterHelper.class); + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = + new InternalLoggingChannelInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); + InternalLoggingServerInterceptor.Factory serverInterceptorFactory = + new InternalLoggingServerInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); + + when(config.isEnableCloudLogging()).thenReturn(true); + FilterParams logAlwaysFilterParams = FilterParams.create(true, 1024, 10); + when(mockFilterHelper.logRpcMethod(anyString(), eq(true))) + .thenReturn(logAlwaysFilterParams); + when(mockFilterHelper.logRpcMethod(anyString(), eq(false))) + .thenReturn(logAlwaysFilterParams); + + try (GcpObservability unused = + GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory)) { + Server server = + ServerBuilder.forPort(0) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) + .build() + .start(); + int port = cleanupRule.register(server).getPort(); + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub( + cleanupRule.register( + ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + .isEqualTo("Hello buddy"); + assertThat(Mockito.mockingDetails(spyLogHelper).getInvocations().size()).isGreaterThan(11); + } catch (IOException e) { + throw new AssertionError("Exception while testing logging", e); + } + } + } + + public static final class StaticTestingClassLogNever implements Runnable { + + @Override + public void run() { + Sink mockSink = mock(GcpLogSink.class); + ObservabilityConfig config = mock(ObservabilityConfig.class); + LogHelper spyLogHelper = spy(new LogHelper(mockSink)); + ConfigFilterHelper mockFilterHelper = mock(ConfigFilterHelper.class); + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = + new InternalLoggingChannelInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); + InternalLoggingServerInterceptor.Factory serverInterceptorFactory = + new InternalLoggingServerInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); + + when(config.isEnableCloudLogging()).thenReturn(true); + FilterParams logNeverFilterParams = FilterParams.create(false, 0, 0); + when(mockFilterHelper.logRpcMethod(anyString(), eq(true))) + .thenReturn(logNeverFilterParams); + when(mockFilterHelper.logRpcMethod(anyString(), eq(false))) + .thenReturn(logNeverFilterParams); + + try (GcpObservability unused = + GcpObservability.grpcInit( + mockSink, config, channelInterceptorFactory, serverInterceptorFactory)) { + Server server = + ServerBuilder.forPort(0) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) + .build() + .start(); + int port = cleanupRule.register(server).getPort(); + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub( + cleanupRule.register( + ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + .isEqualTo("Hello buddy"); + verifyNoInteractions(spyLogHelper); + verifyNoInteractions(mockSink); + } catch (IOException e) { + throw new AssertionError("Exception while testing logging event filter", e); + } + } + } + + public static final class StaticTestingClassLogEventsUsingMockSink implements Runnable { + + @Override + public void run() { + Sink mockSink = mock(GcpLogSink.class); + ObservabilityConfig config = mock(ObservabilityConfig.class); + LogHelper spyLogHelper = spy(new LogHelper(mockSink)); + ConfigFilterHelper mockFilterHelper2 = mock(ConfigFilterHelper.class); + InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = + new InternalLoggingChannelInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper2); + InternalLoggingServerInterceptor.Factory serverInterceptorFactory = + new InternalLoggingServerInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper2); + + when(config.isEnableCloudLogging()).thenReturn(true); + FilterParams logAlwaysFilterParams = FilterParams.create(true, 0, 0); + when(mockFilterHelper2.logRpcMethod(anyString(), eq(true))) + .thenReturn(logAlwaysFilterParams); + when(mockFilterHelper2.logRpcMethod(anyString(), eq(false))) + .thenReturn(logAlwaysFilterParams); + + try (GcpObservability observability = + GcpObservability.grpcInit( + mockSink, config, channelInterceptorFactory, serverInterceptorFactory)) { + Server server = + ServerBuilder.forPort(0) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) + .build() + .start(); + int port = cleanupRule.register(server).getPort(); + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub( + cleanupRule.register( + ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + .isEqualTo("Hello buddy"); + // Total number of calls should have been 14 (6 from client and 6 from server) + // Since cancel is not invoked, it will be 12. + // Request message(Total count:2 (1 from client and 1 from server) and Response + // message(count:2) + // events are not in the event_types list, i.e 14 - 2(cancel) - 2(req_msg) - 2(resp_msg) + // = 8 + assertThat(Mockito.mockingDetails(mockSink).getInvocations().size()).isEqualTo(12); + ArgumentCaptor captor = ArgumentCaptor.forClass(GrpcLogRecord.class); + verify(mockSink, times(12)).write(captor.capture()); + for (GrpcLogRecord record : captor.getAllValues()) { + assertThat(record.getType()).isInstanceOf(GrpcLogRecord.EventType.class); + assertThat(record.getLogger()).isInstanceOf(GrpcLogRecord.EventLogger.class); + } + } catch (IOException e) { + throw new AssertionError("Exception while testing logging using mock sink", e); + } + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/MetadataConfigTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/MetadataConfigTest.java new file mode 100644 index 00000000000..9a4bf2057c9 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/MetadataConfigTest.java @@ -0,0 +1,56 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.when; + +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import com.google.auth.http.HttpTransportFactory; +import java.io.IOException; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +@RunWith(JUnit4.class) +public class MetadataConfigTest { + + @Mock HttpTransportFactory httpTransportFactory; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testGetAttribute() throws IOException { + MockHttpTransport.Builder builder = new MockHttpTransport.Builder(); + MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); + response.setContent("foo"); + builder.setLowLevelHttpResponse(response); + MockHttpTransport httpTransport = builder.build(); + when(httpTransportFactory.create()).thenReturn(httpTransport); + MetadataConfig metadataConfig = new MetadataConfig(httpTransportFactory); + metadataConfig.init(); + String val = metadataConfig.getAttribute("instance/attributes/cluster-name"); + assertThat(val).isEqualTo("foo"); + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java new file mode 100644 index 00000000000..046799cc9d2 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java @@ -0,0 +1,159 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.monitoring.v3.MetricServiceClient; +import com.google.cloud.monitoring.v3.MetricServiceClient.ListTimeSeriesPagedResponse; +import com.google.monitoring.v3.ListTimeSeriesRequest; +import com.google.monitoring.v3.ProjectName; +import com.google.monitoring.v3.TimeInterval; +import com.google.monitoring.v3.TimeSeries; +import com.google.protobuf.util.Timestamps; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.StaticTestingClassLoader; +import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor; +import io.grpc.gcp.observability.logging.GcpLogSink; +import io.grpc.gcp.observability.logging.Sink; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetricsTest { + + @ClassRule + public static final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + private static final String PROJECT_ID = "PROJECT"; + private static final String TEST_CLIENT_METHOD = "grpc.testing.SimpleService/UnaryRpc"; + private static final String CUSTOM_TAG_KEY = "Version"; + private static final String CUSTOM_TAG_VALUE = + String.format("C67J9A-%s", String.valueOf(System.currentTimeMillis())); + private static final Map CUSTOM_TAGS = Collections.singletonMap(CUSTOM_TAG_KEY, + CUSTOM_TAG_VALUE); + + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader(getClass().getClassLoader(), + Pattern.compile("io\\.grpc\\..*|io\\.opencensus\\..*")); + + /** + * End to end cloud monitoring test. + * + *

Ignoring test, because it calls external Cloud Monitoring APIs. To test cloud monitoring + * setup locally, + * 1. Set up Cloud auth credentials + * 2. Assign permissions to service account to write metrics to project specified by variable + * PROJECT_ID + * 3. Comment @Ignore annotation + * 4. This test is expected to pass when ran with above setup. This has been verified manually. + */ + @Ignore + @Test + public void testMetricsExporter() throws Exception { + Class runnable = + classLoader.loadClass(MetricsTest.StaticTestingClassTestMetricsExporter.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + public static final class StaticTestingClassTestMetricsExporter implements Runnable { + + @Override + public void run() { + Sink mockSink = mock(GcpLogSink.class); + ObservabilityConfig mockConfig = mock(ObservabilityConfig.class); + InternalLoggingChannelInterceptor.Factory mockChannelInterceptorFactory = + mock(InternalLoggingChannelInterceptor.Factory.class); + InternalLoggingServerInterceptor.Factory mockServerInterceptorFactory = + mock(InternalLoggingServerInterceptor.Factory.class); + + when(mockConfig.isEnableCloudMonitoring()).thenReturn(true); + when(mockConfig.getProjectId()).thenReturn(PROJECT_ID); + + try { + GcpObservability observability = + GcpObservability.grpcInit( + mockSink, mockConfig, mockChannelInterceptorFactory, mockServerInterceptorFactory); + observability.registerStackDriverExporter(PROJECT_ID, CUSTOM_TAGS); + + Server server = + ServerBuilder.forPort(0) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) + .build() + .start(); + int port = cleanupRule.register(server).getPort(); + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub( + cleanupRule.register( + ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + .isEqualTo("Hello buddy"); + // Adding sleep to ensure metrics are exported before querying cloud monitoring backend + TimeUnit.SECONDS.sleep(40); + + // This checks Cloud monitoring for the new metrics that was just exported. + MetricServiceClient metricServiceClient = MetricServiceClient.create(); + // Restrict time to last 1 minute + long startMillis = System.currentTimeMillis() - ((60 * 1) * 1000); + TimeInterval interval = + TimeInterval.newBuilder() + .setStartTime(Timestamps.fromMillis(startMillis)) + .setEndTime(Timestamps.fromMillis(System.currentTimeMillis())) + .build(); + // Timeseries data + String metricsFilter = + String.format( + "metric.type=\"custom.googleapis.com/opencensus/grpc.io/client/completed_rpcs\"" + + " AND metric.labels.grpc_client_method=\"%s\"" + + " AND metric.labels.%s=%s", + TEST_CLIENT_METHOD, CUSTOM_TAG_KEY, CUSTOM_TAG_VALUE); + ListTimeSeriesRequest metricsRequest = + ListTimeSeriesRequest.newBuilder() + .setName(ProjectName.of(PROJECT_ID).toString()) + .setFilter(metricsFilter) + .setInterval(interval) + .build(); + ListTimeSeriesPagedResponse response = metricServiceClient.listTimeSeries(metricsRequest); + assertThat(response.iterateAll()).isNotEmpty(); + for (TimeSeries ts : response.iterateAll()) { + assertThat(ts.getMetric().getLabelsMap().get("grpc_client_method")) + .isEqualTo(TEST_CLIENT_METHOD); + assertThat(ts.getMetric().getLabelsMap().get("grpc_client_status")).isEqualTo("OK"); + assertThat(ts.getPoints(0).getValue().getInt64Value()).isEqualTo(1); + } + observability.close(); + } catch (IOException | InterruptedException e) { + throw new AssertionError("Exception while testing metrics", e); + } + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java new file mode 100644 index 00000000000..d6f23fbcc9a --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java @@ -0,0 +1,491 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.base.Charsets; +import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; +import io.opencensus.trace.Sampler; +import io.opencensus.trace.samplers.Samplers; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ObservabilityConfigImplTest { + private static final String LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"*\"],\n" + + " \"max_metadata_bytes\": 4096\n" + + " }" + + " ],\n" + + " \"server_rpc_events\": [{\n" + + " \"methods\": [\"*\"],\n" + + " \"max_metadata_bytes\": 32,\n" + + " \"max_message_bytes\": 64\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + private static final String CLIENT_LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"*\"],\n" + + " \"max_metadata_bytes\": 4096,\n" + + " \"max_message_bytes\": 2048\n" + + " }," + + " {\n" + + " \"methods\": [\"service1/Method2\", \"Service2/*\"],\n" + + " \"exclude\": true\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + private static final String SERVER_LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"server_rpc_events\": [{\n" + + " \"methods\": [\"service1/method4\", \"service2/method234\"],\n" + + " \"max_metadata_bytes\": 32,\n" + + " \"max_message_bytes\": 64\n" + + " }," + + " {\n" + + " \"methods\": [\"service4/*\", \"Service2/*\"],\n" + + " \"exclude\": true\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + private static final String VALID_LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"server_rpc_events\": [{\n" + + " \"methods\": [\"service.Service1/*\", \"service2.Service4/method4\"],\n" + + " \"max_metadata_bytes\": 16,\n" + + " \"max_message_bytes\": 64\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + + private static final String PROJECT_ID = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {},\n" + + " \"project_id\": \"grpc-testing\"\n" + + "}"; + + private static final String EMPTY_CONFIG = "{}"; + + private static final String ENABLE_CLOUD_MONITORING_AND_TRACING = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_monitoring\": {},\n" + + " \"cloud_trace\": {}\n" + + "}"; + + private static final String ENABLE_CLOUD_MONITORING = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_monitoring\": {}\n" + + "}"; + + private static final String ENABLE_CLOUD_TRACE = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {}\n" + + "}"; + + private static final String TRACING_ALWAYS_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": 1.00\n" + + " }\n" + + "}"; + + private static final String TRACING_NEVER_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": 0.00\n" + + " }\n" + + "}"; + + private static final String TRACING_PROBABILISTIC_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": 0.75\n" + + " }\n" + + "}"; + + private static final String TRACING_DEFAULT_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {}\n" + + "}"; + + private static final String GLOBAL_TRACING_BAD_PROBABILISTIC_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": -0.75\n" + + " }\n" + + "}"; + + private static final String CUSTOM_TAGS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {},\n" + + " \"labels\": {\n" + + " \"SOURCE_VERSION\" : \"J2e1Cf\",\n" + + " \"SERVICE_NAME\" : \"payment-service\",\n" + + " \"ENTRYPOINT_SCRIPT\" : \"entrypoint.sh\"\n" + + " }\n" + + "}"; + + private static final String BAD_CUSTOM_TAGS = + "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_monitoring\": {},\n" + + " \"labels\": {\n" + + " \"SOURCE_VERSION\" : \"J2e1Cf\",\n" + + " \"SERVICE_NAME\" : { \"SUB_SERVICE_NAME\" : \"payment-service\"},\n" + + " \"ENTRYPOINT_SCRIPT\" : \"entrypoint.sh\"\n" + + " }\n" + + "}"; + + private static final String LOG_FILTER_GLOBAL_EXCLUDE = + "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"service1/Method2\", \"*\"],\n" + + " \"max_metadata_bytes\": 20,\n" + + " \"max_message_bytes\": 15,\n" + + " \"exclude\": true\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + private static final String LOG_FILTER_INVALID_METHOD = + "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"s*&%ervice1/Method2\", \"*\"],\n" + + " \"max_metadata_bytes\": 20\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + ObservabilityConfigImpl observabilityConfig = new ObservabilityConfigImpl(); + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void nullConfig() throws IOException { + try { + observabilityConfig.parse(null); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo("GRPC_GCP_OBSERVABILITY_CONFIG value is null!"); + } + } + + @Test + public void emptyConfig() throws IOException { + observabilityConfig.parse(EMPTY_CONFIG); + assertFalse(observabilityConfig.isEnableCloudLogging()); + assertFalse(observabilityConfig.isEnableCloudMonitoring()); + assertFalse(observabilityConfig.isEnableCloudTracing()); + assertThat(observabilityConfig.getClientLogFilters()).isEmpty(); + assertThat(observabilityConfig.getServerLogFilters()).isEmpty(); + assertThat(observabilityConfig.getSampler()).isNull(); + assertThat(observabilityConfig.getProjectId()).isNull(); + assertThat(observabilityConfig.getCustomTags()).isEmpty(); + } + + @Test + public void emptyConfigFile() throws IOException { + File configFile = tempFolder.newFile(); + try { + observabilityConfig.parseFile(configFile.getAbsolutePath()); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo( + "GRPC_GCP_OBSERVABILITY_CONFIG_FILE is empty!"); + } + } + + @Test + public void setProjectId() throws IOException { + observabilityConfig.parse(PROJECT_ID); + assertTrue(observabilityConfig.isEnableCloudLogging()); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + } + + @Test + public void logFilters() throws IOException { + observabilityConfig.parse(LOG_FILTERS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + + List clientLogFilters = observabilityConfig.getClientLogFilters(); + assertThat(clientLogFilters).hasSize(1); + assertThat(clientLogFilters.get(0).headerBytes).isEqualTo(4096); + assertThat(clientLogFilters.get(0).messageBytes).isEqualTo(0); + assertThat(clientLogFilters.get(0).excludePattern).isFalse(); + assertThat(clientLogFilters.get(0).matchAll).isTrue(); + assertThat(clientLogFilters.get(0).services).isEmpty(); + assertThat(clientLogFilters.get(0).methods).isEmpty(); + + List serverLogFilters = observabilityConfig.getServerLogFilters(); + assertThat(serverLogFilters).hasSize(1); + assertThat(serverLogFilters.get(0).headerBytes).isEqualTo(32); + assertThat(serverLogFilters.get(0).messageBytes).isEqualTo(64); + assertThat(serverLogFilters.get(0).excludePattern).isFalse(); + assertThat(serverLogFilters.get(0).matchAll).isTrue(); + assertThat(serverLogFilters.get(0).services).isEmpty(); + assertThat(serverLogFilters.get(0).methods).isEmpty(); + } + + @Test + public void setClientLogFilters() throws IOException { + observabilityConfig.parse(CLIENT_LOG_FILTERS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + List logFilterList = observabilityConfig.getClientLogFilters(); + assertThat(logFilterList).hasSize(2); + assertThat(logFilterList.get(0).headerBytes).isEqualTo(4096); + assertThat(logFilterList.get(0).messageBytes).isEqualTo(2048); + assertThat(logFilterList.get(0).excludePattern).isFalse(); + assertThat(logFilterList.get(0).matchAll).isTrue(); + assertThat(logFilterList.get(0).services).isEmpty(); + assertThat(logFilterList.get(0).methods).isEmpty(); + + assertThat(logFilterList.get(1).headerBytes).isEqualTo(0); + assertThat(logFilterList.get(1).messageBytes).isEqualTo(0); + assertThat(logFilterList.get(1).excludePattern).isTrue(); + assertThat(logFilterList.get(1).matchAll).isFalse(); + assertThat(logFilterList.get(1).services).isEqualTo(Collections.singleton("Service2")); + assertThat(logFilterList.get(1).methods) + .isEqualTo(Collections.singleton("service1/Method2")); + } + + @Test + public void setServerLogFilters() throws IOException { + Set expectedMethods = Stream.of("service1/method4", "service2/method234") + .collect(Collectors.toCollection(HashSet::new)); + observabilityConfig.parse(SERVER_LOG_FILTERS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + List logFilterList = observabilityConfig.getServerLogFilters(); + assertThat(logFilterList).hasSize(2); + assertThat(logFilterList.get(0).headerBytes).isEqualTo(32); + assertThat(logFilterList.get(0).messageBytes).isEqualTo(64); + assertThat(logFilterList.get(0).excludePattern).isFalse(); + assertThat(logFilterList.get(0).matchAll).isFalse(); + assertThat(logFilterList.get(0).services).isEmpty(); + assertThat(logFilterList.get(0).methods) + .isEqualTo(expectedMethods); + + Set expectedServices = Stream.of("service4", "Service2") + .collect(Collectors.toCollection(HashSet::new)); + assertThat(logFilterList.get(1).headerBytes).isEqualTo(0); + assertThat(logFilterList.get(1).messageBytes).isEqualTo(0); + assertThat(logFilterList.get(1).excludePattern).isTrue(); + assertThat(logFilterList.get(1).matchAll).isFalse(); + assertThat(logFilterList.get(1).services).isEqualTo(expectedServices); + assertThat(logFilterList.get(1).methods).isEmpty(); + } + + @Test + public void enableCloudMonitoring() throws IOException { + observabilityConfig.parse(ENABLE_CLOUD_MONITORING); + assertTrue(observabilityConfig.isEnableCloudMonitoring()); + } + + @Test + public void enableCloudTracing() throws IOException { + observabilityConfig.parse(ENABLE_CLOUD_TRACE); + assertTrue(observabilityConfig.isEnableCloudTracing()); + } + + @Test + public void enableCloudMonitoringAndTracing() throws IOException { + observabilityConfig.parse(ENABLE_CLOUD_MONITORING_AND_TRACING); + assertFalse(observabilityConfig.isEnableCloudLogging()); + assertTrue(observabilityConfig.isEnableCloudMonitoring()); + assertTrue(observabilityConfig.isEnableCloudTracing()); + } + + @Test + public void alwaysSampler() throws IOException { + observabilityConfig.parse(TRACING_ALWAYS_SAMPLER); + assertTrue(observabilityConfig.isEnableCloudTracing()); + Sampler sampler = observabilityConfig.getSampler(); + assertThat(sampler).isNotNull(); + assertThat(sampler).isEqualTo(Samplers.alwaysSample()); + } + + @Test + public void neverSampler() throws IOException { + observabilityConfig.parse(TRACING_NEVER_SAMPLER); + assertTrue(observabilityConfig.isEnableCloudTracing()); + Sampler sampler = observabilityConfig.getSampler(); + assertThat(sampler).isNotNull(); + assertThat(sampler).isEqualTo(Samplers.probabilitySampler(0.0)); + } + + @Test + public void probabilisticSampler() throws IOException { + observabilityConfig.parse(TRACING_PROBABILISTIC_SAMPLER); + assertTrue(observabilityConfig.isEnableCloudTracing()); + Sampler sampler = observabilityConfig.getSampler(); + assertThat(sampler).isNotNull(); + assertThat(sampler).isEqualTo(Samplers.probabilitySampler(0.75)); + } + + @Test + public void defaultSampler() throws IOException { + observabilityConfig.parse(TRACING_DEFAULT_SAMPLER); + assertTrue(observabilityConfig.isEnableCloudTracing()); + Sampler sampler = observabilityConfig.getSampler(); + assertThat(sampler).isNotNull(); + assertThat(sampler).isEqualTo(Samplers.probabilitySampler(0.00)); + } + + @Test + public void badProbabilisticSampler_error() throws IOException { + try { + observabilityConfig.parse(GLOBAL_TRACING_BAD_PROBABILISTIC_SAMPLER); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo( + "'sampling_rate' needs to be between [0.0, 1.0]"); + } + } + + @Test + public void configFileLogFilters() throws Exception { + File configFile = tempFolder.newFile(); + Files.write( + Paths.get(configFile.getAbsolutePath()), CLIENT_LOG_FILTERS.getBytes(Charsets.US_ASCII)); + observabilityConfig.parseFile(configFile.getAbsolutePath()); + assertTrue(observabilityConfig.isEnableCloudLogging()); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + List logFilters = observabilityConfig.getClientLogFilters(); + assertThat(logFilters).hasSize(2); + assertThat(logFilters.get(0).headerBytes).isEqualTo(4096); + assertThat(logFilters.get(0).messageBytes).isEqualTo(2048); + assertThat(logFilters.get(1).headerBytes).isEqualTo(0); + assertThat(logFilters.get(1).messageBytes).isEqualTo(0); + + assertThat(logFilters).hasSize(2); + assertThat(logFilters.get(0).headerBytes).isEqualTo(4096); + assertThat(logFilters.get(0).messageBytes).isEqualTo(2048); + assertThat(logFilters.get(0).excludePattern).isFalse(); + assertThat(logFilters.get(0).matchAll).isTrue(); + assertThat(logFilters.get(0).services).isEmpty(); + assertThat(logFilters.get(0).methods).isEmpty(); + + assertThat(logFilters.get(1).headerBytes).isEqualTo(0); + assertThat(logFilters.get(1).messageBytes).isEqualTo(0); + assertThat(logFilters.get(1).excludePattern).isTrue(); + assertThat(logFilters.get(1).matchAll).isFalse(); + assertThat(logFilters.get(1).services).isEqualTo(Collections.singleton("Service2")); + assertThat(logFilters.get(1).methods) + .isEqualTo(Collections.singleton("service1/Method2")); + } + + @Test + public void customTags() throws IOException { + observabilityConfig.parse(CUSTOM_TAGS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + Map customTags = observabilityConfig.getCustomTags(); + assertThat(customTags).hasSize(3); + assertThat(customTags).containsEntry("SOURCE_VERSION", "J2e1Cf"); + assertThat(customTags).containsEntry("SERVICE_NAME", "payment-service"); + assertThat(customTags).containsEntry("ENTRYPOINT_SCRIPT", "entrypoint.sh"); + } + + @Test + public void badCustomTags() throws IOException { + try { + observabilityConfig.parse(BAD_CUSTOM_TAGS); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo( + "'labels' needs to be a map of "); + } + } + + @Test + public void globalLogFilterExclude() throws IOException { + try { + observabilityConfig.parse(LOG_FILTER_GLOBAL_EXCLUDE); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo( + "cannot have 'exclude' and '*' wildcard in the same filter"); + } + } + + @Test + public void logFilterInvalidMethod() throws IOException { + try { + observabilityConfig.parse(LOG_FILTER_INVALID_METHOD); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).contains( + "invalid service or method filter"); + } + } + + @Test + public void validLogFilter() throws Exception { + observabilityConfig.parse(VALID_LOG_FILTERS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + List logFilterList = observabilityConfig.getServerLogFilters(); + assertThat(logFilterList).hasSize(1); + assertThat(logFilterList.get(0).headerBytes).isEqualTo(16); + assertThat(logFilterList.get(0).messageBytes).isEqualTo(64); + assertThat(logFilterList.get(0).excludePattern).isFalse(); + assertThat(logFilterList.get(0).matchAll).isFalse(); + assertThat(logFilterList.get(0).services).isEqualTo(Collections.singleton("service.Service1")); + assertThat(logFilterList.get(0).methods) + .isEqualTo(Collections.singleton("service2.Service4/method4")); + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityTestHelper.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityTestHelper.java new file mode 100644 index 00000000000..ebb73ec76a1 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityTestHelper.java @@ -0,0 +1,45 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; + +public class ObservabilityTestHelper { + + static String makeUnaryRpcViaClientStub( + String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { + SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build(); + SimpleResponse response = blockingStub.unaryRpc(request); + return response.getResponseMessage(); + } + + static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + + @Override + public void unaryRpc(SimpleRequest req, StreamObserver responseObserver) { + SimpleResponse response = + SimpleResponse.newBuilder() + .setResponseMessage("Hello " + req.getRequestMessage()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java new file mode 100644 index 00000000000..ae7aa63befc --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java @@ -0,0 +1,165 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.trace.v1.TraceServiceClient; +import com.google.cloud.trace.v1.TraceServiceClient.ListTracesPagedResponse; +import com.google.devtools.cloudtrace.v1.GetTraceRequest; +import com.google.devtools.cloudtrace.v1.ListTracesRequest; +import com.google.devtools.cloudtrace.v1.Trace; +import com.google.devtools.cloudtrace.v1.TraceSpan; +import com.google.protobuf.util.Timestamps; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.StaticTestingClassLoader; +import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor; +import io.grpc.gcp.observability.logging.GcpLogSink; +import io.grpc.gcp.observability.logging.Sink; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.opencensus.trace.samplers.Samplers; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TracesTest { + + @ClassRule + public static final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + private static final String PROJECT_ID = "PROJECT"; + private static final String CUSTOM_TAG_KEY = "service"; + private static final String CUSTOM_TAG_VALUE = + String.format("payment-%s", String.valueOf(System.currentTimeMillis())); + private static final Map CUSTOM_TAGS = + Collections.singletonMap(CUSTOM_TAG_KEY, CUSTOM_TAG_VALUE); + + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader(getClass().getClassLoader(), + Pattern.compile("io\\.grpc\\..*|io\\.opencensus\\..*")); + + /** + * End to end cloud trace test. + * + *

Ignoring test, because it calls external Cloud Tracing APIs. To test cloud trace setup + * locally, + * 1. Set up Cloud auth credentials + * 2. Assign permissions to service account to write traces to project specified by variable + * PROJECT_ID + * 3. Comment @Ignore annotation + * 4. This test is expected to pass when ran with above setup. This has been verified manually. + */ + @Ignore + @Test + public void testTracesExporter() throws Exception { + Class runnable = + classLoader.loadClass(TracesTest.StaticTestingClassTestTracesExporter.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + public static final class StaticTestingClassTestTracesExporter implements Runnable { + + @Override + public void run() { + Sink mockSink = mock(GcpLogSink.class); + ObservabilityConfig mockConfig = mock(ObservabilityConfig.class); + InternalLoggingChannelInterceptor.Factory mockChannelInterceptorFactory = + mock(InternalLoggingChannelInterceptor.Factory.class); + InternalLoggingServerInterceptor.Factory mockServerInterceptorFactory = + mock(InternalLoggingServerInterceptor.Factory.class); + + when(mockConfig.isEnableCloudTracing()).thenReturn(true); + when(mockConfig.getSampler()).thenReturn(Samplers.alwaysSample()); + when(mockConfig.getProjectId()).thenReturn(PROJECT_ID); + + try { + GcpObservability observability = + GcpObservability.grpcInit( + mockSink, mockConfig, mockChannelInterceptorFactory, mockServerInterceptorFactory); + observability.registerStackDriverExporter(PROJECT_ID, CUSTOM_TAGS); + + Server server = + ServerBuilder.forPort(0) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) + .build() + .start(); + int port = cleanupRule.register(server).getPort(); + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub( + cleanupRule.register( + ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + .isEqualTo("Hello buddy"); + // Adding sleep to ensure traces are exported before querying cloud tracing backend + TimeUnit.SECONDS.sleep(10); + + TraceServiceClient traceServiceClient = TraceServiceClient.create(); + String traceFilter = + String.format( + "span:Sent.grpc.testing.SimpleService +%s:%s", CUSTOM_TAG_KEY, CUSTOM_TAG_VALUE); + String traceOrder = "start"; + // Restrict time to last 1 minute + long startMillis = System.currentTimeMillis() - ((60 * 1) * 1000); + ListTracesRequest traceRequest = + ListTracesRequest.newBuilder() + .setProjectId(PROJECT_ID) + .setStartTime(Timestamps.fromMillis(startMillis)) + .setEndTime(Timestamps.fromMillis(System.currentTimeMillis())) + .setFilter(traceFilter) + .setOrderBy(traceOrder) + .build(); + ListTracesPagedResponse traceResponse = traceServiceClient.listTraces(traceRequest); + assertThat(traceResponse.iterateAll()).isNotEmpty(); + List traceIdList = new ArrayList<>(); + for (Trace t : traceResponse.iterateAll()) { + traceIdList.add(t.getTraceId()); + } + + for (String traceId : traceIdList) { + // This checks Cloud trace for the new trace that was just created. + GetTraceRequest getTraceRequest = + GetTraceRequest.newBuilder().setProjectId(PROJECT_ID).setTraceId(traceId).build(); + Trace trace = traceServiceClient.getTrace(getTraceRequest); + assertThat(trace.getSpansList()).hasSize(3); + for (TraceSpan span : trace.getSpansList()) { + assertThat(span.getName()).contains("grpc.testing.SimpleService.UnaryRpc"); + assertThat(span.getLabelsMap().get(CUSTOM_TAG_KEY)).isEqualTo(CUSTOM_TAG_VALUE); + } + } + observability.close(); + } catch (IOException | InterruptedException e) { + throw new AssertionError("Exception while testing traces", e); + } + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConditionalClientInterceptorTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConditionalClientInterceptorTest.java new file mode 100644 index 00000000000..5fdd2e185bd --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConditionalClientInterceptorTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import java.util.function.BiPredicate; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link ConditionalClientInterceptor}. + */ +@RunWith(JUnit4.class) +public class ConditionalClientInterceptorTest { + + private ConditionalClientInterceptor conditionalClientInterceptor; + @Mock private ClientInterceptor delegate; + @Mock private BiPredicate, CallOptions> predicate; + @Mock private Channel channel; + @Mock private ClientCall returnedCall; + private MethodDescriptor method; + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + conditionalClientInterceptor = new ConditionalClientInterceptor( + delegate, predicate); + method = MethodDescriptor.newBuilder().setType(MethodType.UNARY) + .setFullMethodName("service/method") + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build(); + } + + @Test + @SuppressWarnings("unchecked") + public void predicateFalse() { + when(predicate.test(any(MethodDescriptor.class), any(CallOptions.class))).thenReturn(false); + doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT); + ClientCall clientCall = conditionalClientInterceptor.interceptCall(method, + CallOptions.DEFAULT, channel); + assertThat(clientCall).isSameInstanceAs(returnedCall); + verify(delegate, never()).interceptCall(any(MethodDescriptor.class), any(CallOptions.class), + any(Channel.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void predicateTrue() { + when(predicate.test(any(MethodDescriptor.class), any(CallOptions.class))).thenReturn(true); + doReturn(returnedCall).when(delegate).interceptCall(method, CallOptions.DEFAULT, channel); + ClientCall clientCall = conditionalClientInterceptor.interceptCall(method, + CallOptions.DEFAULT, channel); + assertThat(clientCall).isSameInstanceAs(returnedCall); + verify(channel, never()).newCall(any(MethodDescriptor.class), any(CallOptions.class)); + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java new file mode 100644 index 00000000000..971e6070777 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.grpc.gcp.observability.ObservabilityConfig; +import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; +import java.util.Collections; +import java.util.List; +import org.junit.Before; +import org.junit.Test; + +public class ConfigFilterHelperTest { + private static final ImmutableList configLogFilters = + ImmutableList.of( + new LogFilter(Collections.emptySet(), Collections.singleton("service1/Method2"), false, + 1024, 1024, false), + new LogFilter( + Collections.singleton("service2"), Collections.singleton("service4/method2"), false, + 2048, 1024, false), + new LogFilter( + Collections.singleton("service2"), Collections.singleton("service4/method3"), false, + 2048, 1024, true), + new LogFilter( + Collections.emptySet(), Collections.emptySet(), true, + 128, 128, false)); + + private ObservabilityConfig mockConfig; + private ConfigFilterHelper configFilterHelper; + + @Before + public void setup() { + mockConfig = mock(ObservabilityConfig.class); + configFilterHelper = ConfigFilterHelper.getInstance(mockConfig); + } + + @Test + public void enableCloudLogging_withoutLogFilters() { + when(mockConfig.isEnableCloudLogging()).thenReturn(true); + assertThat(mockConfig.getClientLogFilters()).isEmpty(); + assertThat(mockConfig.getServerLogFilters()).isEmpty(); + } + + @Test + public void checkMethodLog_withoutLogFilters() { + when(mockConfig.isEnableCloudLogging()).thenReturn(true); + assertThat(mockConfig.getClientLogFilters()).isEmpty(); + assertThat(mockConfig.getServerLogFilters()).isEmpty(); + + FilterParams expectedParams = + FilterParams.create(false, 0, 0); + FilterParams clientResultParams + = configFilterHelper.logRpcMethod("service3/Method3", true); + assertThat(clientResultParams).isEqualTo(expectedParams); + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service3/Method3", false); + assertThat(serverResultParams).isEqualTo(expectedParams); + } + + @Test + public void checkMethodAlwaysLogged() { + List sampleLogFilters = + ImmutableList.of( + new LogFilter( + Collections.emptySet(), Collections.emptySet(), true, + 4096, 4096, false)); + when(mockConfig.getClientLogFilters()).thenReturn(sampleLogFilters); + when(mockConfig.getServerLogFilters()).thenReturn(sampleLogFilters); + + FilterParams expectedParams = + FilterParams.create(true, 4096, 4096); + FilterParams clientResultParams + = configFilterHelper.logRpcMethod("service1/Method6", true); + assertThat(clientResultParams).isEqualTo(expectedParams); + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service1/Method6", false); + assertThat(serverResultParams).isEqualTo(expectedParams); + } + + @Test + public void checkMethodNotToBeLogged() { + List sampleLogFilters = + ImmutableList.of( + new LogFilter(Collections.emptySet(), Collections.singleton("service2/*"), false, + 1024, 1024, true), + new LogFilter( + Collections.singleton("service2/Method1"), Collections.emptySet(), false, + 2048, 1024, false)); + when(mockConfig.getClientLogFilters()).thenReturn(sampleLogFilters); + when(mockConfig.getServerLogFilters()).thenReturn(sampleLogFilters); + + FilterParams expectedParams = + FilterParams.create(false, 0, 0); + FilterParams clientResultParams1 + = configFilterHelper.logRpcMethod("service3/Method3", true); + assertThat(clientResultParams1).isEqualTo(expectedParams); + + FilterParams clientResultParams2 + = configFilterHelper.logRpcMethod("service2/Method1", true); + assertThat(clientResultParams2).isEqualTo(expectedParams); + + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service2/Method1", false); + assertThat(serverResultParams).isEqualTo(expectedParams); + } + + @Test + public void checkMethodToBeLoggedConditional() { + when(mockConfig.getClientLogFilters()).thenReturn(configLogFilters); + when(mockConfig.getServerLogFilters()).thenReturn(configLogFilters); + + FilterParams expectedParams = + FilterParams.create(true, 1024, 1024); + FilterParams resultParams + = configFilterHelper.logRpcMethod("service1/Method2", true); + assertThat(resultParams).isEqualTo(expectedParams); + + FilterParams expectedParamsWildCard = + FilterParams.create(true, 2048, 1024); + FilterParams resultParamsWildCard + = configFilterHelper.logRpcMethod("service2/Method1", true); + assertThat(resultParamsWildCard).isEqualTo(expectedParamsWildCard); + + FilterParams excludeParams = + FilterParams.create(false, 0, 0); + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service4/method3", false); + assertThat(serverResultParams).isEqualTo(excludeParams); + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java new file mode 100644 index 00000000000..2a2e1d4c229 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java @@ -0,0 +1,636 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.gcp.observability.interceptors.LogHelperTest.BYTEARRAY_MARSHALLER; +import static org.junit.Assert.assertSame; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.Duration; +import com.google.protobuf.util.Durations; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.Context; +import io.grpc.Deadline; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Status; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; +import io.grpc.internal.NoopClientCall; +import io.grpc.observabilitylog.v1.GrpcLogRecord; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Objects; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.AdditionalMatchers; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Tests for {@link InternalLoggingChannelInterceptor}. + */ +@RunWith(JUnit4.class) +public class InternalLoggingChannelInterceptorTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + private static final Charset US_ASCII = StandardCharsets.US_ASCII; + + private InternalLoggingChannelInterceptor.Factory factory; + private AtomicReference> interceptedListener; + private AtomicReference actualClientInitial; + private AtomicReference actualRequest; + private SettableFuture halfCloseCalled; + private SettableFuture cancelCalled; + private SocketAddress peer; + private LogHelper mockLogHelper; + private ConfigFilterHelper mockFilterHelper; + private FilterParams filterParams; + + @Before + public void setup() throws Exception { + mockLogHelper = mock(LogHelper.class); + mockFilterHelper = mock(ConfigFilterHelper.class); + factory = new InternalLoggingChannelInterceptor.FactoryImpl(mockLogHelper, mockFilterHelper); + interceptedListener = new AtomicReference<>(); + actualClientInitial = new AtomicReference<>(); + actualRequest = new AtomicReference<>(); + halfCloseCalled = SettableFuture.create(); + cancelCalled = SettableFuture.create(); + peer = new InetSocketAddress(InetAddress.getByName("127.0.0.1"), 1234); + filterParams = FilterParams.create(true, 0, 0); + } + + @Test + @SuppressWarnings("unchecked") + public void internalLoggingChannelInterceptor() throws Exception { + Channel channel = new Channel() { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + return new NoopClientCall() { + @Override + @SuppressWarnings("unchecked") + public void start(Listener responseListener, Metadata headers) { + interceptedListener.set((Listener) responseListener); + actualClientInitial.set(headers); + } + + @Override + public void sendMessage(RequestT message) { + actualRequest.set(message); + } + + @Override + public void cancel(String message, Throwable cause) { + cancelCalled.set(null); + } + + @Override + public void halfClose() { + halfCloseCalled.set(null); + } + + @Override + public Attributes getAttributes() { + return Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer).build(); + } + }; + } + + @Override + public String authority() { + return "the-authority"; + } + }; + + ClientCall.Listener mockListener = mock(ClientCall.Listener.class); + + MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) + .thenReturn(filterParams); + + ClientCall interceptedLoggingCall = + factory.create() + .interceptCall(method, + CallOptions.DEFAULT, + channel); + + // send request header + { + Metadata clientInitial = new Metadata(); + String dataA = "aaaaaaaaa"; + String dataB = "bbbbbbbbb"; + Metadata.Key keyA = + Metadata.Key.of("a", Metadata.ASCII_STRING_MARSHALLER); + Metadata.Key keyB = + Metadata.Key.of("b", Metadata.ASCII_STRING_MARSHALLER); + clientInitial.put(keyA, dataA); + clientInitial.put(keyB, dataB); + interceptedLoggingCall.start(mockListener, clientInitial); + verify(mockLogHelper).logClientHeader( + /*seq=*/ eq(1L), + eq("service"), + eq("method"), + eq("the-authority"), + ArgumentMatchers.isNull(), + same(clientInitial), + eq(filterParams.headerBytes()), + eq(EventLogger.CLIENT), + anyString(), + ArgumentMatchers.isNull()); + verifyNoMoreInteractions(mockLogHelper); + assertSame(clientInitial, actualClientInitial.get()); + } + + reset(mockLogHelper); + reset(mockListener); + + // receive response header + { + Metadata serverInitial = new Metadata(); + interceptedListener.get().onHeaders(serverInitial); + verify(mockLogHelper).logServerHeader( + /*seq=*/ eq(2L), + eq("service"), + eq("method"), + eq("the-authority"), + same(serverInitial), + eq(filterParams.headerBytes()), + eq(EventLogger.CLIENT), + anyString(), + same(peer)); + verifyNoMoreInteractions(mockLogHelper); + verify(mockListener).onHeaders(same(serverInitial)); + } + + reset(mockLogHelper); + reset(mockListener); + + // send request message + { + byte[] request = "this is a request".getBytes(US_ASCII); + interceptedLoggingCall.sendMessage(request); + verify(mockLogHelper).logRpcMessage( + /*seq=*/ eq(3L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventType.CLIENT_MESSAGE), + same(request), + eq(filterParams.messageBytes()), + eq(EventLogger.CLIENT), + anyString()); + verifyNoMoreInteractions(mockLogHelper); + assertSame(request, actualRequest.get()); + } + + reset(mockLogHelper); + reset(mockListener); + + // client half close + { + interceptedLoggingCall.halfClose(); + verify(mockLogHelper).logHalfClose( + /*seq=*/ eq(4L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventLogger.CLIENT), + anyString()); + halfCloseCalled.get(1, TimeUnit.MILLISECONDS); + verifyNoMoreInteractions(mockLogHelper); + } + + reset(mockLogHelper); + reset(mockListener); + + // receive response message + { + byte[] response = "this is a response".getBytes(US_ASCII); + interceptedListener.get().onMessage(response); + verify(mockLogHelper).logRpcMessage( + /*seq=*/ eq(5L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventType.SERVER_MESSAGE), + same(response), + eq(filterParams.messageBytes()), + eq(EventLogger.CLIENT), + anyString()); + verifyNoMoreInteractions(mockLogHelper); + verify(mockListener).onMessage(same(response)); + } + + reset(mockLogHelper); + reset(mockListener); + + // receive trailer + { + Status status = Status.INTERNAL.withDescription("trailer description"); + Metadata trailers = new Metadata(); + interceptedListener.get().onClose(status, trailers); + verify(mockLogHelper).logTrailer( + /*seq=*/ eq(6L), + eq("service"), + eq("method"), + eq("the-authority"), + same(status), + same(trailers), + eq(filterParams.headerBytes()), + eq(EventLogger.CLIENT), + anyString(), + same(peer)); + verifyNoMoreInteractions(mockLogHelper); + verify(mockListener).onClose(same(status), same(trailers)); + } + + reset(mockLogHelper); + reset(mockListener); + + // cancel + { + interceptedLoggingCall.cancel(null, null); + verify(mockLogHelper).logCancel( + /*seq=*/ eq(7L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventLogger.CLIENT), + anyString()); + cancelCalled.get(1, TimeUnit.MILLISECONDS); + } + } + + @Test + public void clientDeadLineLogged_deadlineSetViaCallOption() { + MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) + .thenReturn(filterParams); + @SuppressWarnings("unchecked") + ClientCall.Listener mockListener = mock(ClientCall.Listener.class); + + ClientCall interceptedLoggingCall = + factory.create() + .interceptCall( + method, + CallOptions.DEFAULT.withDeadlineAfter(1, TimeUnit.SECONDS), + new Channel() { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, + CallOptions callOptions) { + return new NoopClientCall<>(); + } + + @Override + public String authority() { + return "the-authority"; + } + }); + interceptedLoggingCall.start(mockListener, new Metadata()); + ArgumentCaptor callOptTimeoutCaptor = ArgumentCaptor.forClass(Duration.class); + verify(mockLogHelper, times(1)) + .logClientHeader( + anyLong(), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + callOptTimeoutCaptor.capture(), + any(Metadata.class), + anyInt(), + any(GrpcLogRecord.EventLogger.class), + anyString(), + AdditionalMatchers.or(ArgumentMatchers.isNull(), + ArgumentMatchers.any())); + Duration timeout = callOptTimeoutCaptor.getValue(); + assertThat(TimeUnit.SECONDS.toNanos(1) - Durations.toNanos(timeout)) + .isAtMost(TimeUnit.MILLISECONDS.toNanos(250)); + } + + @Test + public void clientDeadlineLogged_deadlineSetViaContext() throws Exception { + final SettableFuture> callFuture = SettableFuture.create(); + Context.current() + .withDeadline( + Deadline.after(1, TimeUnit.SECONDS), + Executors.newSingleThreadScheduledExecutor()) + .run(() -> { + MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) + .thenReturn(filterParams); + + callFuture.set( + factory.create() + .interceptCall( + method, + CallOptions.DEFAULT.withDeadlineAfter(1, TimeUnit.SECONDS), + new Channel() { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, + CallOptions callOptions) { + return new NoopClientCall<>(); + } + + @Override + public String authority() { + return "the-authority"; + } + })); + }); + @SuppressWarnings("unchecked") + ClientCall.Listener mockListener = mock(ClientCall.Listener.class); + Objects.requireNonNull(callFuture.get()).start(mockListener, new Metadata()); + ArgumentCaptor contextTimeoutCaptor = ArgumentCaptor.forClass(Duration.class); + verify(mockLogHelper, times(1)) + .logClientHeader( + anyLong(), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + contextTimeoutCaptor.capture(), + any(Metadata.class), + anyInt(), + any(GrpcLogRecord.EventLogger.class), + anyString(), + AdditionalMatchers.or(ArgumentMatchers.isNull(), + ArgumentMatchers.any())); + Duration timeout = contextTimeoutCaptor.getValue(); + assertThat(TimeUnit.SECONDS.toNanos(1) - Durations.toNanos(timeout)) + .isAtMost(TimeUnit.MILLISECONDS.toNanos(250)); + } + + @Test + public void clientDeadlineLogged_deadlineSetViaContextAndCallOptions() throws Exception { + Deadline contextDeadline = Deadline.after(10, TimeUnit.SECONDS); + Deadline callOptionsDeadline = CallOptions.DEFAULT + .withDeadlineAfter(15, TimeUnit.SECONDS).getDeadline(); + + final SettableFuture> callFuture = SettableFuture.create(); + Context.current() + .withDeadline( + contextDeadline, Executors.newSingleThreadScheduledExecutor()) + .run(() -> { + MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) + .thenReturn(filterParams); + + callFuture.set( + factory.create() + .interceptCall( + method, + CallOptions.DEFAULT.withDeadlineAfter(15, TimeUnit.SECONDS), + new Channel() { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, + CallOptions callOptions) { + return new NoopClientCall<>(); + } + + @Override + public String authority() { + return "the-authority"; + } + })); + }); + @SuppressWarnings("unchecked") + ClientCall.Listener mockListener = mock(ClientCall.Listener.class); + Objects.requireNonNull(callFuture.get()).start(mockListener, new Metadata()); + ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Duration.class); + verify(mockLogHelper, times(1)) + .logClientHeader( + anyLong(), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), + timeoutCaptor.capture(), + any(Metadata.class), + anyInt(), + any(GrpcLogRecord.EventLogger.class), + anyString(), + AdditionalMatchers.or(ArgumentMatchers.isNull(), + ArgumentMatchers.any())); + Duration timeout = timeoutCaptor.getValue(); + assertThat(LogHelper.min(contextDeadline, callOptionsDeadline)) + .isSameInstanceAs(contextDeadline); + assertThat(TimeUnit.SECONDS.toNanos(10) - Durations.toNanos(timeout)) + .isAtMost(TimeUnit.MILLISECONDS.toNanos(250)); + } + + @Test + public void clientMethodOrServiceFilter_disabled() { + Channel channel = new Channel() { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + return new NoopClientCall() { + @Override + @SuppressWarnings("unchecked") + public void start(Listener responseListener, Metadata headers) { + interceptedListener.set((Listener) responseListener); + actualClientInitial.set(headers); + } + + @Override + public void sendMessage(RequestT message) { + actualRequest.set(message); + } + + @Override + public void cancel(String message, Throwable cause) { + cancelCalled.set(null); + } + + @Override + public void halfClose() { + halfCloseCalled.set(null); + } + + @Override + public Attributes getAttributes() { + return Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer).build(); + } + }; + } + + @Override + public String authority() { + return "the-authority"; + } + }; + + @SuppressWarnings("unchecked") + ClientCall.Listener mockListener = mock(ClientCall.Listener.class); + + MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) + .thenReturn(FilterParams.create(false, 0, 0)); + + ClientCall interceptedLoggingCall = + factory.create() + .interceptCall(method, + CallOptions.DEFAULT, + channel); + + interceptedLoggingCall.start(mockListener, new Metadata()); + verifyNoInteractions(mockLogHelper); + } + + @Test + public void clientMethodOrServiceFilter_enabled() { + Channel channel = new Channel() { + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + return new NoopClientCall() { + @Override + @SuppressWarnings("unchecked") + public void start(Listener responseListener, Metadata headers) { + interceptedListener.set((Listener) responseListener); + actualClientInitial.set(headers); + } + + @Override + public void sendMessage(RequestT message) { + actualRequest.set(message); + } + + @Override + public void cancel(String message, Throwable cause) { + cancelCalled.set(null); + } + + @Override + public void halfClose() { + halfCloseCalled.set(null); + } + + @Override + public Attributes getAttributes() { + return Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer).build(); + } + }; + } + + @Override + public String authority() { + return "the-authority"; + } + }; + + @SuppressWarnings("unchecked") + ClientCall.Listener mockListener = mock(ClientCall.Listener.class); + + MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) + .thenReturn(FilterParams.create(true, 10, 10)); + + ClientCall interceptedLoggingCall = + factory.create() + .interceptCall(method, + CallOptions.DEFAULT, + channel); + + { + interceptedLoggingCall.start(mockListener, new Metadata()); + interceptedListener.get().onHeaders(new Metadata()); + byte[] request = "this is a request".getBytes(US_ASCII); + interceptedLoggingCall.sendMessage(request); + interceptedLoggingCall.halfClose(); + byte[] response = "this is a response".getBytes(US_ASCII); + interceptedListener.get().onMessage(response); + Status status = Status.INTERNAL.withDescription("trailer description"); + Metadata trailers = new Metadata(); + interceptedListener.get().onClose(status, trailers); + interceptedLoggingCall.cancel(null, null); + assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(7); + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java new file mode 100644 index 00000000000..fee936dfbca --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java @@ -0,0 +1,481 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.gcp.observability.interceptors.LogHelperTest.BYTEARRAY_MARSHALLER; +import static org.junit.Assert.assertSame; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Duration; +import com.google.protobuf.util.Durations; +import io.grpc.Attributes; +import io.grpc.Context; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerCall; +import io.grpc.Status; +import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; +import io.grpc.internal.NoopServerCall; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Tests for {@link InternalLoggingServerInterceptor}. + */ +@RunWith(JUnit4.class) +public class InternalLoggingServerInterceptorTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + private static final Charset US_ASCII = StandardCharsets.US_ASCII; + + private InternalLoggingServerInterceptor.Factory factory; + private AtomicReference> interceptedLoggingCall; + ServerCall.Listener capturedListener; + private ServerCall.Listener mockListener; + // capture these manually because ServerCall can not be mocked + private AtomicReference actualServerInitial; + private AtomicReference actualResponse; + private AtomicReference actualStatus; + private AtomicReference actualTrailers; + private LogHelper mockLogHelper; + private ConfigFilterHelper mockFilterHelper; + private SocketAddress peer; + + @Before + @SuppressWarnings("unchecked") + public void setup() throws Exception { + mockLogHelper = mock(LogHelper.class); + mockFilterHelper = mock(ConfigFilterHelper.class); + factory = new InternalLoggingServerInterceptor.FactoryImpl(mockLogHelper, mockFilterHelper); + interceptedLoggingCall = new AtomicReference<>(); + mockListener = mock(ServerCall.Listener.class); + actualServerInitial = new AtomicReference<>(); + actualResponse = new AtomicReference<>(); + actualStatus = new AtomicReference<>(); + actualTrailers = new AtomicReference<>(); + peer = new InetSocketAddress(InetAddress.getByName("127.0.0.1"), 1234); + } + + @Test + @SuppressWarnings("unchecked") + public void internalLoggingServerInterceptor() { + Metadata clientInitial = new Metadata(); + final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + FilterParams filterParams = FilterParams.create(true, 0, 0); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)).thenReturn(filterParams); + capturedListener = + factory.create() + .interceptCall( + new NoopServerCall() { + @Override + public void sendHeaders(Metadata headers) { + actualServerInitial.set(headers); + } + + @Override + public void sendMessage(byte[] message) { + actualResponse.set(message); + } + + @Override + public void close(Status status, Metadata trailers) { + actualStatus.set(status); + actualTrailers.set(trailers); + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return method; + } + + @Override + public Attributes getAttributes() { + return Attributes + .newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer) + .build(); + } + + @Override + public String getAuthority() { + return "the-authority"; + } + }, + clientInitial, + (call, headers) -> { + interceptedLoggingCall.set(call); + return mockListener; + }); + // receive request header + { + verify(mockLogHelper).logClientHeader( + /*seq=*/ eq(1L), + eq("service"), + eq("method"), + eq("the-authority"), + ArgumentMatchers.isNull(), + same(clientInitial), + eq(filterParams.headerBytes()), + eq(EventLogger.SERVER), + anyString(), + same(peer)); + verifyNoMoreInteractions(mockLogHelper); + } + + reset(mockLogHelper); + reset(mockListener); + + // send response header + { + Metadata serverInitial = new Metadata(); + interceptedLoggingCall.get().sendHeaders(serverInitial); + verify(mockLogHelper).logServerHeader( + /*seq=*/ eq(2L), + eq("service"), + eq("method"), + eq("the-authority"), + same(serverInitial), + eq(filterParams.headerBytes()), + eq(EventLogger.SERVER), + anyString(), + ArgumentMatchers.isNull()); + verifyNoMoreInteractions(mockLogHelper); + assertSame(serverInitial, actualServerInitial.get()); + } + + reset(mockLogHelper); + reset(mockListener); + + // receive request message + { + byte[] request = "this is a request".getBytes(US_ASCII); + capturedListener.onMessage(request); + verify(mockLogHelper).logRpcMessage( + /*seq=*/ eq(3L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventType.CLIENT_MESSAGE), + same(request), + eq(filterParams.messageBytes()), + eq(EventLogger.SERVER), + anyString()); + verifyNoMoreInteractions(mockLogHelper); + verify(mockListener).onMessage(same(request)); + } + + reset(mockLogHelper); + reset(mockListener); + + // client half close + { + capturedListener.onHalfClose(); + verify(mockLogHelper).logHalfClose( + /*seq=*/ eq(4L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventLogger.SERVER), + anyString()); + verifyNoMoreInteractions(mockLogHelper); + verify(mockListener).onHalfClose(); + } + + reset(mockLogHelper); + reset(mockListener); + + // send response message + { + byte[] response = "this is a response".getBytes(US_ASCII); + interceptedLoggingCall.get().sendMessage(response); + verify(mockLogHelper).logRpcMessage( + /*seq=*/ eq(5L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventType.SERVER_MESSAGE), + same(response), + eq(filterParams.messageBytes()), + eq(EventLogger.SERVER), + anyString()); + verifyNoMoreInteractions(mockLogHelper); + assertSame(response, actualResponse.get()); + } + + reset(mockLogHelper); + reset(mockListener); + + // send trailer + { + Status status = Status.INTERNAL.withDescription("trailer description"); + Metadata trailers = new Metadata(); + interceptedLoggingCall.get().close(status, trailers); + verify(mockLogHelper).logTrailer( + /*seq=*/ eq(6L), + eq("service"), + eq("method"), + eq("the-authority"), + same(status), + same(trailers), + eq(filterParams.headerBytes()), + eq(EventLogger.SERVER), + anyString(), + ArgumentMatchers.isNull()); + verifyNoMoreInteractions(mockLogHelper); + assertSame(status, actualStatus.get()); + assertSame(trailers, actualTrailers.get()); + } + + reset(mockLogHelper); + reset(mockListener); + + // cancel + { + capturedListener.onCancel(); + verify(mockLogHelper).logCancel( + /*seq=*/ eq(7L), + eq("service"), + eq("method"), + eq("the-authority"), + eq(EventLogger.SERVER), + anyString()); + verify(mockListener).onCancel(); + } + } + + @Test + public void serverDeadlineLogged() { + final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + FilterParams filterParams = FilterParams.create(true, 0, 0); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)).thenReturn(filterParams); + final ServerCall noopServerCall = new NoopServerCall() { + @Override + public MethodDescriptor getMethodDescriptor() { + return method; + } + + @Override + public String getAuthority() { + return "the-authority"; + } + }; + + // We expect the contents of the "grpc-timeout" header to be installed the context + Context.current() + .withDeadlineAfter(1, TimeUnit.SECONDS, Executors.newSingleThreadScheduledExecutor()) + .run( + () -> { + ServerCall.Listener unused = + factory.create() + .interceptCall(noopServerCall, + new Metadata(), + (call, headers) -> new ServerCall.Listener() {}); + }); + ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Duration.class); + verify(mockLogHelper, times(1)) + .logClientHeader( + /*seq=*/ eq(1L), + eq("service"), + eq("method"), + eq("the-authority"), + timeoutCaptor.capture(), + any(Metadata.class), + eq(filterParams.headerBytes()), + eq(EventLogger.SERVER), + anyString(), + ArgumentMatchers.isNull()); + verifyNoMoreInteractions(mockLogHelper); + Duration timeout = timeoutCaptor.getValue(); + assertThat(TimeUnit.SECONDS.toNanos(1) - Durations.toNanos(timeout)) + .isAtMost(TimeUnit.MILLISECONDS.toNanos(250)); + } + + @Test + public void serverMethodOrServiceFilter_disabled() { + Metadata clientInitial = new Metadata(); + final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)) + .thenReturn(FilterParams.create(false, 0, 0)); + capturedListener = + factory.create() + .interceptCall( + new NoopServerCall() { + @Override + public void sendHeaders(Metadata headers) { + actualServerInitial.set(headers); + } + + @Override + public void sendMessage(byte[] message) { + actualResponse.set(message); + } + + @Override + public void close(Status status, Metadata trailers) { + actualStatus.set(status); + actualTrailers.set(trailers); + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return method; + } + + @Override + public Attributes getAttributes() { + return Attributes + .newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer) + .build(); + } + + @Override + public String getAuthority() { + return "the-authority"; + } + }, + clientInitial, + (call, headers) -> { + interceptedLoggingCall.set(call); + return mockListener; + }); + verifyNoInteractions(mockLogHelper); + } + + @Test + public void serverMethodOrServiceFilter_enabled() { + Metadata clientInitial = new Metadata(); + final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(BYTEARRAY_MARSHALLER) + .setResponseMarshaller(BYTEARRAY_MARSHALLER) + .build(); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)) + .thenReturn(FilterParams.create(true, 10, 10)); + + capturedListener = + factory.create() + .interceptCall( + new NoopServerCall() { + @Override + public void sendHeaders(Metadata headers) { + actualServerInitial.set(headers); + } + + @Override + public void sendMessage(byte[] message) { + actualResponse.set(message); + } + + @Override + public void close(Status status, Metadata trailers) { + actualStatus.set(status); + actualTrailers.set(trailers); + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return method; + } + + @Override + public Attributes getAttributes() { + return Attributes + .newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer) + .build(); + } + + @Override + public String getAuthority() { + return "the-authority"; + } + }, + clientInitial, + (call, headers) -> { + interceptedLoggingCall.set(call); + return mockListener; + }); + + { + interceptedLoggingCall.get().sendHeaders(new Metadata()); + byte[] request = "this is a request".getBytes(US_ASCII); + capturedListener.onMessage(request); + capturedListener.onHalfClose(); + byte[] response = "this is a response".getBytes(US_ASCII); + interceptedLoggingCall.get().sendMessage(response); + Status status = Status.INTERNAL.withDescription("trailer description"); + Metadata trailers = new Metadata(); + interceptedLoggingCall.get().close(status, trailers); + capturedListener.onCancel(); + assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(7); + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java new file mode 100644 index 00000000000..73704eb4181 --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java @@ -0,0 +1,756 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.interceptors; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Duration; +import com.google.protobuf.util.Durations; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor.Marshaller; +import io.grpc.Status; +import io.grpc.gcp.observability.interceptors.LogHelper.PayloadBuilderHelper; +import io.grpc.gcp.observability.logging.GcpLogSink; +import io.grpc.gcp.observability.logging.Sink; +import io.grpc.observabilitylog.v1.Address; +import io.grpc.observabilitylog.v1.GrpcLogRecord; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import io.grpc.observabilitylog.v1.Payload; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link LogHelper}. + */ +@RunWith(JUnit4.class) +public class LogHelperTest { + public static final Marshaller BYTEARRAY_MARSHALLER = new ByteArrayMarshaller(); + private static final String DATA_A = "aaaaaaaaa"; + private static final String DATA_B = "bbbbbbbbb"; + private static final String DATA_C = "ccccccccc"; + private static final Metadata.Key KEY_A = + Metadata.Key.of("a", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key KEY_B = + Metadata.Key.of("b", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key KEY_C = + Metadata.Key.of("c", Metadata.ASCII_STRING_MARSHALLER); + private static final int HEADER_LIMIT = 10; + private static final int MESSAGE_LIMIT = Integer.MAX_VALUE; + + private final Metadata nonEmptyMetadata = new Metadata(); + private final Sink sink = mock(GcpLogSink.class); + private final LogHelper logHelper = new LogHelper(sink); + + @Before + public void setUp() { + nonEmptyMetadata.put(KEY_A, DATA_A); + nonEmptyMetadata.put(KEY_B, DATA_B); + nonEmptyMetadata.put(KEY_C, DATA_C); + } + + @Test + public void socketToProto_ipv4() throws Exception { + InetAddress address = InetAddress.getByName("127.0.0.1"); + int port = 12345; + InetSocketAddress socketAddress = new InetSocketAddress(address, port); + assertThat(LogHelper.socketAddressToProto(socketAddress)) + .isEqualTo(Address + .newBuilder() + .setType(Address.Type.TYPE_IPV4) + .setAddress("127.0.0.1") + .setIpPort(12345) + .build()); + } + + @Test + public void socketToProto_ipv6() throws Exception { + // this is a ipv6 link local address + InetAddress address = InetAddress.getByName("2001:db8:0:0:0:0:2:1"); + int port = 12345; + InetSocketAddress socketAddress = new InetSocketAddress(address, port); + assertThat(LogHelper.socketAddressToProto(socketAddress)) + .isEqualTo(Address + .newBuilder() + .setType(Address.Type.TYPE_IPV6) + .setAddress("2001:db8::2:1") // RFC 5952 section 4: ipv6 canonical form required + .setIpPort(12345) + .build()); + } + + @Test + public void socketToProto_unknown() { + SocketAddress unknownSocket = new SocketAddress() { + @Override + public String toString() { + return "some-socket-address"; + } + }; + assertThat(LogHelper.socketAddressToProto(unknownSocket)) + .isEqualTo(Address.newBuilder() + .setType(Address.Type.TYPE_UNKNOWN) + .setAddress("some-socket-address") + .build()); + } + + @Test + public void metadataToProto_empty() { + assertThat(metadataToProtoTestHelper( + EventType.CLIENT_HEADER, new Metadata(), Integer.MAX_VALUE)) + .isEqualTo(GrpcLogRecord.newBuilder() + .setType(EventType.CLIENT_HEADER) + .setPayload( + Payload.newBuilder().putAllMetadata(new HashMap<>())) + .build()); + } + + @Test + public void metadataToProto() { + Payload.Builder payloadBuilder = Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .putMetadata("c", DATA_C); + + assertThat(metadataToProtoTestHelper( + EventType.CLIENT_HEADER, nonEmptyMetadata, Integer.MAX_VALUE)) + .isEqualTo(GrpcLogRecord.newBuilder() + .setType(EventType.CLIENT_HEADER) + .setPayload(payloadBuilder) + .build()); + } + + @Test + public void metadataToProto_setsTruncated() { + assertTrue(LogHelper.createMetadataProto(nonEmptyMetadata, 0).truncated); + } + + @Test + public void metadataToProto_truncated() { + // 0 byte limit not enough for any metadata + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 0).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putAllMetadata(new HashMap<>()) + .build()); + // not enough bytes for first key value + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 9).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putAllMetadata(new HashMap<>()) + .build()); + // enough for first key value + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 10).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder().putMetadata("a", DATA_A).build()); + // Test edge cases for >= 2 key values + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 19).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder().putMetadata("a", DATA_A).build()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 20).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .build()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 29).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .build()); + // not truncated: enough for all keys + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 30).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .putMetadata("c", DATA_C) + .build()); + } + + @Test + public void messageToProto() { + byte[] bytes + = "this is a long message: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".getBytes( + StandardCharsets.US_ASCII); + assertThat(messageTestHelper(bytes, Integer.MAX_VALUE)) + .isEqualTo(GrpcLogRecord.newBuilder() + .setPayload( + Payload.newBuilder() + .setMessage( + ByteString.copyFrom(bytes)) + .setMessageLength(bytes.length)) + .build()); + } + + @Test + public void messageToProto_truncated() { + byte[] bytes + = "this is a long message: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".getBytes( + StandardCharsets.US_ASCII); + assertThat(messageTestHelper(bytes, 0)) + .isEqualTo(GrpcLogRecord.newBuilder() + .setPayload( + Payload.newBuilder() + .setMessageLength(bytes.length)) + .setPayloadTruncated(true) + .build()); + + int limit = 10; + String truncatedMessage = "this is a "; + assertThat(messageTestHelper(bytes, limit)) + .isEqualTo( + GrpcLogRecord.newBuilder() + .setPayload( + Payload.newBuilder() + .setMessage( + ByteString.copyFrom( + truncatedMessage.getBytes(StandardCharsets.US_ASCII))) + .setMessageLength(bytes.length)) + .setPayloadTruncated(true) + .build()); + } + + + @Test + public void logRequestHeader() throws Exception { + long seqId = 1; + String serviceName = "service"; + String methodName = "method"; + String authority = "authority"; + Duration timeout = Durations.fromMillis(1234); + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + InetAddress address = InetAddress.getByName("127.0.0.1"); + int port = 12345; + InetSocketAddress peerAddress = new InetSocketAddress(address, port); + + GrpcLogRecord.Builder builder = + metadataToProtoTestHelper(EventType.CLIENT_HEADER, nonEmptyMetadata, + HEADER_LIMIT) + .toBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setType(EventType.CLIENT_HEADER) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + builder.setAuthority(authority); + builder.setPayload(builder.getPayload().toBuilder().setTimeout(timeout).build()); + GrpcLogRecord base = builder.build(); + + // logged on client + { + logHelper.logClientHeader( + seqId, + serviceName, + methodName, + authority, + timeout, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + null); + verify(sink).write(base); + } + + // logged on server + { + logHelper.logClientHeader( + seqId, + serviceName, + methodName, + authority, + timeout, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.SERVER, + callId, + peerAddress); + verify(sink).write( + base.toBuilder() + .setPeer(LogHelper.socketAddressToProto(peerAddress)) + .setLogger(EventLogger.SERVER) + .build()); + } + + // timeout is null + { + logHelper.logClientHeader( + seqId, + serviceName, + methodName, + authority, + null, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + null); + verify(sink).write( + base.toBuilder() + .setPayload(base.getPayload().toBuilder().clearTimeout().build()) + .build()); + } + + // peerAddress is not null (error on client) + try { + logHelper.logClientHeader( + seqId, + serviceName, + methodName, + authority, + timeout, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + peerAddress); + fail(); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().contains("peerAddress can only be specified by server"); + } + } + + @Test + public void logResponseHeader() throws Exception { + long seqId = 1; + String serviceName = "service"; + String methodName = "method"; + String authority = "authority"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + InetAddress address = InetAddress.getByName("127.0.0.1"); + int port = 12345; + InetSocketAddress peerAddress = new InetSocketAddress(address, port); + + GrpcLogRecord.Builder builder = + metadataToProtoTestHelper(EventType.SERVER_HEADER, nonEmptyMetadata, + HEADER_LIMIT) + .toBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.SERVER_HEADER) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + builder.setPeer(LogHelper.socketAddressToProto(peerAddress)); + GrpcLogRecord base = builder.build(); + + // logged on client + { + logHelper.logServerHeader( + seqId, + serviceName, + methodName, + authority, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + peerAddress); + verify(sink).write(base); + } + + // logged on server + { + logHelper.logServerHeader( + seqId, + serviceName, + methodName, + authority, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.SERVER, + callId, + null); + verify(sink).write( + base.toBuilder() + .setLogger(EventLogger.SERVER) + .clearPeer() + .build()); + } + + // peerAddress is not null (error on server) + try { + logHelper.logServerHeader( + seqId, + serviceName, + methodName, + authority, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.SERVER, + callId, + peerAddress); + + fail(); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat() + .contains("peerAddress can only be specified for client"); + } + } + + @Test + public void logTrailer() throws Exception { + long seqId = 1; + String serviceName = "service"; + String methodName = "method"; + String authority = "authority"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + InetAddress address = InetAddress.getByName("127.0.0.1"); + int port = 12345; + InetSocketAddress peer = new InetSocketAddress(address, port); + Status statusDescription = Status.INTERNAL.withDescription("test description"); + + GrpcLogRecord.Builder builder = + metadataToProtoTestHelper(EventType.SERVER_HEADER, nonEmptyMetadata, + HEADER_LIMIT) + .toBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.SERVER_TRAILER) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + builder.setPeer(LogHelper.socketAddressToProto(peer)); + builder.setPayload( + builder.getPayload().toBuilder() + .setStatusCode(Status.INTERNAL.getCode().value()) + .setStatusMessage("test description") + .build()); + GrpcLogRecord base = builder.build(); + + // logged on client + { + logHelper.logTrailer( + seqId, + serviceName, + methodName, + authority, + statusDescription, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + peer); + verify(sink).write(base); + } + + // logged on server + { + logHelper.logTrailer( + seqId, + serviceName, + methodName, + authority, + statusDescription, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.SERVER, + callId, + null); + verify(sink).write( + base.toBuilder() + .clearPeer() + .setLogger(EventLogger.SERVER) + .build()); + } + + // peer address is null + { + logHelper.logTrailer( + seqId, + serviceName, + methodName, + authority, + statusDescription, + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + null); + verify(sink).write( + base.toBuilder() + .clearPeer() + .build()); + } + + // status description is null + { + logHelper.logTrailer( + seqId, + serviceName, + methodName, + authority, + statusDescription.getCode().toStatus(), + nonEmptyMetadata, + HEADER_LIMIT, + EventLogger.CLIENT, + callId, + peer); + verify(sink).write( + base.toBuilder() + .setPayload(base.getPayload().toBuilder().clearStatusMessage().build()) + .build()); + } + } + + @Test + public void alwaysLoggedMetadata_grpcTraceBin() { + Metadata.Key key + = Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata metadata = new Metadata(); + metadata.put(key, new byte[1]); + int zeroHeaderBytes = 0; + PayloadBuilderHelper pair = + LogHelper.createMetadataProto(metadata, zeroHeaderBytes); + assertThat(pair.payloadBuilder.build().getMetadataMap().containsKey(key.name())).isTrue(); + assertFalse(pair.truncated); + } + + @Test + public void neverLoggedMetadata_grpcStatusDetailsBin() { + Metadata.Key key + = Metadata.Key.of("grpc-status-details-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata metadata = new Metadata(); + metadata.put(key, new byte[1]); + int unlimitedHeaderBytes = Integer.MAX_VALUE; + PayloadBuilderHelper pair + = LogHelper.createMetadataProto(metadata, unlimitedHeaderBytes); + assertThat(pair.payloadBuilder.getMetadataMap()).isEmpty(); + assertFalse(pair.truncated); + } + + @Test + public void logRpcMessage() { + long seqId = 1; + String serviceName = "service"; + String methodName = "method"; + String authority = "authority"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + byte[] message = new byte[100]; + + GrpcLogRecord.Builder builder = messageTestHelper(message, MESSAGE_LIMIT) + .toBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.CLIENT_MESSAGE) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + GrpcLogRecord base = builder.build(); + // request message + { + logHelper.logRpcMessage( + seqId, + serviceName, + methodName, + authority, + EventType.CLIENT_MESSAGE, + message, + MESSAGE_LIMIT, + EventLogger.CLIENT, + callId); + verify(sink).write(base); + } + // response message, logged on client + { + logHelper.logRpcMessage( + seqId, + serviceName, + methodName, + authority, + EventType.SERVER_MESSAGE, + message, + MESSAGE_LIMIT, + EventLogger.CLIENT, + callId); + verify(sink).write( + base.toBuilder() + .setType(EventType.SERVER_MESSAGE) + .build()); + } + // request message, logged on server + { + logHelper.logRpcMessage( + seqId, + serviceName, + methodName, + authority, + EventType.CLIENT_MESSAGE, + message, + MESSAGE_LIMIT, + EventLogger.SERVER, + callId); + verify(sink).write( + base.toBuilder() + .setLogger(EventLogger.SERVER) + .build()); + } + // response message, logged on server + { + logHelper.logRpcMessage( + seqId, + serviceName, + methodName, + authority, + EventType.SERVER_MESSAGE, + message, + MESSAGE_LIMIT, + EventLogger.SERVER, + callId); + verify(sink).write( + base.toBuilder() + .setType(EventType.SERVER_MESSAGE) + .setLogger(EventLogger.SERVER) + .build()); + } + // message is not of type : com.google.protobuf.Message or byte[] + { + logHelper.logRpcMessage( + seqId, + serviceName, + methodName, + authority, + EventType.CLIENT_MESSAGE, + "message", + MESSAGE_LIMIT, + EventLogger.CLIENT, + callId); + verify(sink).write( + base.toBuilder() + .clearPayload() + .clearPayloadTruncated() + .build()); + } + } + + @Test + public void getPeerAddressTest() throws Exception { + SocketAddress peer = new InetSocketAddress(InetAddress.getByName("127.0.0.1"), 1234); + assertNull(LogHelper.getPeerAddress(Attributes.EMPTY)); + assertSame( + peer, + LogHelper.getPeerAddress( + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer).build())); + } + + private static GrpcLogRecord metadataToProtoTestHelper( + EventType type, Metadata metadata, int maxHeaderBytes) { + GrpcLogRecord.Builder builder = GrpcLogRecord.newBuilder(); + PayloadBuilderHelper pair + = LogHelper.createMetadataProto(metadata, maxHeaderBytes); + builder.setPayload(pair.payloadBuilder); + builder.setPayloadTruncated(pair.truncated); + builder.setType(type); + return builder.build(); + } + + private static GrpcLogRecord messageTestHelper(byte[] message, int maxMessageBytes) { + GrpcLogRecord.Builder builder = GrpcLogRecord.newBuilder(); + PayloadBuilderHelper pair + = LogHelper.createMessageProto(message, maxMessageBytes); + builder.setPayload(pair.payloadBuilder); + builder.setPayloadTruncated(pair.truncated); + return builder.build(); + } + + // Used only in tests + // Copied from internal + static final class ByteArrayMarshaller implements Marshaller { + + @Override + public InputStream stream(byte[] value) { + return new ByteArrayInputStream(value); + } + + @Override + public byte[] parse(InputStream stream) { + try { + return parseHelper(stream); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private byte[] parseHelper(InputStream stream) throws IOException { + try { + return IoUtils.toByteArray(stream); + } finally { + stream.close(); + } + } + } + + // Copied from internal + static final class IoUtils { + + /** maximum buffer to be read is 16 KB. */ + private static final int MAX_BUFFER_LENGTH = 16384; + + /** Returns the byte array. */ + public static byte[] toByteArray(InputStream in) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + copy(in, out); + return out.toByteArray(); + } + + /** Copies the data from input stream to output stream. */ + public static long copy(InputStream from, OutputStream to) throws IOException { + // Copied from guava com.google.common.io.ByteStreams because its API is unstable (beta) + checkNotNull(from); + checkNotNull(to); + byte[] buf = new byte[MAX_BUFFER_LENGTH]; + long total = 0; + while (true) { + int r = from.read(buf); + if (r == -1) { + break; + } + to.write(buf, 0, r); + total += r; + } + return total; + } + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java new file mode 100644 index 00000000000..e02cc6dd4eb --- /dev/null +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java @@ -0,0 +1,207 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.observability.logging; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.anyIterable; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.google.cloud.MonitoredResource; +import com.google.cloud.logging.LogEntry; +import com.google.cloud.logging.Logging; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Duration; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.protobuf.util.Durations; +import io.grpc.observabilitylog.v1.GrpcLogRecord; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; +import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Tests for {@link GcpLogSink}. + */ +@RunWith(JUnit4.class) +public class GcpLogSinkTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + private static final ImmutableMap LOCATION_TAGS = + ImmutableMap.of("project_id", "PROJECT", + "location", "us-central1-c", + "cluster_name", "grpc-observability-cluster", + "namespace_name", "default" , + "pod_name", "app1-6c7c58f897-n92c5"); + private static final ImmutableMap CUSTOM_TAGS = + ImmutableMap.of("KEY1", "Value1", + "KEY2", "VALUE2"); + // gRPC is expected to always use this log name when reporting to GCP cloud logging. + private static final String EXPECTED_LOG_NAME = + "microservices.googleapis.com%2Fobservability%2Fgrpc"; + private static final long SEQ_ID = 1; + private static final String DEST_PROJECT_NAME = "PROJECT"; + private static final String SERVICE_NAME = "service"; + private static final String METHOD_NAME = "method"; + private static final String AUTHORITY = "authority"; + private static final Duration TIMEOUT = Durations.fromMillis(1234); + private static final String CALL_ID = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + private static final GrpcLogRecord LOG_PROTO = GrpcLogRecord.newBuilder() + .setSequenceId(SEQ_ID) + .setServiceName(SERVICE_NAME) + .setMethodName(METHOD_NAME) + .setAuthority(AUTHORITY) + .setPayload(io.grpc.observabilitylog.v1.Payload.newBuilder().setTimeout(TIMEOUT)) + .setType(EventType.CLIENT_HEADER) + .setLogger(EventLogger.CLIENT) + .setCallId(CALL_ID) + .build(); + // .putFields("timeout", Value.newBuilder().setStringValue("1.234s").build()) + private static final Struct struct = + Struct.newBuilder() + .putFields("timeout", Value.newBuilder().setStringValue("1.234s").build()) + .build(); + private static final Struct EXPECTED_STRUCT_LOG_PROTO = Struct.newBuilder() + .putFields("sequenceId", Value.newBuilder().setStringValue(String.valueOf(SEQ_ID)).build()) + .putFields("serviceName", Value.newBuilder().setStringValue(SERVICE_NAME).build()) + .putFields("methodName", Value.newBuilder().setStringValue(METHOD_NAME).build()) + .putFields("authority", Value.newBuilder().setStringValue(AUTHORITY).build()) + .putFields("payload", Value.newBuilder().setStructValue(struct).build()) + .putFields("type", Value.newBuilder().setStringValue( + String.valueOf(EventType.CLIENT_HEADER)).build()) + .putFields("logger", Value.newBuilder().setStringValue( + String.valueOf(EventLogger.CLIENT)).build()) + .putFields("callId", Value.newBuilder().setStringValue(CALL_ID).build()) + .build(); + @Mock + private Logging mockLogging; + + @Test + @SuppressWarnings("unchecked") + public void verifyWrite() throws Exception { + GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, + CUSTOM_TAGS, Collections.emptySet()); + sink.write(LOG_PROTO); + + ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( + (Class) Collection.class); + verify(mockLogging, times(1)).write(logEntrySetCaptor.capture()); + for (Iterator it = logEntrySetCaptor.getValue().iterator(); it.hasNext(); ) { + LogEntry entry = it.next(); + assertThat(entry.getPayload().getData()).isEqualTo(EXPECTED_STRUCT_LOG_PROTO); + assertThat(entry.getLogName()).isEqualTo(EXPECTED_LOG_NAME); + } + verifyNoMoreInteractions(mockLogging); + } + + @Test + @SuppressWarnings("unchecked") + public void verifyWriteWithTags() { + GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, + CUSTOM_TAGS, Collections.emptySet()); + MonitoredResource expectedMonitoredResource = GcpLogSink.getResource(LOCATION_TAGS); + sink.write(LOG_PROTO); + + ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( + (Class) Collection.class); + verify(mockLogging, times(1)).write(logEntrySetCaptor.capture()); + System.out.println(logEntrySetCaptor.getValue()); + for (Iterator it = logEntrySetCaptor.getValue().iterator(); it.hasNext(); ) { + LogEntry entry = it.next(); + assertThat(entry.getResource()).isEqualTo(expectedMonitoredResource); + assertThat(entry.getLabels()).isEqualTo(CUSTOM_TAGS); + assertThat(entry.getPayload().getData()).isEqualTo(EXPECTED_STRUCT_LOG_PROTO); + assertThat(entry.getLogName()).isEqualTo(EXPECTED_LOG_NAME); + } + verifyNoMoreInteractions(mockLogging); + } + + @Test + @SuppressWarnings("unchecked") + public void emptyCustomTags_labelsNotSet() { + Map emptyCustomTags = null; + Map expectedEmptyLabels = new HashMap<>(); + GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, + emptyCustomTags, Collections.emptySet()); + sink.write(LOG_PROTO); + + ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( + (Class) Collection.class); + verify(mockLogging, times(1)).write(logEntrySetCaptor.capture()); + for (Iterator it = logEntrySetCaptor.getValue().iterator(); it.hasNext(); ) { + LogEntry entry = it.next(); + assertThat(entry.getLabels()).isEqualTo(expectedEmptyLabels); + assertThat(entry.getPayload().getData()).isEqualTo(EXPECTED_STRUCT_LOG_PROTO); + } + } + + @Test + @SuppressWarnings("unchecked") + public void emptyCustomTags_setSourceProject() { + Map emptyCustomTags = null; + String projectId = "PROJECT"; + Map expectedLabels = GcpLogSink.getCustomTags(emptyCustomTags, LOCATION_TAGS, + projectId); + GcpLogSink sink = new GcpLogSink(mockLogging, projectId, LOCATION_TAGS, + emptyCustomTags, Collections.emptySet()); + sink.write(LOG_PROTO); + + ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( + (Class) Collection.class); + verify(mockLogging, times(1)).write(logEntrySetCaptor.capture()); + for (Iterator it = logEntrySetCaptor.getValue().iterator(); it.hasNext(); ) { + LogEntry entry = it.next(); + assertThat(entry.getLabels()).isEqualTo(expectedLabels); + assertThat(entry.getPayload().getData()).isEqualTo(EXPECTED_STRUCT_LOG_PROTO); + } + } + + @Test + public void verifyClose() throws Exception { + GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, + CUSTOM_TAGS, Collections.emptySet()); + sink.write(LOG_PROTO); + verify(mockLogging, times(1)).write(anyIterable()); + sink.close(); + verify(mockLogging).close(); + verifyNoMoreInteractions(mockLogging); + } + + @Test + public void verifyExclude() throws Exception { + Sink mockSink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, + CUSTOM_TAGS, Collections.singleton("service")); + mockSink.write(LOG_PROTO); + verifyNoInteractions(mockLogging); + } +} diff --git a/googleapis/BUILD.bazel b/googleapis/BUILD.bazel new file mode 100644 index 00000000000..77b0bcd93b9 --- /dev/null +++ b/googleapis/BUILD.bazel @@ -0,0 +1,14 @@ +java_library( + name = "googleapis", + srcs = glob([ + "src/main/java/**/*.java", + ]), + visibility = ["//visibility:public"], + deps = [ + "//alts", + "//api", + "//core:internal", + "//xds", + "@com_google_guava_guava//jar", + ], +) diff --git a/googleapis/build.gradle b/googleapis/build.gradle new file mode 100644 index 00000000000..d829b1d28e9 --- /dev/null +++ b/googleapis/build.gradle @@ -0,0 +1,36 @@ +plugins { + id "java-library" + id "maven-publish" + + id "ru.vyarus.animalsniffer" +} + +description = 'gRPC: googleapis' + +dependencies { + api project(':grpc-api') + implementation project(':grpc-alts'), + project(':grpc-core'), + project(':grpc-xds'), + libraries.guava + testImplementation project(':grpc-core').sourceSets.test.output + + signature libraries.signature.java +} + +publishing { + publications { + maven(MavenPublication) { + pom { + withXml { + // Since internal APIs are used, pin the version. + asNode().dependencies.'*'.findAll() { dep -> + dep.artifactId.text() in ['grpc-alts', 'grpc-xds'] + }.each() { core -> + core.version*.value = "[" + core.version.text() + "]" + } + } + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java similarity index 62% rename from xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolverProvider.java rename to googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java index e7f9cb45ff8..349e1c94380 100644 --- a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2021 The gRPC Authors + * Copyright 2022 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,45 +14,39 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.googleapis; import io.grpc.Internal; import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; -import io.grpc.internal.GrpcUtil; import java.net.URI; /** - * A provider for {@link GoogleCloudToProdNameResolver}. + * A provider for {@link GoogleCloudToProdNameResolver}, with experimental scheme. */ @Internal -public final class GoogleCloudToProdNameResolverProvider extends NameResolverProvider { - - private static final String SCHEME = "google-c2p-experimental"; +public final class GoogleCloudToProdExperimentalNameResolverProvider extends NameResolverProvider { + private final GoogleCloudToProdNameResolverProvider delegate = + new GoogleCloudToProdNameResolverProvider("google-c2p-experimental"); @Override public NameResolver newNameResolver(URI targetUri, Args args) { - if (SCHEME.equals(targetUri.getScheme())) { - return new GoogleCloudToProdNameResolver( - targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR, - SharedXdsClientPoolProvider.getDefaultProvider()); - } - return null; + return delegate.newNameResolver(targetUri, args); } @Override public String getDefaultScheme() { - return SCHEME; + return delegate.getDefaultScheme(); } @Override protected boolean isAvailable() { - return true; + return delegate.isAvailable(); } @Override protected int priority() { - return 4; + return delegate.priority(); } } diff --git a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java similarity index 76% rename from xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java rename to googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java index 3ec0434a22a..1db4825ccb8 100644 --- a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java @@ -14,13 +14,14 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.googleapis; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Charsets; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.CharStreams; @@ -32,7 +33,6 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder.Resource; -import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; @@ -40,6 +40,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; @@ -54,6 +55,7 @@ final class GoogleCloudToProdNameResolver extends NameResolver { @VisibleForTesting static final String METADATA_URL_SUPPORT_IPV6 = "http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ipv6s"; + static final String C2P_AUTHORITY = "traffic-director-c2p.xds.googleapis.com"; @VisibleForTesting static boolean isOnGcp = InternalCheckGcpEnvironment.isOnGcp(); @VisibleForTesting @@ -62,6 +64,10 @@ final class GoogleCloudToProdNameResolver extends NameResolver { || System.getProperty("io.grpc.xds.bootstrap") != null || System.getenv("GRPC_XDS_BOOTSTRAP_CONFIG") != null || System.getProperty("io.grpc.xds.bootstrapConfig") != null; + @VisibleForTesting + static boolean enableFederation = + !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")) + && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")); private static final String serverUriOverride = System.getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI"); @@ -70,13 +76,16 @@ final class GoogleCloudToProdNameResolver extends NameResolver { private final String authority; private final SynchronizationContext syncContext; private final Resource executorResource; - private final XdsClientPoolFactory xdsClientPoolFactory; + private final BootstrapSetter bootstrapSetter; private final NameResolver delegate; private final Random rand; private final boolean usingExecutorResource; // It's not possible to use both PSM and DirectPath C2P in the same application. // Delegate to DNS if user-provided bootstrap is found. - private final String schemeOverride = !isOnGcp || xdsBootstrapProvided ? "dns" : "xds"; + private final String schemeOverride = + !isOnGcp + || (xdsBootstrapProvided && !enableFederation) + ? "dns" : "xds"; private Executor executor; private Listener2 listener; private boolean succeeded; @@ -84,17 +93,16 @@ final class GoogleCloudToProdNameResolver extends NameResolver { private boolean shutdown; GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - XdsClientPoolFactory xdsClientPoolFactory) { - this(targetUri, args, executorResource, new Random(), xdsClientPoolFactory, + BootstrapSetter bootstrapSetter) { + this(targetUri, args, executorResource, new Random(), bootstrapSetter, NameResolverRegistry.getDefaultRegistry().asFactory()); } @VisibleForTesting GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - Random rand, XdsClientPoolFactory xdsClientPoolFactory, - NameResolver.Factory nameResolverFactory) { + Random rand, BootstrapSetter bootstrapSetter, NameResolver.Factory nameResolverFactory) { this.executorResource = checkNotNull(executorResource, "executorResource"); - this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.bootstrapSetter = checkNotNull(bootstrapSetter, "bootstrapSetter"); this.rand = checkNotNull(rand, "rand"); String targetPath = checkNotNull(checkNotNull(targetUri, "targetUri").getPath(), "targetPath"); Preconditions.checkArgument( @@ -104,8 +112,12 @@ final class GoogleCloudToProdNameResolver extends NameResolver { targetUri); authority = GrpcUtil.checkAuthority(targetPath.substring(1)); syncContext = checkNotNull(args, "args").getSynchronizationContext(); + targetUri = overrideUriScheme(targetUri, schemeOverride); + if (schemeOverride.equals("xds") && enableFederation) { + targetUri = overrideUriAuthority(targetUri, C2P_AUTHORITY); + } delegate = checkNotNull(nameResolverFactory, "nameResolverFactory").newNameResolver( - overrideUriScheme(targetUri, schemeOverride), args); + targetUri, args); executor = args.getOffloadExecutor(); usingExecutorResource = executor == null; } @@ -144,22 +156,28 @@ private void resolve() { class Resolve implements Runnable { @Override public void run() { - String zone; - boolean supportIpv6; ImmutableMap rawBootstrap = null; try { - zone = queryZoneMetadata(METADATA_URL_ZONE); - supportIpv6 = queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6); - rawBootstrap = generateBootstrap(zone, supportIpv6); + // User provided bootstrap configs are only supported with federation. If federation is + // not enabled or there is no user provided config, we set a custom bootstrap override. + // Otherwise, we don't set the override, which will allow a user provided bootstrap config + // to take effect. + if (!enableFederation || !xdsBootstrapProvided) { + rawBootstrap = generateBootstrap(queryZoneMetadata(METADATA_URL_ZONE), + queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6)); + } } catch (IOException e) { - listener.onError(Status.INTERNAL.withDescription("Unable to get metadata").withCause(e)); + listener.onError( + Status.INTERNAL.withDescription("Unable to get metadata").withCause(e)); } finally { final ImmutableMap finalRawBootstrap = rawBootstrap; syncContext.execute(new Runnable() { @Override public void run() { - if (!shutdown && finalRawBootstrap != null) { - xdsClientPoolFactory.setBootstrapOverride(finalRawBootstrap); + if (!shutdown) { + if (finalRawBootstrap != null) { + bootstrapSetter.setBootstrap(finalRawBootstrap); + } delegate.start(listener); succeeded = true; } @@ -192,9 +210,14 @@ public void run() { serverBuilder.put("channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default"))); serverBuilder.put("server_features", ImmutableList.of("xds_v3")); + ImmutableMap.Builder authoritiesBuilder = ImmutableMap.builder(); + authoritiesBuilder.put( + C2P_AUTHORITY, + ImmutableMap.of("xds_servers", ImmutableList.of(serverBuilder.buildOrThrow()))); return ImmutableMap.of( - "node", nodeBuilder.build(), - "xds_servers", ImmutableList.of(serverBuilder.build())); + "node", nodeBuilder.buildOrThrow(), + "xds_servers", ImmutableList.of(serverBuilder.buildOrThrow()), + "authorities", authoritiesBuilder.buildOrThrow()); } @Override @@ -267,6 +290,16 @@ private static URI overrideUriScheme(URI uri, String scheme) { return res; } + private static URI overrideUriAuthority(URI uri, String authority) { + URI res; + try { + res = new URI(uri.getScheme(), authority, uri.getPath(), uri.getQuery(), uri.getFragment()); + } catch (URISyntaxException ex) { + throw new IllegalArgumentException("Invalid authority: " + authority, ex); + } + return res; + } + private enum HttpConnectionFactory implements HttpConnectionProvider { INSTANCE; @@ -284,4 +317,8 @@ public HttpURLConnection createConnection(String url) throws IOException { interface HttpConnectionProvider { HttpURLConnection createConnection(String url) throws IOException; } + + public interface BootstrapSetter { + void setBootstrap(Map bootstrap); + } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java new file mode 100644 index 00000000000..ce833d5c4e0 --- /dev/null +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java @@ -0,0 +1,88 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.googleapis; + +import com.google.common.base.Preconditions; +import io.grpc.Internal; +import io.grpc.NameResolver; +import io.grpc.NameResolver.Args; +import io.grpc.NameResolverProvider; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.InternalSharedXdsClientPoolProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; + +/** + * A provider for {@link GoogleCloudToProdNameResolver}. + */ +@Internal +public final class GoogleCloudToProdNameResolverProvider extends NameResolverProvider { + + private static final String SCHEME = "google-c2p"; + + private final String scheme; + + public GoogleCloudToProdNameResolverProvider() { + this(SCHEME); + } + + GoogleCloudToProdNameResolverProvider(String scheme) { + this.scheme = Preconditions.checkNotNull(scheme, "scheme"); + } + + @Override + public NameResolver newNameResolver(URI targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + return new GoogleCloudToProdNameResolver( + targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR, + new SharedXdsClientPoolProviderBootstrapSetter()); + } + return null; + } + + @Override + public String getDefaultScheme() { + return scheme; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 4; + } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + + private static final class SharedXdsClientPoolProviderBootstrapSetter + implements GoogleCloudToProdNameResolver.BootstrapSetter { + @Override + public void setBootstrap(Map bootstrap) { + InternalSharedXdsClientPoolProvider.setDefaultProviderBootstrapOverride(bootstrap); + } + } +} diff --git a/googleapis/src/main/resources/META-INF/services/io.grpc.NameResolverProvider b/googleapis/src/main/resources/META-INF/services/io.grpc.NameResolverProvider new file mode 100644 index 00000000000..1d08ff2bb0b --- /dev/null +++ b/googleapis/src/main/resources/META-INF/services/io.grpc.NameResolverProvider @@ -0,0 +1,2 @@ +io.grpc.googleapis.GoogleCloudToProdExperimentalNameResolverProvider +io.grpc.googleapis.GoogleCloudToProdNameResolverProvider diff --git a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverProviderTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java similarity index 79% rename from xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverProviderTest.java rename to googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java index cc6621028c8..447b102c8c7 100644 --- a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverProviderTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.googleapis; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; @@ -71,9 +71,28 @@ NameResolverProvider.class, getClass().getClassLoader())) { fail("GoogleCloudToProdNameResolverProvider not registered"); } + @Test + public void experimentalProvided() { + for (NameResolverProvider current + : InternalServiceProviders.getCandidatesViaServiceLoader( + NameResolverProvider.class, getClass().getClassLoader())) { + if (current instanceof GoogleCloudToProdExperimentalNameResolverProvider) { + return; + } + } + fail("GoogleCloudToProdExperimentalNameResolverProvider not registered"); + } + @Test public void newNameResolver() { assertThat(provider + .newNameResolver(URI.create("google-c2p:///foo.googleapis.com"), args)) + .isInstanceOf(GoogleCloudToProdNameResolver.class); + } + + @Test + public void experimentalNewNameResolver() { + assertThat(new GoogleCloudToProdExperimentalNameResolverProvider() .newNameResolver(URI.create("google-c2p-experimental:///foo.googleapis.com"), args)) .isInstanceOf(GoogleCloudToProdNameResolver.class); } diff --git a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java similarity index 74% rename from xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java rename to googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java index 7c777b84bf7..52174c19a3b 100644 --- a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.googleapis; import static com.google.common.truth.Truth.assertThat; import static org.mockito.Mockito.mock; @@ -33,12 +33,10 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; +import io.grpc.googleapis.GoogleCloudToProdNameResolver.HttpConnectionProvider; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourceHolder.Resource; -import io.grpc.xds.GoogleCloudToProdNameResolver.HttpConnectionProvider; -import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.HttpURLConnection; @@ -50,7 +48,6 @@ import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; -import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -69,7 +66,7 @@ public class GoogleCloudToProdNameResolverTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final URI TARGET_URI = URI.create("google-c2p-experimental:///googleapis.com"); + private static final URI TARGET_URI = URI.create("google-c2p:///googleapis.com"); private static final String ZONE = "us-central1-a"; private static final int DEFAULT_PORT = 887; @@ -88,7 +85,7 @@ public void uncaughtException(Thread t, Throwable e) { .setChannelLogger(mock(ChannelLogger.class)) .build(); private final FakeClock fakeExecutor = new FakeClock(); - private final FakeXdsClientPoolFactory fakeXdsClientPoolFactory = new FakeXdsClientPoolFactory(); + private final FakeBootstrapSetter fakeBootstrapSetter = new FakeBootstrapSetter(); private final Resource fakeExecutorResource = new Resource() { @Override public Executor create() { @@ -144,7 +141,7 @@ public HttpURLConnection createConnection(String url) throws IOException { } }; resolver = new GoogleCloudToProdNameResolver( - TARGET_URI, args, fakeExecutorResource, random, fakeXdsClientPoolFactory, + TARGET_URI, args, fakeExecutorResource, random, fakeBootstrapSetter, nsRegistry.asFactory()); resolver.setHttpConnectionProvider(httpConnections); } @@ -178,7 +175,7 @@ public void onGcpAndNoProvidedBootstrapDelegateToXds() { fakeExecutor.runDueTasks(); assertThat(delegatedResolver.keySet()).containsExactly("xds"); verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeXdsClientPoolFactory.bootstrapRef.get(); + Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); Map node = (Map) bootstrap.get("node"); assertThat(node).containsExactly( "id", "C2P-991614323", @@ -190,6 +187,55 @@ public void onGcpAndNoProvidedBootstrapDelegateToXds() { "server_uri", "directpath-pa.googleapis.com", "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), "server_features", ImmutableList.of("xds_v3")); + Map authorities = (Map) bootstrap.get("authorities"); + assertThat(authorities).containsExactly( + "traffic-director-c2p.xds.googleapis.com", + ImmutableMap.of("xds_servers", ImmutableList.of(server))); + } + + @SuppressWarnings("unchecked") + @Test + public void onGcpAndNoProvidedBootstrapAndFederationEnabledDelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = true; + GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; + GoogleCloudToProdNameResolver.enableFederation = true; + createResolver(); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); + // check bootstrap + Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); + Map node = (Map) bootstrap.get("node"); + assertThat(node).containsExactly( + "id", "C2P-991614323", + "locality", ImmutableMap.of("zone", ZONE), + "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); + Map server = Iterables.getOnlyElement( + (List>) bootstrap.get("xds_servers")); + assertThat(server).containsExactly( + "server_uri", "directpath-pa.googleapis.com", + "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), + "server_features", ImmutableList.of("xds_v3")); + Map authorities = (Map) bootstrap.get("authorities"); + assertThat(authorities).containsExactly( + "traffic-director-c2p.xds.googleapis.com", + ImmutableMap.of("xds_servers", ImmutableList.of(server))); + } + + @SuppressWarnings("unchecked") + @Test + public void onGcpAndProvidedBootstrapAndFederationEnabledDontDelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = true; + GoogleCloudToProdNameResolver.xdsBootstrapProvided = true; + GoogleCloudToProdNameResolver.enableFederation = true; + createResolver(); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); + // Bootstrapper should not have been set, since there was no user provided config. + assertThat(fakeBootstrapSetter.bootstrapRef.get()).isNull(); } @Test @@ -246,23 +292,13 @@ public String getDefaultScheme() { } } - private static final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { + private static final class FakeBootstrapSetter + implements GoogleCloudToProdNameResolver.BootstrapSetter { private final AtomicReference> bootstrapRef = new AtomicReference<>(); @Override - public void setBootstrapOverride(Map bootstrap) { + public void setBootstrap(Map bootstrap) { bootstrapRef.set(bootstrap); } - - @Override - @Nullable - public ObjectPool get() { - throw new UnsupportedOperationException("Should not be called"); - } - - @Override - public ObjectPool getOrCreate() { - throw new UnsupportedOperationException("Should not be called"); - } } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml new file mode 100644 index 00000000000..6e377bb4c3a --- /dev/null +++ b/gradle/libs.versions.toml @@ -0,0 +1,72 @@ +[versions] +# Breaks on upgrade: https://github.com/mojohaus/animal-sniffer/issues/131 +animalsniffer = "1.18" +autovalue = "1.9" +checkstyle = "8.28" +googleauth = "1.4.0" +guava = "31.1-android" +netty = '4.1.79.Final' +nettytcnative = '2.0.54.Final' +opencensus = "0.31.0" +protobuf = "3.21.7" + +[libraries] +android-annotations = "com.google.android:annotations:4.1.1.4" +androidx-annotation = "androidx.annotation:annotation:1.4.0" +androidx-core = "androidx.core:core:1.3.0" +androidx-lifecycle-common = "androidx.lifecycle:lifecycle-common:2.3.0" +androidx-lifecycle-service = "androidx.lifecycle:lifecycle-service:2.3.0" +androidx-test-core = "androidx.test:core:1.4.0" +androidx-test-ext-junit = "androidx.test.ext:junit:1.1.3" +androidx-test-rules = "androidx.test:rules:1.4.0" +animalsniffer-annotations = "org.codehaus.mojo:animal-sniffer-annotations:1.21" +auto-value = { module = "com.google.auto.value:auto-value", version.ref = "autovalue" } +auto-value-annotations = { module = "com.google.auto.value:auto-value-annotations", version.ref = "autovalue" } +commons-math3 = "org.apache.commons:commons-math3:3.6.1" +conscrypt = "org.conscrypt:conscrypt-openjdk-uber:2.5.2" +cronet-api = "org.chromium.net:cronet-api:92.4515.131" +cronet-embedded = "org.chromium.net:cronet-embedded:102.5005.125" +errorprone-annotations = "com.google.errorprone:error_prone_annotations:2.14.0" +errorprone-core = "com.google.errorprone:error_prone_core:2.10.0" +google-api-protos = "com.google.api.grpc:proto-google-common-protos:2.9.0" +google-auth-credentials = { module = "com.google.auth:google-auth-library-credentials", version.ref = "googleauth" } +google-auth-oauth2Http = { module = "com.google.auth:google-auth-library-oauth2-http", version.ref = "googleauth" } +gson = "com.google.code.gson:gson:2.9.0" +guava = { module = "com.google.guava:guava", version.ref = "guava" } +guava-betaChecker = "com.google.guava:guava-beta-checker:1.0" +guava-testlib = { module = "com.google.guava:guava-testlib", version.ref = "guava" } +hdrhistogram = "org.hdrhistogram:HdrHistogram:2.1.12" +javax-annotation = "org.apache.tomcat:annotations-api:6.0.53" +jetty-alpn-agent = "org.mortbay.jetty.alpn:jetty-alpn-agent:2.0.10" +jsr305 = "com.google.code.findbugs:jsr305:3.0.2" +junit = "junit:junit:4.13.2" +mockito-android = "org.mockito:mockito-android:3.8.0" +mockito-core = "org.mockito:mockito-core:3.3.3" +netty-codec-http2 = { module = "io.netty:netty-codec-http2", version.ref = "netty" } +netty-handler-proxy = { module = "io.netty:netty-handler-proxy", version.ref = "netty" } +# Keep the following references of tcnative version in sync whenever it's updated: +# SECURITY.md (multiple occurrences) +# examples/example-tls/build.gradle +# examples/example-tls/pom.xml +netty-tcnative = { module = "io.netty:netty-tcnative-boringssl-static", version.ref = "nettytcnative" } +netty-tcnative-classes = { module = "io.netty:netty-tcnative-classes", version.ref = "nettytcnative" } +netty-transport-epoll = { module = "io.netty:netty-transport-native-epoll", version.ref = "netty" } +netty-unix-common = { module = "io.netty:netty-transport-native-unix-common", version.ref = "netty" } +okhttp = "com.squareup.okhttp:okhttp:2.7.5" +okio = "com.squareup.okio:okio:1.17.5" +opencensus-api = { module = "io.opencensus:opencensus-api", version.ref = "opencensus" } +opencensus-contrib-grpc-metrics = { module = "io.opencensus:opencensus-contrib-grpc-metrics", version.ref = "opencensus" } +opencensus-exporter-stats-stackdriver = { module = "io.opencensus:opencensus-exporter-stats-stackdriver", version.ref = "opencensus" } +opencensus-exporter-trace-stackdriver = { module = "io.opencensus:opencensus-exporter-trace-stackdriver", version.ref = "opencensus" } +opencensus-impl = { module = "io.opencensus:opencensus-impl", version.ref = "opencensus" } +opencensus-proto = "io.opencensus:opencensus-proto:0.2.0" +perfmark-api = "io.perfmark:perfmark-api:0.25.0" +protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } +protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", version.ref = "protobuf" } +protobuf-javalite = { module = "com.google.protobuf:protobuf-javalite", version.ref = "protobuf" } +protobuf-protoc = { module = "com.google.protobuf:protoc", version.ref = "protobuf" } +re2j = "com.google.re2j:re2j:1.6" +robolectric = "org.robolectric:robolectric:4.8.1" +signature-android = "net.sf.androidscents.signature:android-api-level-19:4.4.2_r4" +signature-java = "org.codehaus.mojo.signature:java18:1.0" +truth = "com.google.truth:truth:1.0.1" diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 2e6e5897b52..070cb702f09 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/grpclb/BUILD.bazel b/grpclb/BUILD.bazel index a18795d0485..0ca4d695bb4 100644 --- a/grpclb/BUILD.bazel +++ b/grpclb/BUILD.bazel @@ -5,6 +5,9 @@ java_library( srcs = glob([ "src/main/java/io/grpc/grpclb/*.java", ]), + resources = glob([ + "src/main/resources/**", + ]), visibility = ["//visibility:public"], deps = [ ":load_balancer_java_grpc", diff --git a/grpclb/build.gradle b/grpclb/build.gradle index 58ff2f412d1..6fd64c84299 100644 --- a/grpclb/build.gradle +++ b/grpclb/build.gradle @@ -4,6 +4,7 @@ plugins { id "com.google.protobuf" id "me.champeau.gradle.japicmp" + id "ru.vyarus.animalsniffer" } description = "gRPC: GRPCLB LoadBalancer plugin" @@ -14,22 +15,24 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), - libraries.protobuf, - libraries.protobuf_util, + libraries.protobuf.java, + libraries.protobuf.java.util, libraries.guava - runtimeOnly libraries.errorprone - compileOnly libraries.javax_annotation + runtimeOnly libraries.errorprone.annotations + compileOnly libraries.javax.annotation testImplementation libraries.truth, project(':grpc-core').sourceSets.test.output + + signature libraries.signature.java } configureProtoCompilation() -javadoc { +tasks.named("javadoc").configure { exclude 'io/grpc/grpclb/Internal*' } -jacocoTestReport { +tasks.named("jacocoTestReport").configure { classDirectories.from = sourceSets.main.output.collect { fileTree(dir: it, exclude: [ diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java index 60f22a2e0e8..4395c8415dc 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java @@ -16,6 +16,7 @@ package io.grpc.grpclb; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; @@ -28,24 +29,31 @@ final class GrpclbConfig { private final Mode mode; @Nullable private final String serviceName; + private final long fallbackTimeoutMs; - private GrpclbConfig(Mode mode, @Nullable String serviceName) { + private GrpclbConfig(Mode mode, @Nullable String serviceName, long fallbackTimeoutMs) { this.mode = checkNotNull(mode, "mode"); this.serviceName = serviceName; + this.fallbackTimeoutMs = fallbackTimeoutMs; } static GrpclbConfig create(Mode mode) { - return create(mode, null); + return create(mode, null, GrpclbState.FALLBACK_TIMEOUT_MS); } - static GrpclbConfig create(Mode mode, @Nullable String serviceName) { - return new GrpclbConfig(mode, serviceName); + static GrpclbConfig create(Mode mode, @Nullable String serviceName, long fallbackTimeoutMs) { + checkArgument(fallbackTimeoutMs > 0, "Invalid timeout (%s)", fallbackTimeoutMs); + return new GrpclbConfig(mode, serviceName, fallbackTimeoutMs); } Mode getMode() { return mode; } + long getFallbackTimeoutMs() { + return fallbackTimeoutMs; + } + /** * If specified, it overrides the name of the sevice name to be sent to the balancer. if not, the * target to be sent to the balancer will continue to be obtained from the target URI passed @@ -65,12 +73,14 @@ public boolean equals(Object o) { return false; } GrpclbConfig that = (GrpclbConfig) o; - return mode == that.mode && Objects.equal(serviceName, that.serviceName); + return mode == that.mode + && Objects.equal(serviceName, that.serviceName) + && fallbackTimeoutMs == that.fallbackTimeoutMs; } @Override public int hashCode() { - return Objects.hashCode(mode, serviceName); + return Objects.hashCode(mode, serviceName, fallbackTimeoutMs); } @Override @@ -78,6 +88,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("mode", mode) .add("serviceName", serviceName) + .add("fallbackTimeoutMs", fallbackTimeoutMs) .toString(); } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java index bda6473b0cc..be09b1ce306 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConstants.java @@ -46,7 +46,7 @@ public final class GrpclbConstants { * Attribute key for gRPC LB server addresses. */ public static final Attributes.Key> ATTR_LB_ADDRS = - Attributes.Key.create("io.grpc.grpclb.lbAddrs"); + Attributes.Key.create("io.grpc.grpclb.GrpclbConstants.ATTR_LB_ADDRS"); /** * The naming authority of a gRPC LB server address. It is an address-group-level attribute, @@ -54,7 +54,7 @@ public final class GrpclbConstants { */ @EquivalentAddressGroup.Attr public static final Attributes.Key ATTR_LB_ADDR_AUTHORITY = - Attributes.Key.create("io.grpc.grpclb.lbAddrAuthority"); + Attributes.Key.create("io.grpc.grpclb.GrpclbConstants.ATTR_LB_ADDR_AUTHORITY"); /** * Whether this EquivalentAddressGroup was provided by a GRPCLB server. It would be rare for this @@ -62,7 +62,7 @@ public final class GrpclbConstants { */ @EquivalentAddressGroup.Attr public static final Attributes.Key ATTR_LB_PROVIDED_BACKEND = - Attributes.Key.create("io.grpc.grpclb.lbProvidedBackend"); + Attributes.Key.create("io.grpc.grpclb.GrpclbConstants.ATTR_LB_PROVIDED_BACKEND"); private GrpclbConstants() { } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 65293d24511..14b06a3b57c 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -76,7 +76,7 @@ class GrpclbLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { Attributes attributes = resolvedAddresses.getAttributes(); List newLbAddresses = attributes.get(GrpclbConstants.ATTR_LB_ADDRS); if (newLbAddresses == null) { @@ -85,7 +85,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (newLbAddresses.isEmpty() && resolvedAddresses.getAddresses().isEmpty()) { handleNameResolutionError( Status.UNAVAILABLE.withDescription("No backend or balancer addresses found")); - return; + return false; } List overrideAuthorityLbAddresses = new ArrayList<>(newLbAddresses.size()); @@ -114,6 +114,8 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { } grpclbState.handleAddresses(Collections.unmodifiableList(overrideAuthorityLbAddresses), newBackendServers); + + return true; } @Override diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java index fa9b6963f33..abb3be77407 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java @@ -89,6 +89,13 @@ ConfigOrError parseLoadBalancingConfigPolicyInternal( } String serviceName = JsonUtil.getString(rawLoadBalancingPolicyConfig, "serviceName"); List rawChildPolicies = JsonUtil.getList(rawLoadBalancingPolicyConfig, "childPolicy"); + Long initialFallbackTimeoutNs = + JsonUtil.getStringAsDuration(rawLoadBalancingPolicyConfig, "initialFallbackTimeout"); + long timeoutMs = GrpclbState.FALLBACK_TIMEOUT_MS; + if (initialFallbackTimeoutNs != null) { + timeoutMs = initialFallbackTimeoutNs / 1000000; + } + List childPolicies = null; if (rawChildPolicies != null) { childPolicies = @@ -97,7 +104,8 @@ ConfigOrError parseLoadBalancingConfigPolicyInternal( } if (childPolicies == null || childPolicies.isEmpty()) { - return ConfigOrError.fromConfig(GrpclbConfig.create(DEFAULT_MODE, serviceName)); + return ConfigOrError.fromConfig( + GrpclbConfig.create(DEFAULT_MODE, serviceName, timeoutMs)); } List policiesTried = new ArrayList<>(); @@ -105,16 +113,18 @@ ConfigOrError parseLoadBalancingConfigPolicyInternal( String childPolicyName = childPolicy.getPolicyName(); switch (childPolicyName) { case "round_robin": - return ConfigOrError.fromConfig(GrpclbConfig.create(Mode.ROUND_ROBIN, serviceName)); + return ConfigOrError.fromConfig( + GrpclbConfig.create(Mode.ROUND_ROBIN, serviceName, timeoutMs)); case "pick_first": - return ConfigOrError.fromConfig(GrpclbConfig.create(Mode.PICK_FIRST, serviceName)); + return ConfigOrError.fromConfig( + GrpclbConfig.create(Mode.PICK_FIRST, serviceName, timeoutMs)); default: policiesTried.add(childPolicyName); } } return ConfigOrError.fromError( Status - .INVALID_ARGUMENT + .UNAVAILABLE .withDescription( "None of " + policiesTried + " specified child policies are available.")); } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index 1eebaa63a8e..49b74645ec8 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -137,6 +137,7 @@ enum Mode { } private final String serviceName; + private final long fallbackTimeoutMs; private final Helper helper; private final Context context; private final SynchronizationContext syncContext; @@ -220,6 +221,7 @@ public void onSubchannelState( } else { this.serviceName = checkNotNull(helper.getAuthority(), "helper returns null authority"); } + this.fallbackTimeoutMs = config.getFallbackTimeoutMs(); this.logger = checkNotNull(helper.getChannelLogger(), "logger"); logger.log(ChannelLogLevel.INFO, "[grpclb-<{0}>] Created", serviceName); } @@ -260,7 +262,7 @@ void handleAddresses( List newBackendServers) { logger.log( ChannelLogLevel.DEBUG, - "[grpclb-<{0}>] Resolved addresses: lb addresses {0}, backends: {1}", + "[grpclb-<{0}>] Resolved addresses: lb addresses {1}, backends: {2}", serviceName, newLbAddressGroups, newBackendServers); @@ -290,9 +292,12 @@ void handleAddresses( // Start the fallback timer if it's never started and we are not already using fallback // backends. if (fallbackTimer == null && !usingFallbackBackends) { - fallbackTimer = syncContext.schedule( - new FallbackModeTask(BALANCER_TIMEOUT_STATUS), FALLBACK_TIMEOUT_MS, - TimeUnit.MILLISECONDS, timerService); + fallbackTimer = + syncContext.schedule( + new FallbackModeTask(BALANCER_TIMEOUT_STATUS), + fallbackTimeoutMs, + TimeUnit.MILLISECONDS, + timerService); } } if (usingFallbackBackends) { diff --git a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java index bc25f28f94c..da5b7c3353e 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java @@ -22,7 +22,11 @@ import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.GrpcUtil; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; /** * A provider for {@code io.grpc.grpclb.GrpclbNameResolver}. @@ -85,5 +89,10 @@ public int priority() { // Must be higher than DnsNameResolverProvider#priority. return 6; } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } } diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerProviderTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerProviderTest.java index dda9700f64a..c291f2da9ad 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerProviderTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerProviderTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; import io.grpc.grpclb.GrpclbState.Mode; import io.grpc.internal.JsonParser; import java.util.Map; @@ -41,6 +42,7 @@ public void retrieveModeFromLbConfig_pickFirst() throws Exception { GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.PICK_FIRST); assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); } @Test @@ -54,6 +56,54 @@ public void retrieveModeFromLbConfig_roundRobin() throws Exception { GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.ROUND_ROBIN); assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); + } + + @Test + public void setTimeoutToLbConfig() throws Exception { + String lbConfig = + "{\"initialFallbackTimeout\" : \"123s\", \"childPolicy\" : [{\"pick_first\" : {}}," + + " {\"round_robin\" : {}}]}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + + assertThat(configOrError.getConfig()).isNotNull(); + GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); + assertThat(config.getMode()).isEqualTo(Mode.PICK_FIRST); + assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(123000); + } + + @Test + public void setInvalidTimeoutToLbConfig() throws Exception { + String lbConfig = + "{\"initialFallbackTimeout\" : \"-1s\", \"childPolicy\" : [{\"pick_first\" : {}}," + + " {\"round_robin\" : {}}]}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + + assertThat(configOrError.getConfig()).isNull(); + assertThat(configOrError.getError()).isNotNull(); + Status errorStatus = configOrError.getError(); + assertThat(errorStatus.getCause()).hasMessageThat().isEqualTo("Invalid timeout (-1000)"); + } + + @Test + public void setInvalidTimeoutDurationProtoToLbConfig() throws Exception { + String lbConfig = + "{\"initialFallbackTimeout\" : \"1000\", \"childPolicy\" : [{\"pick_first\" : {}}," + + " {\"round_robin\" : {}}]}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + + assertThat(configOrError.getError()).isNotNull(); + Status errorStatus = configOrError.getError(); + assertThat(errorStatus.getCause()) + .hasMessageThat() + .isEqualTo("java.text.ParseException: Invalid duration string: 1000"); } @Test @@ -65,6 +115,7 @@ public void retrieveModeFromLbConfig_nullConfigUseRoundRobin() throws Exception GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.ROUND_ROBIN); assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); } @Test @@ -78,6 +129,7 @@ public void retrieveModeFromLbConfig_emptyConfigUseRoundRobin() throws Exception GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.ROUND_ROBIN); assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); } @Test @@ -91,6 +143,7 @@ public void retrieveModeFromLbConfig_emptyChildPolicyUseRoundRobin() throws Exce GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.ROUND_ROBIN); assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); } @Test @@ -117,6 +170,7 @@ public void retrieveModeFromLbConfig_skipUnsupportedChildPolicy() throws Excepti GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.PICK_FIRST); assertThat(config.getServiceName()).isNull(); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); } @Test @@ -131,6 +185,7 @@ public void retrieveModeFromLbConfig_skipUnsupportedChildPolicyWithTarget() thro GrpclbConfig config = (GrpclbConfig) configOrError.getConfig(); assertThat(config.getMode()).isEqualTo(Mode.PICK_FIRST); assertThat(config.getServiceName()).isEqualTo("foo.google.com"); + assertThat(config.getFallbackTimeoutMs()).isEqualTo(GrpclbState.FALLBACK_TIMEOUT_MS); } @Test diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 293c0aa0b82..66fd850802c 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -1179,24 +1179,34 @@ public void roundRobinMode_subchannelStayTransientFailureUntilReady() { @Test public void grpclbFallback_initialTimeout_serverListReceivedBeforeTimerExpires() { - subtestGrpclbFallbackInitialTimeout(false); + subtestGrpclbFallbackTimeout(false, GrpclbState.FALLBACK_TIMEOUT_MS); } @Test public void grpclbFallback_initialTimeout_timerExpires() { - subtestGrpclbFallbackInitialTimeout(true); + subtestGrpclbFallbackTimeout(true, GrpclbState.FALLBACK_TIMEOUT_MS); + } + + @Test + public void grpclbFallback_timeout_serverListReceivedBeforeTimerExpires() { + subtestGrpclbFallbackTimeout(false, 12345); + } + + @Test + public void grpclbFallback_timeout_timerExpires() { + subtestGrpclbFallbackTimeout(true, 12345); } // Fallback or not within the period of the initial timeout. - private void subtestGrpclbFallbackInitialTimeout(boolean timerExpires) { + private void subtestGrpclbFallbackTimeout(boolean timerExpires, long timeout) { long loadReportIntervalMillis = 1983; InOrder inOrder = inOrder(helper, subchannelPool); // Create balancer and backend addresses List backendList = createResolvedBackendAddresses(2); List grpclbBalancerList = createResolvedBalancerAddresses(1); - deliverResolvedAddresses(backendList, grpclbBalancerList); - + deliverResolvedAddresses( + backendList, grpclbBalancerList, GrpclbConfig.create(Mode.ROUND_ROBIN, null, timeout)); inOrder.verify(helper).createOobChannel(eq(xattr(grpclbBalancerList)), eq(lbAuthority(0) + NO_USE_AUTHORITY_SUFFIX)); @@ -1220,7 +1230,7 @@ private void subtestGrpclbFallbackInitialTimeout(boolean timerExpires) { inOrder.verifyNoMoreInteractions(); assertEquals(1, fakeClock.numPendingTasks(FALLBACK_MODE_TASK_FILTER)); - fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS - 1, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(timeout - 1, TimeUnit.MILLISECONDS); assertEquals(1, fakeClock.numPendingTasks(FALLBACK_MODE_TASK_FILTER)); ////////////////////////////////// @@ -1246,7 +1256,10 @@ private void subtestGrpclbFallbackInitialTimeout(boolean timerExpires) { // Name resolver sends new resolution results without any backend addr ////////////////////////////////////////////////////////////////////// grpclbBalancerList = createResolvedBalancerAddresses(2); - deliverResolvedAddresses(Collections.emptyList(),grpclbBalancerList); + deliverResolvedAddresses( + Collections.emptyList(), + grpclbBalancerList, + GrpclbConfig.create(Mode.ROUND_ROBIN, null, timeout)); // New addresses are updated to the OobChannel inOrder.verify(helper).updateOobChannelAddresses( @@ -1276,7 +1289,8 @@ private void subtestGrpclbFallbackInitialTimeout(boolean timerExpires) { subchannelPool.clear(); backendList = createResolvedBackendAddresses(2); grpclbBalancerList = createResolvedBalancerAddresses(1); - deliverResolvedAddresses(backendList, grpclbBalancerList); + deliverResolvedAddresses( + backendList, grpclbBalancerList, GrpclbConfig.create(Mode.ROUND_ROBIN, null, timeout)); // New LB address is updated to the OobChannel inOrder.verify(helper).updateOobChannelAddresses( @@ -1326,7 +1340,8 @@ private void subtestGrpclbFallbackInitialTimeout(boolean timerExpires) { /////////////////////////////////////////////////////////////// backendList = createResolvedBackendAddresses(1); grpclbBalancerList = createResolvedBalancerAddresses(1); - deliverResolvedAddresses(backendList, grpclbBalancerList); + deliverResolvedAddresses( + backendList, grpclbBalancerList, GrpclbConfig.create(Mode.ROUND_ROBIN, null, timeout)); // Will not affect the round robin list at all inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); @@ -2142,16 +2157,23 @@ private void subtestShutdownWithoutSubchannel(GrpclbConfig grpclbConfig) { } @Test - public void pickFirstMode_fallback() throws Exception { + public void pickFirstMode_defaultTimeout_fallback() throws Exception { + pickFirstModeFallback(GrpclbState.FALLBACK_TIMEOUT_MS); + } + + @Test + public void pickFirstMode_serviceConfigTimeout_fallback() throws Exception { + pickFirstModeFallback(12345); + } + + private void pickFirstModeFallback(long timeout) throws Exception { InOrder inOrder = inOrder(helper); // Name resolver returns balancer and backend addresses List backendList = createResolvedBackendAddresses(2); List grpclbBalancerList = createResolvedBalancerAddresses(1); deliverResolvedAddresses( - backendList, - grpclbBalancerList, - GrpclbConfig.create(Mode.PICK_FIRST)); + backendList, grpclbBalancerList, GrpclbConfig.create(Mode.PICK_FIRST, null, timeout)); // Attempted to connect to balancer assertEquals(1, fakeOobChannels.size()); @@ -2160,7 +2182,7 @@ public void pickFirstMode_fallback() throws Exception { assertEquals(1, lbRequestObservers.size()); // Fallback timer expires with no response - fakeClock.forwardTime(GrpclbState.FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(timeout, TimeUnit.MILLISECONDS); // Entering fallback mode inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); @@ -2401,7 +2423,7 @@ public void switchServiceName() throws Exception { deliverResolvedAddresses( Collections.emptyList(), grpclbBalancerList, - GrpclbConfig.create(Mode.ROUND_ROBIN, serviceName)); + GrpclbConfig.create(Mode.ROUND_ROBIN, serviceName, GrpclbState.FALLBACK_TIMEOUT_MS)); assertEquals(1, fakeOobChannels.size()); ManagedChannel oobChannel = fakeOobChannels.poll(); @@ -2443,7 +2465,7 @@ public void switchServiceName() throws Exception { deliverResolvedAddresses( Collections.emptyList(), newGrpclbResolutionList, - GrpclbConfig.create(Mode.ROUND_ROBIN, serviceName)); + GrpclbConfig.create(Mode.ROUND_ROBIN, serviceName, GrpclbState.FALLBACK_TIMEOUT_MS)); // GrpclbState will be shutdown, and a new one will be created assertThat(oobChannel.isShutdown()).isTrue(); @@ -2713,7 +2735,7 @@ private void deliverResolvedAddresses( syncContext.execute(new Runnable() { @Override public void run() { - balancer.handleResolvedAddresses( + balancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(backendAddrs) .setAttributes(attrs) diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 5a13c616143..f8c5561891d 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -3,17 +3,20 @@ plugins { id "java" id "maven-publish" + id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "ru.vyarus.animalsniffer" } description = "gRPC: Integration Testing" -startScripts.enabled = false configurations { alpnagent } +evaluationDependsOn(project(':grpc-core').path) evaluationDependsOn(project(':grpc-context').path) +evaluationDependsOn(project(':grpc-api').path) dependencies { implementation project(path: ':grpc-alts', configuration: 'shadow'), @@ -26,128 +29,179 @@ dependencies { project(':grpc-services'), project(':grpc-stub'), project(':grpc-testing'), - project(path: ':grpc-xds', configuration: 'shadow'), libraries.hdrhistogram, libraries.junit, libraries.truth, - libraries.opencensus_contrib_grpc_metrics, - libraries.google_auth_oauth2_http - compileOnly libraries.javax_annotation + libraries.opencensus.contrib.grpc.metrics, + libraries.google.auth.oauth2Http + def xdsDependency = implementation project(':grpc-xds') + + compileOnly libraries.javax.annotation + shadow project(path: ':grpc-xds', configuration: 'shadow') // TODO(sergiitk): replace with com.google.cloud:google-cloud-logging // Used instead of google-cloud-logging because it's failing // due to a circular dependency on grpc. // https://cloud.google.com/logging/docs/setup/java#the_javautillogging_handler // Error example: "java.util.logging.ErrorManager: 1" // Latest failing version com.google.cloud:google-cloud-logging:2.1.2 - runtimeOnly group: 'io.github.devatherock', name: 'jul-jsonformatter', version: '1.1.0' - runtimeOnly libraries.opencensus_impl, - libraries.netty_tcnative, - project(':grpc-grpclb') + // TODO(ejona): These should be compileOnly, but that doesn't get picked up + // for the shadow runtime + implementation group: 'io.github.devatherock', name: 'jul-jsonformatter', version: '1.1.0' + implementation libraries.opencensus.impl, + libraries.netty.tcnative, + libraries.netty.tcnative.classes, + project(':grpc-grpclb'), + project(':grpc-rls') testImplementation project(':grpc-context').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, project(':grpc-core').sourceSets.test.output, - libraries.mockito - alpnagent libraries.jetty_alpn_agent + libraries.mockito.core, + libraries.okhttp + alpnagent libraries.jetty.alpn.agent + shadow configurations.implementation.getDependencies().minus(xdsDependency) + + signature libraries.signature.java + signature libraries.signature.android } configureProtoCompilation() import net.ltgt.gradle.errorprone.CheckSeverity -compileJava { +tasks.named("compileJava").configure { // This isn't a library; it can use beta APIs options.errorprone.check("BetaApi", CheckSeverity.OFF) } +tasks.named("jar").configure { + // Must use a different archiveClassifier to avoid conflicting with shadowJar + archiveClassifier = 'original' +} + +tasks.named("distZip").configure { + archiveClassifier = "original" +} + +tasks.named("distTar").configure { + archiveClassifier = "original" +} + +def xdsPrefixName = 'io.grpc.xds' +tasks.named("shadowJar").configure { + archiveClassifier = null + dependencies { + exclude(dependency {true}) + } + relocate 'com.github.xds', "${xdsPrefixName}.shaded.com.github.xds" +} -test { +tasks.named("test").configure { // For the automated tests, use Jetty ALPN. jvmArgs "-javaagent:" + configurations.alpnagent.asPath } +tasks.named("startShadowScripts").configure { + enabled = false +} +tasks.named("installDist").configure { + dependsOn installShadowDist + enabled = false +} + // For the generated scripts, use Netty tcnative (i.e. OpenSSL). // Note that OkHttp currently only supports ALPN, so OpenSSL version >= 1.0.2 is required. -task test_client(type: CreateStartScripts) { +def startScriptsClasspath = provider { + shadowJar.outputs.files + configurations.shadow +} + +def test_client = tasks.register("test_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.TestServiceClient" applicationName = "test-client" defaultJvmOpts = [ "-javaagent:JAVAAGENT_APP_HOME" + configurations.alpnagent.singleFile.name ] outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() doLast { unixScript.text = unixScript.text.replace('JAVAAGENT_APP_HOME', '\'"\$APP_HOME"\'/lib/') windowsScript.text = windowsScript.text.replace('JAVAAGENT_APP_HOME', '%APP_HOME%\\lib\\') } } -task test_server(type: CreateStartScripts) { +def test_server = tasks.register("test_server", CreateStartScripts) { mainClass = "io.grpc.testing.integration.TestServiceServer" applicationName = "test-server" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() } -task reconnect_test_client(type: CreateStartScripts) { +def reconnect_test_client = tasks.register("reconnect_test_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.ReconnectTestClient" applicationName = "reconnect-test-client" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() } -task stresstest_client(type: CreateStartScripts) { +def stresstest_client = tasks.register("stresstest_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.StressTestClient" applicationName = "stresstest-client" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() defaultJvmOpts = [ "-verbose:gc", "-XX:+PrintFlagsFinal" ] } -task http2_client(type: CreateStartScripts) { +def http2_client = tasks.register("http2_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.Http2Client" applicationName = "http2-client" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() } -task grpclb_long_lived_affinity_test_client(type: CreateStartScripts) { +def grpclb_long_lived_affinity_test_client = tasks.register("grpclb_long_lived_affinity_test_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.GrpclbLongLivedAffinityTestClient" applicationName = "grpclb-long-lived-affinity-test-client" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() defaultJvmOpts = [ "-Dio.grpc.internal.DnsNameResolverProvider.enable_service_config=true" ] } -task grpclb_fallback_test_client (type: CreateStartScripts) { +def grpclb_fallback_test_client = tasks.register("grpclb_fallback_test_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.GrpclbFallbackTestClient" applicationName = "grpclb-fallback-test-client" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() defaultJvmOpts = [ "-Dio.grpc.internal.DnsNameResolverProvider.enable_service_config=true" ] } -task xds_test_client(type: CreateStartScripts) { +def xds_test_client = tasks.register("xds_test_client", CreateStartScripts) { mainClass = "io.grpc.testing.integration.XdsTestClient" applicationName = "xds-test-client" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() } -task xds_test_server(type: CreateStartScripts) { +def xds_test_server = tasks.register("xds_test_server", CreateStartScripts) { mainClass = "io.grpc.testing.integration.XdsTestServer" applicationName = "xds-test-server" outputDir = new File(project.buildDir, 'tmp/scripts/' + name) - classpath = startScripts.classpath + classpath = startScriptsClasspath.get() } -applicationDistribution.into("bin") { +def xds_federation_test_client = tasks.register("xds_federation_test_client", CreateStartScripts) { + mainClass = "io.grpc.testing.integration.XdsFederationTestClient" + applicationName = "xds-federation-test-client" + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScriptsClasspath.get() +} + +distributions.shadow.contents.into("bin") { from(test_client) from(test_server) from(reconnect_test_client) @@ -157,18 +211,27 @@ applicationDistribution.into("bin") { from(grpclb_fallback_test_client) from(xds_test_client) from(xds_test_server) + from(xds_federation_test_client) fileMode = 0755 } -applicationDistribution.into("lib") { +distributions.shadow.contents.into("lib") { from(configurations.alpnagent) } +distributions.shadow.distributionBaseName = project.name +// to please shadowJar +mainClassName = 'io.grpc.testing.integration.TestServiceClient' + publishing { publications { maven(MavenPublication) { - artifact distZip - artifact distTar + // We want this to throw an exception if it isn't working + def originalJar = artifacts.find { dep -> dep.classifier == 'original'} + artifacts.remove(originalJar) + + artifact shadowDistZip + artifact shadowDistTar } } } diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index e9997a7e4f4..be88c76b2fd 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -18,27 +18,27 @@ private ReconnectServiceGrpc() {} public static final String SERVICE_NAME = "grpc.testing.ReconnectService"; // Static method descriptors that strictly reflect the proto. - private static volatile io.grpc.MethodDescriptor getStartMethod; @io.grpc.stub.annotations.RpcMethod( fullMethodName = SERVICE_NAME + '/' + "Start", - requestType = io.grpc.testing.integration.EmptyProtos.Empty.class, + requestType = io.grpc.testing.integration.Messages.ReconnectParams.class, responseType = io.grpc.testing.integration.EmptyProtos.Empty.class, methodType = io.grpc.MethodDescriptor.MethodType.UNARY) - public static io.grpc.MethodDescriptor getStartMethod() { - io.grpc.MethodDescriptor getStartMethod; + io.grpc.MethodDescriptor getStartMethod; if ((getStartMethod = ReconnectServiceGrpc.getStartMethod) == null) { synchronized (ReconnectServiceGrpc.class) { if ((getStartMethod = ReconnectServiceGrpc.getStartMethod) == null) { ReconnectServiceGrpc.getStartMethod = getStartMethod = - io.grpc.MethodDescriptor.newBuilder() + io.grpc.MethodDescriptor.newBuilder() .setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Start")) .setSampledToLocalTracing(true) .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( - io.grpc.testing.integration.EmptyProtos.Empty.getDefaultInstance())) + io.grpc.testing.integration.Messages.ReconnectParams.getDefaultInstance())) .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( io.grpc.testing.integration.EmptyProtos.Empty.getDefaultInstance())) .setSchemaDescriptor(new ReconnectServiceMethodDescriptorSupplier("Start")) @@ -133,7 +133,7 @@ public static abstract class ReconnectServiceImplBase implements io.grpc.Bindabl /** */ - public void start(io.grpc.testing.integration.EmptyProtos.Empty request, + public void start(io.grpc.testing.integration.Messages.ReconnectParams request, io.grpc.stub.StreamObserver responseObserver) { io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getStartMethod(), responseObserver); } @@ -151,7 +151,7 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, getStartMethod(), io.grpc.stub.ServerCalls.asyncUnaryCall( new MethodHandlers< - io.grpc.testing.integration.EmptyProtos.Empty, + io.grpc.testing.integration.Messages.ReconnectParams, io.grpc.testing.integration.EmptyProtos.Empty>( this, METHODID_START))) .addMethod( @@ -184,7 +184,7 @@ protected ReconnectServiceStub build( /** */ - public void start(io.grpc.testing.integration.EmptyProtos.Empty request, + public void start(io.grpc.testing.integration.Messages.ReconnectParams request, io.grpc.stub.StreamObserver responseObserver) { io.grpc.stub.ClientCalls.asyncUnaryCall( getChannel().newCall(getStartMethod(), getCallOptions()), request, responseObserver); @@ -218,7 +218,7 @@ protected ReconnectServiceBlockingStub build( /** */ - public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.EmptyProtos.Empty request) { + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) { return io.grpc.stub.ClientCalls.blockingUnaryCall( getChannel(), getStartMethod(), getCallOptions(), request); } @@ -251,7 +251,7 @@ protected ReconnectServiceFutureStub build( /** */ public com.google.common.util.concurrent.ListenableFuture start( - io.grpc.testing.integration.EmptyProtos.Empty request) { + io.grpc.testing.integration.Messages.ReconnectParams request) { return io.grpc.stub.ClientCalls.futureUnaryCall( getChannel().newCall(getStartMethod(), getCallOptions()), request); } @@ -286,7 +286,7 @@ private static final class MethodHandlers implements public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { switch (methodId) { case METHODID_START: - serviceImpl.start((io.grpc.testing.integration.EmptyProtos.Empty) request, + serviceImpl.start((io.grpc.testing.integration.Messages.ReconnectParams) request, (io.grpc.stub.StreamObserver) responseObserver); break; case METHODID_STOP: diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 8e107add3e3..76e333be4dd 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -40,6 +40,7 @@ import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.ByteString; import com.google.protobuf.MessageLite; +import com.google.protobuf.StringValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -91,6 +92,7 @@ import io.grpc.testing.integration.Messages.StreamingInputCallResponse; import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; +import io.grpc.testing.integration.Messages.TestOrcaReport; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; import io.opencensus.stats.Measure; import io.opencensus.stats.Measure.MeasureDouble; @@ -100,7 +102,6 @@ import io.opencensus.trace.Span; import io.opencensus.trace.SpanContext; import io.opencensus.trace.Tracing; -import io.opencensus.trace.unsafe.ContextUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -114,8 +115,10 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; @@ -188,6 +191,11 @@ public abstract class AbstractInteropTest { private final LinkedBlockingQueue serverStreamTracers = new LinkedBlockingQueue<>(); + static final CallOptions.Key> + ORCA_RPC_REPORT_KEY = CallOptions.Key.create("orca-rpc-report"); + static final CallOptions.Key> + ORCA_OOB_REPORT_KEY = CallOptions.Key.create("orca-oob-report"); + private static final class ServerStreamTracerInfo { final String fullMethodName; final InteropServerStreamTracer tracer; @@ -236,11 +244,8 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata protected static final Empty EMPTY = Empty.getDefaultInstance(); - private void startServer() { - maybeStartHandshakerServer(); - ServerBuilder builder = getServerBuilder(); + private void configBuilder(@Nullable ServerBuilder builder) { if (builder == null) { - server = null; return; } testServiceExecutor = Executors.newScheduledThreadPool(2); @@ -258,6 +263,14 @@ private void startServer() { new TestServiceImpl(testServiceExecutor), allInterceptors)) .addStreamTracerFactory(serverStreamTracerFactory); + } + + protected void startServer(@Nullable ServerBuilder builder) { + maybeStartHandshakerServer(); + if (builder == null) { + server = null; + return; + } try { server = builder.build().start(); @@ -325,7 +338,9 @@ public ClientCall interceptCall( */ @Before public void setUp() { - startServer(); + ServerBuilder serverBuilder = getServerBuilder(); + configBuilder(serverBuilder); + startServer(serverBuilder); channel = createChannel(); blockingStub = @@ -1059,10 +1074,9 @@ public void veryLargeResponse() throws Exception { public void exchangeMetadataUnaryCall() throws Exception { // Capture the metadata exchange Metadata fixedHeaders = new Metadata(); - // Send a context proto (as it's in the default extension registry) - Messages.SimpleContext contextValue = - Messages.SimpleContext.newBuilder().setValue("dog").build(); - fixedHeaders.put(Util.METADATA_KEY, contextValue); + // Send a metadata proto + StringValue metadataValue = StringValue.newBuilder().setValue("dog").build(); + fixedHeaders.put(Util.METADATA_KEY, metadataValue); // .. and expect it to be echoed back in trailers AtomicReference trailersCapture = new AtomicReference<>(); AtomicReference headersCapture = new AtomicReference<>(); @@ -1073,18 +1087,17 @@ public void exchangeMetadataUnaryCall() throws Exception { assertNotNull(stub.emptyCall(EMPTY)); // Assert that our side channel object is echoed back in both headers and trailers - Assert.assertEquals(contextValue, headersCapture.get().get(Util.METADATA_KEY)); - Assert.assertEquals(contextValue, trailersCapture.get().get(Util.METADATA_KEY)); + Assert.assertEquals(metadataValue, headersCapture.get().get(Util.METADATA_KEY)); + Assert.assertEquals(metadataValue, trailersCapture.get().get(Util.METADATA_KEY)); } @Test public void exchangeMetadataStreamingCall() throws Exception { // Capture the metadata exchange Metadata fixedHeaders = new Metadata(); - // Send a context proto (as it's in the default extension registry) - Messages.SimpleContext contextValue = - Messages.SimpleContext.newBuilder().setValue("dog").build(); - fixedHeaders.put(Util.METADATA_KEY, contextValue); + // Send a metadata proto + StringValue metadataValue = StringValue.newBuilder().setValue("dog").build(); + fixedHeaders.put(Util.METADATA_KEY, metadataValue); // .. and expect it to be echoed back in trailers AtomicReference trailersCapture = new AtomicReference<>(); AtomicReference headersCapture = new AtomicReference<>(); @@ -1115,8 +1128,8 @@ public void exchangeMetadataStreamingCall() throws Exception { org.junit.Assert.assertEquals(responseSizes.size() * numRequests, recorder.getValues().size()); // Assert that our side channel object is echoed back in both headers and trailers - Assert.assertEquals(contextValue, headersCapture.get().get(Util.METADATA_KEY)); - Assert.assertEquals(contextValue, trailersCapture.get().get(Util.METADATA_KEY)); + Assert.assertEquals(metadataValue, headersCapture.get().get(Util.METADATA_KEY)); + Assert.assertEquals(metadataValue, trailersCapture.get().get(Util.METADATA_KEY)); } @Test @@ -1235,7 +1248,7 @@ public void deadlineInPast() throws Exception { } catch (StatusRuntimeException ex) { assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); assertThat(ex.getStatus().getDescription()) - .startsWith("ClientCall started after deadline exceeded"); + .startsWith("ClientCall started after CallOptions deadline was exceeded"); } // CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be @@ -1268,7 +1281,7 @@ public void deadlineInPast() throws Exception { } catch (StatusRuntimeException ex) { assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); assertThat(ex.getStatus().getDescription()) - .startsWith("ClientCall started after deadline exceeded"); + .startsWith("ClientCall started after CallOptions deadline was exceeded"); } if (metricsExpected()) { MetricsRecord clientStartRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); @@ -1547,6 +1560,7 @@ public void customMetadata() throws Exception { Collections.singleton(streamingRequest), Collections.singleton(goldenStreamingResponse)); } + @SuppressWarnings("deprecation") @Test(timeout = 10000) public void censusContextsPropagated() { Assume.assumeTrue("Skip the test because server is not in the same process.", server != null); @@ -1561,7 +1575,7 @@ public void censusContextsPropagated() { .emptyBuilder() .putLocal(StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")) .build()); - ctx = ContextUtils.withValue(ctx, clientParentSpan); + ctx = io.opencensus.trace.unsafe.ContextUtils.withValue(ctx, clientParentSpan); Context origCtx = ctx.attach(); try { blockingStub.unaryCall(SimpleRequest.getDefaultInstance()); @@ -1581,7 +1595,7 @@ public void censusContextsPropagated() { } assertTrue("tag not found", tagFound); - Span span = ContextUtils.getValue(serverCtx); + Span span = io.opencensus.trace.unsafe.ContextUtils.getValue(serverCtx); assertNotNull(span); SpanContext spanContext = span.getContext(); assertEquals(clientParentSpan.getContext().getTraceId(), spanContext.getTraceId()); @@ -1732,6 +1746,91 @@ public void getServerAddressAndLocalAddressFromClient() { assertNotNull(obtainLocalClientAddr()); } + /** + * Test backend metrics per query reporting: expect the test client LB policy to receive load + * reports. + */ + public void testOrcaPerRpc() throws Exception { + AtomicReference reportHolder = new AtomicReference<>(); + TestOrcaReport answer = TestOrcaReport.newBuilder() + .setCpuUtilization(0.8210) + .setMemoryUtilization(0.5847) + .putRequestCost("cost", 3456.32) + .putUtilization("util", 0.30499) + .build(); + blockingStub.withOption(ORCA_RPC_REPORT_KEY, reportHolder).unaryCall( + SimpleRequest.newBuilder().setOrcaPerQueryReport(answer).build()); + assertThat(reportHolder.get()).isEqualTo(answer); + } + + /** + * Test backend metrics OOB reporting: expect the test client LB policy to receive load reports. + */ + public void testOrcaOob() throws Exception { + AtomicReference reportHolder = new AtomicReference<>(); + final TestOrcaReport answer = TestOrcaReport.newBuilder() + .setCpuUtilization(0.8210) + .setMemoryUtilization(0.5847) + .putUtilization("util", 0.30499) + .build(); + final TestOrcaReport answer2 = TestOrcaReport.newBuilder() + .setCpuUtilization(0.29309) + .setMemoryUtilization(0.2) + .putUtilization("util", 100.2039) + .build(); + + final int retryLimit = 5; + BlockingQueue queue = new LinkedBlockingQueue<>(); + final Object lastItem = new Object(); + StreamObserver streamObserver = + asyncStub.fullDuplexCall(new StreamObserver() { + + @Override + public void onNext(StreamingOutputCallResponse value) { + queue.add(value); + } + + @Override + public void onError(Throwable t) { + queue.add(t); + } + + @Override + public void onCompleted() { + queue.add(lastItem); + } + }); + + streamObserver.onNext(StreamingOutputCallRequest.newBuilder() + .setOrcaOobReport(answer) + .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); + assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + int i = 0; + for (; i < retryLimit; i++) { + Thread.sleep(1000); + blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); + if (answer.equals(reportHolder.get())) { + break; + } + } + assertThat(i).isLessThan(retryLimit); + streamObserver.onNext(StreamingOutputCallRequest.newBuilder() + .setOrcaOobReport(answer2) + .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); + assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + + for (i = 0; i < retryLimit; i++) { + Thread.sleep(1000); + blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); + if (reportHolder.get().equals(answer2)) { + break; + } + } + assertThat(i).isLessThan(retryLimit); + streamObserver.onCompleted(); + assertThat(queue.take()).isSameInstanceAs(lastItem); + } + /** Sends a large unary rpc with service account credentials. */ public void serviceAccountCreds(String jsonKey, InputStream credentialsStream, String authScope) throws Exception { @@ -1906,15 +2005,10 @@ public Status getStatus() { private Status status = Status.OK; } - private SoakIterationResult performOneSoakIteration(boolean resetChannel) throws Exception { + private SoakIterationResult performOneSoakIteration( + TestServiceGrpc.TestServiceBlockingStub soakStub) throws Exception { long startNs = System.nanoTime(); Status status = Status.OK; - ManagedChannel soakChannel = channel; - TestServiceGrpc.TestServiceBlockingStub soakStub = blockingStub; - if (resetChannel) { - soakChannel = createChannel(); - soakStub = TestServiceGrpc.newBlockingStub(soakChannel); - } try { final SimpleRequest request = SimpleRequest.newBuilder() @@ -1930,10 +2024,6 @@ private SoakIterationResult performOneSoakIteration(boolean resetChannel) throws status = e.getStatus(); } long elapsedNs = System.nanoTime() - startNs; - if (resetChannel) { - soakChannel.shutdownNow(); - soakChannel.awaitTermination(10, TimeUnit.SECONDS); - } return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); } @@ -1942,62 +2032,83 @@ private SoakIterationResult performOneSoakIteration(boolean resetChannel) throws * and channel creation behavior. */ public void performSoakTest( + String serverUri, boolean resetChannelPerIteration, int soakIterations, int maxFailures, int maxAcceptablePerIterationLatencyMs, + int minTimeMsBetweenRpcs, int overallTimeoutSeconds) throws Exception { int iterationsDone = 0; int totalFailures = 0; Histogram latencies = new Histogram(4 /* number of significant value digits */); long startNs = System.nanoTime(); + ManagedChannel soakChannel = createChannel(); + TestServiceGrpc.TestServiceBlockingStub soakStub = TestServiceGrpc + .newBlockingStub(soakChannel) + .withInterceptors(recordClientCallInterceptor(clientCallCapture)); for (int i = 0; i < soakIterations; i++) { if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { break; } - SoakIterationResult result = performOneSoakIteration(resetChannelPerIteration); - System.err.print( + long earliestNextStartNs = System.nanoTime() + + TimeUnit.MILLISECONDS.toNanos(minTimeMsBetweenRpcs); + if (resetChannelPerIteration) { + soakChannel.shutdownNow(); + soakChannel.awaitTermination(10, TimeUnit.SECONDS); + soakChannel = createChannel(); + soakStub = TestServiceGrpc + .newBlockingStub(soakChannel) + .withInterceptors(recordClientCallInterceptor(clientCallCapture)); + } + SoakIterationResult result = performOneSoakIteration(soakStub); + SocketAddress peer = clientCallCapture + .get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + StringBuilder logStr = new StringBuilder( String.format( - "soak iteration: %d elapsed: %d ms", i, result.getLatencyMs())); + Locale.US, + "soak iteration: %d elapsed_ms: %d peer: %s server_uri: %s", + i, result.getLatencyMs(), peer != null ? peer.toString() : "null", serverUri)); if (!result.getStatus().equals(Status.OK)) { totalFailures++; - System.err.println(String.format(" failed: %s", result.getStatus())); + logStr.append(String.format(" failed: %s", result.getStatus())); } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { totalFailures++; - System.err.println( - String.format( - " exceeds max acceptable latency: %d", maxAcceptablePerIterationLatencyMs)); + logStr.append( + " exceeds max acceptable latency: " + maxAcceptablePerIterationLatencyMs); } else { - System.err.println(" succeeded"); + logStr.append(" succeeded"); } + System.err.println(logStr.toString()); iterationsDone++; latencies.recordValue(result.getLatencyMs()); + long remainingNs = earliestNextStartNs - System.nanoTime(); + if (remainingNs > 0) { + TimeUnit.NANOSECONDS.sleep(remainingNs); + } } + soakChannel.shutdownNow(); + soakChannel.awaitTermination(10, TimeUnit.SECONDS); System.err.println( String.format( - "soak test ran: %d / %d iterations\n" - + "total failures: %d\n" - + "max failures threshold: %d\n" - + "max acceptable per iteration latency ms: %d\n" - + " p50 soak iteration latency: %d ms\n" - + " p90 soak iteration latency: %d ms\n" - + "p100 soak iteration latency: %d ms\n" - + "See breakdown above for which iterations succeeded, failed, and " - + "why for more info.", + Locale.US, + "(server_uri: %s) soak test ran: %d / %d iterations. total failures: %d. " + + "p50: %d ms, p90: %d ms, p100: %d ms", + serverUri, iterationsDone, soakIterations, totalFailures, - maxFailures, - maxAcceptablePerIterationLatencyMs, latencies.getValueAtPercentile(50), latencies.getValueAtPercentile(90), latencies.getValueAtPercentile(100))); // check if we timed out String timeoutErrorMessage = String.format( - "soak test consumed all %d seconds of time and quit early, only " - + "having ran %d out of desired %d iterations.", + Locale.US, + "(server_uri: %s) soak test consumed all %d seconds of time and quit early, " + + "only having ran %d out of desired %d iterations.", + serverUri, overallTimeoutSeconds, iterationsDone, soakIterations); @@ -2005,8 +2116,10 @@ public void performSoakTest( // check if we had too many failures String tooManyFailuresErrorMessage = String.format( - "soak test total failures: %d exceeds max failures threshold: %d.", - totalFailures, maxFailures); + Locale.US, + "(server_uri: %s) soak test total failures: %d exceeds max failures " + + "threshold: %d.", + serverUri, totalFailures, maxFailures); assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); } @@ -2016,7 +2129,7 @@ protected static void assertSuccess(StreamRecorder recorder) { } } - /** Helper for getting remote address from {@link io.grpc.ServerCall#getAttributes()} */ + /** Helper for getting remote address from {@link io.grpc.ServerCall#getAttributes()}. */ protected SocketAddress obtainRemoteClientAddr() { TestServiceGrpc.TestServiceBlockingStub stub = blockingStub.withDeadlineAfter(5, TimeUnit.SECONDS); @@ -2026,7 +2139,7 @@ protected SocketAddress obtainRemoteClientAddr() { return serverCallCapture.get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); } - /** Helper for getting remote address from {@link io.grpc.ClientCall#getAttributes()} */ + /** Helper for getting remote address from {@link io.grpc.ClientCall#getAttributes()}. */ protected SocketAddress obtainRemoteServerAddr() { TestServiceGrpc.TestServiceBlockingStub stub = blockingStub .withInterceptors(recordClientCallInterceptor(clientCallCapture)) @@ -2037,7 +2150,7 @@ protected SocketAddress obtainRemoteServerAddr() { return clientCallCapture.get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); } - /** Helper for getting local address from {@link io.grpc.ServerCall#getAttributes()} */ + /** Helper for getting local address from {@link io.grpc.ServerCall#getAttributes()}. */ protected SocketAddress obtainLocalServerAddr() { TestServiceGrpc.TestServiceBlockingStub stub = blockingStub.withDeadlineAfter(5, TimeUnit.SECONDS); @@ -2047,7 +2160,7 @@ protected SocketAddress obtainLocalServerAddr() { return serverCallCapture.get().getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); } - /** Helper for getting local address from {@link io.grpc.ClientCall#getAttributes()} */ + /** Helper for getting local address from {@link io.grpc.ClientCall#getAttributes()}. */ protected SocketAddress obtainLocalClientAddr() { TestServiceGrpc.TestServiceBlockingStub stub = blockingStub .withInterceptors(recordClientCallInterceptor(clientCallCapture)) @@ -2058,7 +2171,7 @@ protected SocketAddress obtainLocalClientAddr() { return clientCallCapture.get().getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); } - /** Helper for asserting TLS info in SSLSession {@link io.grpc.ServerCall#getAttributes()} */ + /** Helper for asserting TLS info in SSLSession {@link io.grpc.ServerCall#getAttributes()}. */ protected void assertX500SubjectDn(String tlsInfo) { TestServiceGrpc.TestServiceBlockingStub stub = blockingStub.withDeadlineAfter(5, TimeUnit.SECONDS); @@ -2247,9 +2360,10 @@ private void checkTracers( long uncompressedSentSize = 0; int seqNo = 0; for (MessageLite msg : sentMessages) { - assertThat(tracer.nextOutboundEvent()).isEqualTo(String.format("outboundMessage(%d)", seqNo)); + assertThat(tracer.nextOutboundEvent()) + .isEqualTo(String.format(Locale.US, "outboundMessage(%d)", seqNo)); assertThat(tracer.nextOutboundEvent()).matches( - String.format("outboundMessageSent\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo)); + String.format(Locale.US, "outboundMessageSent\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo)); seqNo++; uncompressedSentSize += msg.getSerializedSize(); } @@ -2257,9 +2371,10 @@ private void checkTracers( long uncompressedReceivedSize = 0; seqNo = 0; for (MessageLite msg : receivedMessages) { - assertThat(tracer.nextInboundEvent()).isEqualTo(String.format("inboundMessage(%d)", seqNo)); + assertThat(tracer.nextInboundEvent()) + .isEqualTo(String.format(Locale.US, "inboundMessage(%d)", seqNo)); assertThat(tracer.nextInboundEvent()).matches( - String.format("inboundMessageRead\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo)); + String.format(Locale.US, "inboundMessageRead\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo)); uncompressedReceivedSize += msg.getSerializedSize(); seqNo++; } @@ -2286,10 +2401,10 @@ private void checkCensus(MetricsRecord record, boolean isServer, if (isServer) { assertEquals( requests.size(), - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_REQUEST_COUNT)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_RECEIVED_MESSAGES_PER_RPC)); assertEquals( responses.size(), - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_COUNT)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_SERVER_SENT_MESSAGES_PER_RPC)); assertEquals( uncompressedRequestsSize, record.getMetricAsLongOrFail( @@ -2298,18 +2413,18 @@ private void checkCensus(MetricsRecord record, boolean isServer, uncompressedResponsesSize, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); - assertNotNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_SERVER_LATENCY)); + assertNotNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_SERVER_LATENCY)); // It's impossible to get the expected wire sizes because it may be compressed, so we just // check if they are recorded. - assertNotNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_REQUEST_BYTES)); - assertNotNull(record.getMetric(DeprecatedCensusConstants.RPC_SERVER_RESPONSE_BYTES)); + assertNotNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_RECEIVED_BYTES_PER_RPC)); + assertNotNull(record.getMetric(RpcMeasureConstants.GRPC_SERVER_SENT_BYTES_PER_RPC)); } else { assertEquals( requests.size(), - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)); assertEquals( responses.size(), - record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_RPC)); assertEquals( uncompressedRequestsSize, record.getMetricAsLongOrFail( @@ -2318,11 +2433,11 @@ private void checkCensus(MetricsRecord record, boolean isServer, uncompressedResponsesSize, record.getMetricAsLongOrFail( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); - assertNotNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + assertNotNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)); // It's impossible to get the expected wire sizes because it may be compressed, so we just // check if they are recorded. - assertNotNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); - assertNotNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertNotNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_RPC)); + assertNotNull(record.getMetric(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_RPC)); } } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java b/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java new file mode 100644 index 00000000000..1864afd3c42 --- /dev/null +++ b/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java @@ -0,0 +1,154 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static io.grpc.testing.integration.AbstractInteropTest.ORCA_OOB_REPORT_KEY; +import static io.grpc.testing.integration.AbstractInteropTest.ORCA_RPC_REPORT_KEY; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.services.MetricReport; +import io.grpc.testing.integration.Messages.TestOrcaReport; +import io.grpc.util.ForwardingLoadBalancer; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.xds.orca.OrcaOobUtil; +import io.grpc.xds.orca.OrcaPerRequestUtil; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Implements a test LB policy that receives ORCA load reports. + */ +final class CustomBackendMetricsLoadBalancerProvider extends LoadBalancerProvider { + + static final String TEST_ORCA_LB_POLICY_NAME = "test_backend_metrics_load_balancer"; + private volatile TestOrcaReport latestOobReport; + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new CustomBackendMetricsLoadBalancer(helper); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 0; + } + + @Override + public String getPolicyName() { + return TEST_ORCA_LB_POLICY_NAME; + } + + private final class CustomBackendMetricsLoadBalancer extends ForwardingLoadBalancer { + private LoadBalancer delegate; + + public CustomBackendMetricsLoadBalancer(Helper helper) { + this.delegate = LoadBalancerRegistry.getDefaultRegistry() + .getProvider("pick_first") + .newLoadBalancer(new CustomBackendMetricsLoadBalancerHelper(helper)); + } + + @Override + public LoadBalancer delegate() { + return delegate; + } + + private final class CustomBackendMetricsLoadBalancerHelper + extends ForwardingLoadBalancerHelper { + private final Helper orcaHelper; + + public CustomBackendMetricsLoadBalancerHelper(Helper helper) { + this.orcaHelper = OrcaOobUtil.newOrcaReportingHelper(helper); + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = super.createSubchannel(args); + OrcaOobUtil.setListener(subchannel, new OrcaOobUtil.OrcaOobReportListener() { + @Override + public void onLoadReport(MetricReport orcaLoadReport) { + latestOobReport = fromCallMetricReport(orcaLoadReport); + } + }, + OrcaOobUtil.OrcaReportingConfig.newBuilder() + .setReportInterval(1, TimeUnit.SECONDS) + .build() + ); + return subchannel; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + delegate().updateBalancingState(newState, new MayReportLoadPicker(newPicker)); + } + + @Override + public Helper delegate() { + return orcaHelper; + } + } + + private final class MayReportLoadPicker extends SubchannelPicker { + private SubchannelPicker delegate; + + public MayReportLoadPicker(SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + PickResult result = delegate.pickSubchannel(args); + if (result.getSubchannel() == null) { + return result; + } + AtomicReference reportRef = + args.getCallOptions().getOption(ORCA_OOB_REPORT_KEY); + if (reportRef != null) { + reportRef.set(latestOobReport); + } + + return PickResult.withSubchannel( + result.getSubchannel(), + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + new OrcaPerRequestUtil.OrcaPerRequestReportListener() { + @Override + public void onLoadReport(MetricReport callMetricReport) { + AtomicReference reportRef = + args.getCallOptions().getOption(ORCA_RPC_REPORT_KEY); + reportRef.set(fromCallMetricReport(callMetricReport)); + } + })); + } + } + } + + private static TestOrcaReport fromCallMetricReport(MetricReport callMetricReport) { + return TestOrcaReport.newBuilder() + .setCpuUtilization(callMetricReport.getCpuUtilization()) + .setMemoryUtilization(callMetricReport.getMemoryUtilization()) + .putAllRequestCost(callMetricReport.getRequestCostMetrics()) + .putAllUtilization(callMetricReport.getUtilizationMetrics()) + .build(); + } +} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java index 52c9e8238b5..9fc017c0e35 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java @@ -69,12 +69,13 @@ public void run() { System.exit(0); } - private String unrouteLbAndBackendAddrsCmd = "exit 1"; - private String blackholeLbAndBackendAddrsCmd = "exit 1"; + private String induceFallbackCmd = "exit 1"; private String serverUri; private String customCredentialsType; private String testCase; private Boolean skipNetCmd = false; + private int numWarmupRpcs; + private int fallbackDeadlineSeconds = 1; private ManagedChannel channel; private TestServiceGrpc.TestServiceBlockingStub blockingStub; @@ -103,14 +104,16 @@ private void parseArgs(String[] args) { serverUri = value; } else if ("test_case".equals(key)) { testCase = value; - } else if ("unroute_lb_and_backend_addrs_cmd".equals(key)) { - unrouteLbAndBackendAddrsCmd = value; - } else if ("blackhole_lb_and_backend_addrs_cmd".equals(key)) { - blackholeLbAndBackendAddrsCmd = value; + } else if ("induce_fallback_cmd".equals(key)) { + induceFallbackCmd = value; } else if ("custom_credentials_type".equals(key)) { customCredentialsType = value; } else if ("skip_net_cmd".equals(key)) { skipNetCmd = Boolean.valueOf(value); + } else if ("num_warmup_rpcs".equals(key)) { + numWarmupRpcs = Integer.valueOf(value); + } else if ("fallback_deadline_seconds".equals(key)) { + fallbackDeadlineSeconds = Integer.valueOf(value); } else { System.err.println("Unknown argument: " + key); usage = true; @@ -126,24 +129,25 @@ private void parseArgs(String[] args) { + c.serverUri + "\n --custom_credentials_type Name of Credentials to use. " + "Default: " + c.customCredentialsType - + "\n --unroute_lb_and_backend_addrs_cmd Shell command used to make " - + "LB and backend addresses unroutable. Default: " - + c.unrouteLbAndBackendAddrsCmd - + "\n --blackhole_lb_and_backend_addrs_cmd Shell command used to make " - + "LB and backend addresses black holed. Default: " - + c.blackholeLbAndBackendAddrsCmd + + "\n --induce_fallback_cmd Shell command to induce fallback, e.g. by " + + "making LB and/or backend addresses unroutable or black holed. Default: " + + c.induceFallbackCmd + "\n --skip_net_cmd Skip unroute and blackhole " + "shell command to allow setting the net config outside of the test " + "client. Default: " + c.skipNetCmd + + "\n --num_warmup_rpcs Number of RPCs to perform " + + "on a separate warmup channel before the actual test runs (each warmup " + + "RPC uses a 1 second deadline). Default: " + + c.numWarmupRpcs + + "\n --fallback_deadline_seconds Number of seconds to wait " + + "for fallback to occur after inducing fallback. Default: " + + c.fallbackDeadlineSeconds + "\n --test_case=TEST_CASE Test case to run. Valid options are:" - + "\n fast_fallback_before_startup : fallback before LB connection" - + "\n fast_fallback_after_startup : fallback after startup due to " - + "LB/backend addresses becoming unroutable" - + "\n slow_fallback_before_startup : fallback before LB connection " - + "due to LB/backend addresses being blackholed" - + "\n slow_fallback_after_startup : fallback after startup due to " - + "LB/backend addresses becoming blackholed" + + "\n fallback_before_startup : fallback before startup e.g. due to " + + "LB/backend addresses being unreachable" + + "\n fallback_after_startup : fallback after startup e.g. due to " + + "LB/backend addresses becoming unreachable" + "\n Default: " + c.testCase ); System.exit(1); @@ -197,14 +201,15 @@ private void runShellCmd(String cmd) throws Exception { assertEquals(0, exitCode); } - private GrpclbRouteType doRpcAndGetPath(Deadline deadline) { + private GrpclbRouteType doRpcAndGetPath( + TestServiceGrpc.TestServiceBlockingStub stub, Deadline deadline) { logger.info("doRpcAndGetPath deadline: " + deadline); final SimpleRequest request = SimpleRequest.newBuilder() .setFillGrpclbRouteType(true) .build(); GrpclbRouteType result = GrpclbRouteType.GRPCLB_ROUTE_TYPE_UNKNOWN; try { - SimpleResponse response = blockingStub + SimpleResponse response = stub .withDeadline(deadline) .unaryCall(request); result = response.getGrpclbRouteType(); @@ -226,7 +231,7 @@ private void waitForFallbackAndDoRpcs(Deadline fallbackDeadline) throws Exceptio boolean fallBack = false; while (!fallbackDeadline.isExpired()) { GrpclbRouteType grpclbRouteType = doRpcAndGetPath( - Deadline.after(1, TimeUnit.SECONDS)); + blockingStub, Deadline.after(1, TimeUnit.SECONDS)); if (grpclbRouteType == GrpclbRouteType.GRPCLB_ROUTE_TYPE_BACKEND) { throw new AssertionError("Got grpclb route type backend. Backends are " + "supposed to be unreachable, so this test is broken"); @@ -247,55 +252,57 @@ private void waitForFallbackAndDoRpcs(Deadline fallbackDeadline) throws Exceptio for (int i = 0; i < 30; i++) { assertEquals( GrpclbRouteType.GRPCLB_ROUTE_TYPE_FALLBACK, - doRpcAndGetPath(Deadline.after(20, TimeUnit.SECONDS))); + doRpcAndGetPath(blockingStub, Deadline.after(20, TimeUnit.SECONDS))); Thread.sleep(1000); } } - private void runFastFallbackBeforeStartup() throws Exception { - runShellCmd(unrouteLbAndBackendAddrsCmd); - final Deadline fallbackDeadline = Deadline.after(5, TimeUnit.SECONDS); + private void runFallbackBeforeStartup() throws Exception { + runShellCmd(induceFallbackCmd); + final Deadline fallbackDeadline = Deadline.after( + fallbackDeadlineSeconds, TimeUnit.SECONDS); initStub(); waitForFallbackAndDoRpcs(fallbackDeadline); } - private void runSlowFallbackBeforeStartup() throws Exception { - runShellCmd(blackholeLbAndBackendAddrsCmd); - final Deadline fallbackDeadline = Deadline.after(20, TimeUnit.SECONDS); - initStub(); - waitForFallbackAndDoRpcs(fallbackDeadline); - } - - private void runFastFallbackAfterStartup() throws Exception { + private void runFallbackAfterStartup() throws Exception { initStub(); assertEquals( GrpclbRouteType.GRPCLB_ROUTE_TYPE_BACKEND, - doRpcAndGetPath(Deadline.after(20, TimeUnit.SECONDS))); - runShellCmd(unrouteLbAndBackendAddrsCmd); - final Deadline fallbackDeadline = Deadline.after(40, TimeUnit.SECONDS); + doRpcAndGetPath(blockingStub, Deadline.after(20, TimeUnit.SECONDS))); + runShellCmd(induceFallbackCmd); + final Deadline fallbackDeadline = Deadline.after( + fallbackDeadlineSeconds, TimeUnit.SECONDS); waitForFallbackAndDoRpcs(fallbackDeadline); } - private void runSlowFallbackAfterStartup() throws Exception { - initStub(); - assertEquals( - GrpclbRouteType.GRPCLB_ROUTE_TYPE_BACKEND, - doRpcAndGetPath(Deadline.after(20, TimeUnit.SECONDS))); - runShellCmd(blackholeLbAndBackendAddrsCmd); - final Deadline fallbackDeadline = Deadline.after(40, TimeUnit.SECONDS); - waitForFallbackAndDoRpcs(fallbackDeadline); + // The purpose of this warmup method is to get potentially expensive one-per-process + // initialization out of the way, so that we can use aggressive timeouts in the actual + // test cases. Note that the warmup phase is done using a separate channel from the + // actual test cases, so that we don't affect the states of LB policies in the channel + // of the actual test case. + private void warmup() throws Exception { + logger.info("Begin warmup, performing " + numWarmupRpcs + " RPCs on the warmup channel"); + ManagedChannel channel = createChannel(); + TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel); + for (int i = 0; i < numWarmupRpcs; i++) { + doRpcAndGetPath(stub, Deadline.after(1, TimeUnit.SECONDS)); + } + try { + channel.shutdownNow(); + channel.awaitTermination(1, TimeUnit.SECONDS); + } catch (Exception ex) { + throw new RuntimeException(ex); + } } private void run() throws Exception { + warmup(); logger.info("Begin test case: " + testCase); - if (testCase.equals("fast_fallback_before_startup")) { - runFastFallbackBeforeStartup(); - } else if (testCase.equals("slow_fallback_before_startup")) { - runSlowFallbackBeforeStartup(); - } else if (testCase.equals("fast_fallback_after_startup")) { - runFastFallbackAfterStartup(); - } else if (testCase.equals("slow_fallback_after_startup")) { - runSlowFallbackAfterStartup(); + if (testCase.equals("fallback_before_startup")) { + runFallbackBeforeStartup(); + } else if (testCase.equals("fallback_after_startup")) { + runFallbackAfterStartup(); } else { throw new RuntimeException("invalid testcase: " + testCase); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/ReconnectTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/ReconnectTestClient.java index a89e8788abe..f548a57b270 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/ReconnectTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/ReconnectTestClient.java @@ -25,6 +25,7 @@ import io.grpc.okhttp.OkHttpChannelBuilder; import io.grpc.testing.integration.EmptyProtos.Empty; import io.grpc.testing.integration.Messages.ReconnectInfo; +import io.grpc.testing.integration.Messages.ReconnectParams; /** * Verifies the client is reconnecting the server with correct backoffs @@ -79,12 +80,12 @@ private void runTest() throws Exception { .negotiationType(NegotiationType.TLS).build(); } retryStub = ReconnectServiceGrpc.newBlockingStub(retryChannel); - controlStub.start(Empty.getDefaultInstance()); + controlStub.start(ReconnectParams.getDefaultInstance()); long startTimeStamp = System.currentTimeMillis(); while ((System.currentTimeMillis() - startTimeStamp) < TEST_TIME_MS) { try { - retryStub.start(Empty.getDefaultInstance()); + retryStub.start(ReconnectParams.getDefaultInstance()); } catch (StatusRuntimeException expected) { // Make CheckStyle happy. } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java b/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java new file mode 100644 index 00000000000..83c416765ec --- /dev/null +++ b/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Metadata; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.util.ForwardingLoadBalancer; +import io.grpc.util.ForwardingLoadBalancerHelper; +import java.util.Map; +import javax.annotation.Nonnull; + +/** + * Provides a xDS interop test {@link LoadBalancer} designed to work with {@link XdsTestServer}. It + * looks for an "rpc_behavior" field in its configuration and includes the value in the + * "rpc-behavior" metadata entry that is sent to the server. This will cause the test server to + * behave in a predefined way. Endpoint picking logic is delegated to the + * io.grpc.util.RoundRobinLoadBalancer. + * + *

Initial use case is to prove that a custom load balancer can be configured by the control + * plane via xDS. An interop test will configure this LB and then verify it has been correctly + * configured by observing a specific RPC behavior by the server(s). + * + *

For more details on what behaviors can be specified, please see: + * https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#server + */ +public class RpcBehaviorLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { + String rpcBehavior = JsonUtil.getString(rawLoadBalancingPolicyConfig, "rpcBehavior"); + if (rpcBehavior == null) { + return ConfigOrError.fromError( + Status.UNAVAILABLE.withDescription("no 'rpcBehavior' defined")); + } + return ConfigOrError.fromConfig(new RpcBehaviorConfig(rpcBehavior)); + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + RpcBehaviorHelper rpcBehaviorHelper = new RpcBehaviorHelper(helper); + return new RpcBehaviorLoadBalancer(rpcBehaviorHelper, + LoadBalancerRegistry.getDefaultRegistry().getProvider("round_robin") + .newLoadBalancer(rpcBehaviorHelper)); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "test.RpcBehaviorLoadBalancer"; + } + + static class RpcBehaviorConfig { + + final String rpcBehavior; + + RpcBehaviorConfig(String rpcBehavior) { + this.rpcBehavior = rpcBehavior; + } + } + + /** + * Delegates all calls to another LB and wraps the given helper in {@link RpcBehaviorHelper} that + * assures that the rpc-behavior metadata header gets added to all calls. + */ + static class RpcBehaviorLoadBalancer extends ForwardingLoadBalancer { + + private final RpcBehaviorHelper helper; + private final LoadBalancer delegateLb; + + RpcBehaviorLoadBalancer(RpcBehaviorHelper helper, LoadBalancer delegateLb) { + this.helper = helper; + this.delegateLb = delegateLb; + } + + @Override + protected LoadBalancer delegate() { + return delegateLb; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.setRpcBehavior( + ((RpcBehaviorConfig) resolvedAddresses.getLoadBalancingPolicyConfig()).rpcBehavior); + delegateLb.handleResolvedAddresses(resolvedAddresses); + } + } + + /** + * Wraps the picker that is provided when the balancing change updates with the {@link + * RpcBehaviorPicker} that injects the rpc-behavior metadata entry. + */ + static class RpcBehaviorHelper extends ForwardingLoadBalancerHelper { + + private final Helper delegateHelper; + private String rpcBehavior; + + RpcBehaviorHelper(Helper delegateHelper) { + this.delegateHelper = delegateHelper; + } + + void setRpcBehavior(String rpcBehavior) { + this.rpcBehavior = rpcBehavior; + } + + @Override + protected Helper delegate() { + return delegateHelper; + } + + @Override + public void updateBalancingState(@Nonnull ConnectivityState newState, + @Nonnull SubchannelPicker newPicker) { + delegateHelper.updateBalancingState(newState, new RpcBehaviorPicker(newPicker, rpcBehavior)); + } + } + + /** + * Includes the rpc-behavior metadata entry on each subchannel pick. + */ + static class RpcBehaviorPicker extends SubchannelPicker { + + private static final String RPC_BEHAVIOR_HEADER_KEY = "rpc-behavior"; + + private final SubchannelPicker delegatePicker; + private final String rpcBehavior; + + RpcBehaviorPicker(SubchannelPicker delegatePicker, String rpcBehavior) { + this.delegatePicker = delegatePicker; + this.rpcBehavior = rpcBehavior; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getHeaders() + .put(Metadata.Key.of(RPC_BEHAVIOR_HEADER_KEY, Metadata.ASCII_STRING_MARSHALLER), + rpcBehavior); + return delegatePicker.pickSubchannel(args); + } + } +} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java index 6669d2700d6..0aa5b04b3c3 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java @@ -54,6 +54,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; @@ -231,8 +232,8 @@ void runStressTest() throws Exception { ManagedChannel channel = createChannel(address); channels.add(channel); for (int j = 0; j < stubsPerChannel; j++) { - String gaugeName = - String.format("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j); + String gaugeName = String.format( + Locale.US, "/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j); Worker worker = new Worker(channel, testCaseWeightPairs, durationSecs, gaugeName); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java index 39afaa99d6e..85e5c31a4cb 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java @@ -56,7 +56,9 @@ public enum TestCases { VERY_LARGE_REQUEST("very large request"), PICK_FIRST_UNARY("all requests are sent to one server despite multiple servers are resolved"), RPC_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on the same channel"), - CHANNEL_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on a new channel"); + CHANNEL_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on a new channel"), + ORCA_PER_RPC("report backend metrics per query"), + ORCA_OOB("report backend metrics out-of-band"); private final String description; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index 914db12e5a8..4e3cb90232d 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -22,6 +22,8 @@ import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.ServerBuilder; @@ -60,6 +62,8 @@ public static void main(String[] args) throws Exception { TestUtils.installConscryptIfAvailable(); final TestServiceClient client = new TestServiceClient(); client.parseArgs(args); + customBackendMetricsLoadBalancerProvider = new CustomBackendMetricsLoadBalancerProvider(); + LoadBalancerRegistry.getDefaultRegistry().register(customBackendMetricsLoadBalancerProvider); client.setUp(); try { @@ -89,8 +93,10 @@ public static void main(String[] args) throws Exception { private int soakIterations = 10; private int soakMaxFailures = 0; private int soakPerIterationMaxAcceptableLatencyMs = 1000; + private int soakMinTimeMsBetweenRpcs = 0; private int soakOverallTimeoutSeconds = soakIterations * soakPerIterationMaxAcceptableLatencyMs / 1000; + private static LoadBalancerProvider customBackendMetricsLoadBalancerProvider; private Tester tester = new Tester(); @@ -161,6 +167,8 @@ void parseArgs(String[] args) throws Exception { soakMaxFailures = Integer.parseInt(value); } else if ("soak_per_iteration_max_acceptable_latency_ms".equals(key)) { soakPerIterationMaxAcceptableLatencyMs = Integer.parseInt(value); + } else if ("soak_min_time_ms_between_rpcs".equals(key)) { + soakMinTimeMsBetweenRpcs = Integer.parseInt(value); } else if ("soak_overall_timeout_seconds".equals(key)) { soakOverallTimeoutSeconds = Integer.parseInt(value); } else { @@ -221,6 +229,11 @@ void parseArgs(String[] args) throws Exception { + "\n two soak tests (rpc_soak and channel_soak) should " + "\n take. Default " + c.soakPerIterationMaxAcceptableLatencyMs + + "\n --soak_min_time_ms_between_rpcs " + + "\n The minimum time in milliseconds between consecutive " + + "\n RPCs in a soak test (rpc_soak or channel_soak), " + + "\n useful for limiting QPS. Default: " + + c.soakMinTimeMsBetweenRpcs + "\n --soak_overall_timeout_seconds " + "\n The overall number of seconds after which a soak test " + "\n should stop and fail, if the desired number of " @@ -239,6 +252,10 @@ void setUp() { private synchronized void tearDown() { try { tester.tearDown(); + if (customBackendMetricsLoadBalancerProvider != null) { + LoadBalancerRegistry.getDefaultRegistry() + .deregister(customBackendMetricsLoadBalancerProvider); + } } catch (RuntimeException ex) { throw ex; } catch (Exception ex) { @@ -444,22 +461,37 @@ private void runTest(TestCases testCase) throws Exception { case RPC_SOAK: { tester.performSoakTest( + serverHost, false /* resetChannelPerIteration */, soakIterations, soakMaxFailures, soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, soakOverallTimeoutSeconds); break; } case CHANNEL_SOAK: { tester.performSoakTest( + serverHost, true /* resetChannelPerIteration */, soakIterations, soakMaxFailures, soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, soakOverallTimeoutSeconds); break; + + } + + case ORCA_PER_RPC: { + tester.testOrcaPerRpc(); + break; + } + + case ORCA_OOB: { + tester.testOrcaOob(); + break; } default: @@ -599,6 +631,11 @@ protected ServerBuilder getHandshakerServerBuilder() { return null; } } + + @Override + protected int operationTimeoutMillis() { + return 15000; + } } private static String validTestCasesHelpText() { diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java index 5fe7248b2bd..ea77f13892a 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java @@ -25,7 +25,10 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.grpc.internal.LogExceptionRunnable; +import io.grpc.services.CallMetricRecorder; +import io.grpc.services.MetricRecorder; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.integration.Messages.Payload; @@ -36,15 +39,19 @@ import io.grpc.testing.integration.Messages.StreamingInputCallResponse; import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; +import io.grpc.testing.integration.Messages.TestOrcaReport; import java.util.ArrayDeque; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.Random; import java.util.Set; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; @@ -57,13 +64,20 @@ public class TestServiceImpl extends TestServiceGrpc.TestServiceImplBase { private final ScheduledExecutorService executor; private final ByteString compressableBuffer; + private final MetricRecorder metricRecorder; + final Semaphore lock = new Semaphore(1); /** * Constructs a controller using the given executor for scheduling response stream chunks. */ - public TestServiceImpl(ScheduledExecutorService executor) { + public TestServiceImpl(ScheduledExecutorService executor, MetricRecorder metricRecorder) { this.executor = executor; this.compressableBuffer = ByteString.copyFrom(new byte[1024]); + this.metricRecorder = metricRecorder; + } + + public TestServiceImpl(ScheduledExecutorService executor) { + this(executor, MetricRecorder.newInstance()); } @Override @@ -112,10 +126,34 @@ public void unaryCall(SimpleRequest req, StreamObserver response return; } + if (req.hasOrcaPerQueryReport()) { + echoCallMetricsFromPayload(req.getOrcaPerQueryReport()); + } responseObserver.onNext(responseBuilder.build()); responseObserver.onCompleted(); } + private static void echoCallMetricsFromPayload(TestOrcaReport report) { + CallMetricRecorder recorder = CallMetricRecorder.getCurrent() + .recordCpuUtilizationMetric(report.getCpuUtilization()) + .recordMemoryUtilizationMetric(report.getMemoryUtilization()); + for (Map.Entry entry : report.getUtilizationMap().entrySet()) { + recorder.recordUtilizationMetric(entry.getKey(), entry.getValue()); + } + for (Map.Entry entry : report.getRequestCostMap().entrySet()) { + recorder.recordRequestCostMetric(entry.getKey(), entry.getValue()); + } + } + + private void echoMetricsFromPayload(TestOrcaReport report) { + metricRecorder.setCpuUtilizationMetric(report.getCpuUtilization()); + metricRecorder.setMemoryUtilizationMetric(report.getMemoryUtilization()); + metricRecorder.setAllUtilizationMetrics(new HashMap<>()); + for (Map.Entry entry : report.getUtilizationMap().entrySet()) { + metricRecorder.putUtilizationMetric(entry.getKey(), entry.getValue()); + } + } + /** * Given a request that specifies chunk size and interval between responses, creates and schedules * the response stream. @@ -165,8 +203,25 @@ public StreamObserver fullDuplexCall( final StreamObserver responseObserver) { final ResponseDispatcher dispatcher = new ResponseDispatcher(responseObserver); return new StreamObserver() { + boolean oobTestLocked; + @Override public void onNext(StreamingOutputCallRequest request) { + // to facilitate multiple clients running orca_oob test in parallel, the server allows + // only one orca_oob test client to run at a time to avoid conflicts. + if (request.hasOrcaOobReport()) { + if (!oobTestLocked) { + try { + lock.acquire(); + } catch (InterruptedException ex) { + responseObserver.onError(new StatusRuntimeException( + Status.ABORTED.withDescription("server service interrupted").withCause(ex))); + return; + } + oobTestLocked = true; + } + echoMetricsFromPayload(request.getOrcaOobReport()); + } if (request.hasResponseStatus()) { dispatcher.cancel(); dispatcher.onError(Status.fromCodeValue(request.getResponseStatus().getCode()) @@ -179,6 +234,10 @@ public void onNext(StreamingOutputCallRequest request) { @Override public void onCompleted() { + if (oobTestLocked) { + lock.release(); + oobTestLocked = false; + } if (!dispatcher.isCancelled()) { // Tell the dispatcher that all input has been received. dispatcher.completeInput(); @@ -187,6 +246,10 @@ public void onCompleted() { @Override public void onError(Throwable cause) { + if (oobTestLocked) { + lock.release(); + oobTestLocked = false; + } dispatcher.onError(cause); } }; @@ -351,7 +414,7 @@ private void scheduleNextChunk() { // Schedule the next response chunk if there is one. Chunk nextChunk = chunks.peek(); - if (nextChunk != null) { + if (nextChunk != null && !executor.isShutdown()) { scheduled = true; // TODO(ejona): cancel future if RPC is cancelled Future unused = executor.schedule(new LogExceptionRunnable(dispatchTask), diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index 19946ec4a7d..a2966685f33 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.BindableService; import io.grpc.Grpc; import io.grpc.InsecureServerCredentials; import io.grpc.Server; @@ -26,6 +27,9 @@ import io.grpc.TlsServerCredentials; import io.grpc.alts.AltsServerCredentials; import io.grpc.internal.testing.TestUtils; +import io.grpc.services.MetricRecorder; +import io.grpc.xds.orca.OrcaMetricReportingServerInterceptor; +import io.grpc.xds.orca.OrcaServiceImpl; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -151,11 +155,16 @@ void start() throws Exception { } else { serverCreds = InsecureServerCredentials.create(); } + MetricRecorder metricRecorder = MetricRecorder.newInstance(); + BindableService orcaOobService = + OrcaServiceImpl.createService(executor, metricRecorder, 1, TimeUnit.SECONDS); server = Grpc.newServerBuilderForPort(port, serverCreds) .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .addService( ServerInterceptors.intercept( - new TestServiceImpl(executor), TestServiceImpl.interceptors())) + new TestServiceImpl(executor, metricRecorder), TestServiceImpl.interceptors())) + .addService(orcaOobService) + .intercept(OrcaMetricReportingServerInterceptor.getInstance()) .build() .start(); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/Util.java b/interop-testing/src/main/java/io/grpc/testing/integration/Util.java index d75661132a9..b66114f12c0 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/Util.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/Util.java @@ -17,6 +17,7 @@ package io.grpc.testing.integration; import com.google.protobuf.MessageLite; +import com.google.protobuf.StringValue; import io.grpc.Metadata; import io.grpc.protobuf.lite.ProtoLiteUtils; import java.util.List; @@ -27,10 +28,10 @@ */ public class Util { - public static final Metadata.Key METADATA_KEY = + public static final Metadata.Key METADATA_KEY = Metadata.Key.of( - "grpc.testing.SimpleContext" + Metadata.BINARY_HEADER_SUFFIX, - ProtoLiteUtils.metadataMarshaller(Messages.SimpleContext.getDefaultInstance())); + "google.protobuf.StringValue" + Metadata.BINARY_HEADER_SUFFIX, + ProtoLiteUtils.metadataMarshaller(StringValue.getDefaultInstance())); public static final Metadata.Key ECHO_INITIAL_METADATA_KEY = Metadata.Key.of("x-grpc-test-echo-initial", Metadata.ASCII_STRING_MARSHALLER); public static final Metadata.Key ECHO_TRAILING_METADATA_KEY diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java new file mode 100644 index 00000000000..8f166b6affa --- /dev/null +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java @@ -0,0 +1,296 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertTrue; + +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannelBuilder; +import io.grpc.alts.ComputeEngineChannelCredentials; +import io.grpc.netty.NettyChannelBuilder; +import java.util.ArrayList; +import java.util.logging.Logger; + +/** + * Test client that can be used to verify that XDS federation works. A list of + * server URIs (which can each be load balanced by different XDS servers), can + * be configured via flags. A separate thread is created for each of these clients + * and the configured test (either rpc_soak or channel_soak) is ran for each client + * on each thread. + */ +public final class XdsFederationTestClient { + private static final Logger logger = + Logger.getLogger(XdsFederationTestClient.class.getName()); + + /** + * Entry point. + */ + public static void main(String[] args) throws Exception { + final XdsFederationTestClient client = new XdsFederationTestClient(); + client.parseArgs(args); + Runtime.getRuntime() + .addShutdownHook( + new Thread() { + @Override + @SuppressWarnings("CatchAndPrintStackTrace") + public void run() { + System.out.println("Shutting down"); + try { + client.tearDown(); + } catch (RuntimeException e) { + e.printStackTrace(); + } + } + }); + client.setUp(); + try { + client.run(); + } finally { + client.tearDown(); + } + System.exit(0); + } + + private String serverUris = ""; + private String credentialsTypes = ""; + private int soakIterations = 10; + private int soakMaxFailures = 0; + private int soakPerIterationMaxAcceptableLatencyMs = 1000; + private int soakOverallTimeoutSeconds = 10; + private int soakMinTimeMsBetweenRpcs = 0; + private String testCase = "rpc_soak"; + private final ArrayList clients = new ArrayList<>(); + + private void parseArgs(String[] args) { + boolean usage = false; + for (String arg : args) { + if (!arg.startsWith("--")) { + System.err.println("All arguments must start with '--': " + arg); + usage = true; + break; + } + String[] parts = arg.substring(2).split("=", 2); + String key = parts[0]; + if (key.equals("help")) { + usage = true; + break; + } + if (parts.length != 2) { + System.err.println("All arguments must be of the form --arg=value"); + usage = true; + break; + } + String value = parts[1]; + switch (key) { + case "server_uris": + serverUris = value; + break; + case "credentials_types": + credentialsTypes = value; + break; + case "test_case": + testCase = value; + break; + case "soak_iterations": + soakIterations = Integer.parseInt(value); + break; + case "soak_max_failures": + soakMaxFailures = Integer.parseInt(value); + break; + case "soak_per_iteration_max_acceptable_latency_ms": + soakPerIterationMaxAcceptableLatencyMs = Integer.parseInt(value); + break; + case "soak_overall_timeout_seconds": + soakOverallTimeoutSeconds = Integer.parseInt(value); + break; + case "soak_min_time_ms_between_rpcs": + soakMinTimeMsBetweenRpcs = Integer.parseInt(value); + break; + default: + System.err.println("Unknown argument: " + key); + usage = true; + break; + } + } + if (usage) { + XdsFederationTestClient c = new XdsFederationTestClient(); + System.out.println( + "Usage: [ARGS...]" + + "\n" + + "\n --server_uris Comma separated list of server " + + "URIs to make RPCs to. Default: " + + c.serverUris + + "\n --credentials_types Comma-separated list of " + + "\n credentials, each entry is used " + + "\n for the server of the " + + "\n corresponding index in server_uris. " + + "\n Supported values: " + + "compute_engine_channel_creds,INSECURE_CREDENTIALS. Default: " + + c.credentialsTypes + + "\n --soak_iterations The number of iterations to use " + + "\n for the two tests: rpc_soak and " + + "\n channel_soak. Default: " + + c.soakIterations + + "\n --soak_max_failures The number of iterations in soak " + + "\n tests that are allowed to fail " + + "\n (either due to non-OK status code " + + "\n or exceeding the per-iteration max " + + "\n acceptable latency). Default: " + + c.soakMaxFailures + + "\n --soak_per_iteration_max_acceptable_latency_ms" + + "\n The number of milliseconds a " + + "\n single iteration in the two soak " + + "\n tests (rpc_soak and channel_soak) " + + "\n should take. Default: " + + c.soakPerIterationMaxAcceptableLatencyMs + + "\n --soak_overall_timeout_seconds The overall number of seconds " + + "\n after which a soak test should " + + "\n stop and fail, if the desired " + + "\n number of iterations have not yet " + + "\n completed. Default: " + + c.soakOverallTimeoutSeconds + + "\n --soak_min_time_ms_between_rpcs The minimum time in milliseconds " + + "\n between consecutive RPCs in a soak " + + "\n test (rpc_soak or channel_soak), " + + "\n useful for limiting QPS. Default: " + + c.soakMinTimeMsBetweenRpcs + + "\n --test_case=TEST_CASE Test case to run. Valid options are:" + + "\n rpc_soak: sends --soak_iterations large_unary RPCs" + + "\n channel_soak: sends --soak_iterations RPCs, rebuilding the channel " + + "each time." + + "\n Default: " + c.testCase + ); + System.exit(1); + } + } + + void setUp() { + String[] uris = serverUris.split(",", -1); + String[] creds = credentialsTypes.split(",", -1); + if (uris.length == 0) { + throw new IllegalArgumentException("--server_uris is empty"); + } + if (uris.length != creds.length) { + throw new IllegalArgumentException("Number of entries in --server_uris " + + "does not match number of entries in --credentials_types"); + } + for (int i = 0; i < uris.length; i++) { + clients.add(new InnerClient(creds[i], uris[i])); + } + for (InnerClient c : clients) { + c.setUp(); + } + } + + private synchronized void tearDown() { + for (InnerClient c : clients) { + c.tearDown(); + } + } + + /** + * Wraps a single client stub configuration and executes a + * soak test case with that configuration. + */ + class InnerClient extends AbstractInteropTest { + private final String credentialsType; + private final String serverUri; + private boolean runSucceeded = false; + + public InnerClient(String credentialsType, String serverUri) { + this.credentialsType = credentialsType; + this.serverUri = serverUri; + } + + /** + * Indicates whether run succeeded or not. This must only be called + * after run() has finished. + */ + public boolean runSucceeded() { + return runSucceeded; + } + + /** + * Run the intended soak test. + */ + public void run() { + boolean resetChannelPerIteration; + switch (testCase) { + case "rpc_soak": + resetChannelPerIteration = false; + break; + case "channel_soak": + resetChannelPerIteration = true; + break; + default: + throw new RuntimeException("invalid testcase: " + testCase); + } + try { + performSoakTest( + serverUri, + resetChannelPerIteration, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, + soakOverallTimeoutSeconds); + logger.info("Test case: " + testCase + " done for server: " + serverUri); + runSucceeded = true; + } catch (Exception e) { + logger.info("Test case: " + testCase + " failed for server: " + serverUri); + throw new RuntimeException(e); + } + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + ChannelCredentials channelCredentials; + switch (credentialsType) { + case "compute_engine_channel_creds": + channelCredentials = ComputeEngineChannelCredentials.create(); + break; + case "INSECURE_CREDENTIALS": + channelCredentials = InsecureChannelCredentials.create(); + break; + default: + throw new IllegalArgumentException("Unknown custom credentials: " + credentialsType); + } + return NettyChannelBuilder.forTarget(serverUri, channelCredentials) + .keepAliveTime(3600, SECONDS) + .keepAliveTimeout(20, SECONDS); + } + } + + private void run() throws Exception { + logger.info("Begin test case: " + testCase); + ArrayList threads = new ArrayList<>(); + for (InnerClient c : clients) { + Thread t = new Thread(c::run); + t.start(); + threads.add(t); + } + for (Thread t : threads) { + t.join(); + } + for (InnerClient c : clients) { + assertTrue(c.runSucceeded()); + } + logger.info("Test case: " + testCase + " done for all clients!"); + } +} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java index 6b3e7213cfb..19a3c44259f 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java @@ -59,9 +59,11 @@ public final class XdsTestServer { "succeed-on-retry-attempt-"; private static final String CALL_BEHAVIOR_ERROR_CODE = "error-code-"; + private static final String CALL_BEHAVIOR_HOSTNAME = "hostname="; private static final Splitter HEADER_VALUE_SPLITTER = Splitter.on(',') .trimResults() .omitEmptyStrings(); + private static final Splitter HEADER_HOSTNAME_SPLITTER = Splitter.on(' '); private static Logger logger = Logger.getLogger(XdsTestServer.class.getName()); @@ -300,8 +302,38 @@ public void sendHeaders(Metadata responseHeaders) { }; ServerCall.Listener noopListener = new ServerCall.Listener() {}; - // sleep if instructed by rpc-behavior + int attemptNum = 0; + String attemptNumHeader = requestHeaders.get(ATTEMPT_NUM); + if (attemptNumHeader != null) { + try { + attemptNum = Integer.valueOf(attemptNumHeader); + } catch (NumberFormatException e) { + newCall.close( + Status.INVALID_ARGUMENT.withDescription( + "Invalid format for grpc-previous-rpc-attempts header: " + attemptNumHeader), + new Metadata()); + return noopListener; + } + } + for (String callBehavior : callBehaviors) { + if (callBehavior.startsWith(CALL_BEHAVIOR_HOSTNAME)) { + List splitHeader = HEADER_HOSTNAME_SPLITTER.splitToList(callBehavior); + if (splitHeader.size() > 1) { + if (!splitHeader.get(0).substring(CALL_BEHAVIOR_HOSTNAME.length()).equals(host)) { + continue; + } + callBehavior = splitHeader.get(1); + } else { + newCall.close( + Status.INVALID_ARGUMENT.withDescription( + "Invalid format for rpc-behavior header: " + callBehavior), + new Metadata() + ); + return noopListener; + } + } + if (callBehavior.startsWith(CALL_BEHAVIOR_SLEEP_VALUE)) { try { int timeout = Integer.parseInt( @@ -310,7 +342,7 @@ public void sendHeaders(Metadata responseHeaders) { } catch (NumberFormatException e) { newCall.close( Status.INVALID_ARGUMENT.withDescription( - String.format("Invalid format for rpc-behavior header (%s)", callBehavior)), + "Invalid format for rpc-behavior header: " + callBehavior), new Metadata()); return noopListener; } catch (InterruptedException e) { @@ -321,51 +353,29 @@ public void sendHeaders(Metadata responseHeaders) { return noopListener; } } - } - // succeed the retry attempt if instructed by rpc-behavior - int succeedOnAttemptNum = Integer.MAX_VALUE; - for (String callBehavior : callBehaviors) { if (callBehavior.startsWith(CALL_BEHAVIOR_SUCCEED_ON_RETRY_ATTEMPT_VALUE)) { + int succeedOnAttemptNum = Integer.MAX_VALUE; try { succeedOnAttemptNum = Integer.parseInt( callBehavior.substring(CALL_BEHAVIOR_SUCCEED_ON_RETRY_ATTEMPT_VALUE.length())); } catch (NumberFormatException e) { newCall.close( Status.INVALID_ARGUMENT.withDescription( - String.format("Invalid format for rpc-behavior header (%s)", callBehavior)), + "Invalid format for rpc-behavior header: " + callBehavior), new Metadata()); return noopListener; } - break; + if (attemptNum == succeedOnAttemptNum) { + return next.startCall(newCall, requestHeaders); + } } - } - int attemptNum = 0; - String attemptNumHeader = requestHeaders.get(ATTEMPT_NUM); - if (attemptNumHeader != null) { - try { - attemptNum = Integer.valueOf(attemptNumHeader); - } catch (NumberFormatException e) { - newCall.close( - Status.INVALID_ARGUMENT.withDescription( - String.format( - "Invalid format for grpc-previous-rpc-attempts header (%s)", - attemptNumHeader)), - new Metadata()); + + // hang if instructed by rpc-behavior + if (callBehavior.equals(CALL_BEHAVIOR_KEEP_OPEN_VALUE)) { return noopListener; } - } - if (attemptNum == succeedOnAttemptNum) { - return next.startCall(newCall, requestHeaders); - } - - // hang if instructed by rpc-behavior - if (callBehaviors.contains(CALL_BEHAVIOR_KEEP_OPEN_VALUE)) { - return noopListener; - } - // fail if instructed by rpc-behavior - for (String callBehavior : callBehaviors) { if (callBehavior.startsWith(CALL_BEHAVIOR_ERROR_CODE)) { try { int codeValue = Integer.valueOf( @@ -378,7 +388,7 @@ public void sendHeaders(Metadata responseHeaders) { } catch (NumberFormatException e) { newCall.close( Status.INVALID_ARGUMENT.withDescription( - String.format("Invalid format for rpc-behavior header (%s)", callBehavior)), + "Invalid format for rpc-behavior header: " + callBehavior), new Metadata()); return noopListener; } diff --git a/interop-testing/src/main/proto/grpc/testing/empty.proto b/interop-testing/src/main/proto/grpc/testing/empty.proto index bd626abe522..43779012dc4 100644 --- a/interop-testing/src/main/proto/grpc/testing/empty.proto +++ b/interop-testing/src/main/proto/grpc/testing/empty.proto @@ -1,4 +1,5 @@ -// Copyright 2015 The gRPC Authors + +// Copyright 2015 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,7 +12,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -syntax = "proto2"; + +syntax = "proto3"; package grpc.testing; diff --git a/interop-testing/src/main/proto/grpc/testing/messages.proto b/interop-testing/src/main/proto/grpc/testing/messages.proto index 9ac53cd89e1..fbcb6b4ce9b 100644 --- a/interop-testing/src/main/proto/grpc/testing/messages.proto +++ b/interop-testing/src/main/proto/grpc/testing/messages.proto @@ -1,4 +1,5 @@ -// Copyright 2015 The gRPC Authors + +// Copyright 2015-2016 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,16 +21,24 @@ package grpc.testing; option java_package = "io.grpc.testing.integration"; -// TODO(jihuncho): Use well-known types once all languages are synced. +// TODO(dgq): Go back to using well-known types once +// https://github.com/grpc/grpc/issues/6980 has been fixed. +// import "google/protobuf/wrappers.proto"; message BoolValue { // The bool value. bool value = 1; } +// The type of payload that should be returned. +enum PayloadType { + // Compressable text format. + COMPRESSABLE = 0; +} + // A block of data, to simply increase gRPC message size. message Payload { - reserved 1; - + // The type of data in body. + PayloadType type = 1; // Primary contents of payload. bytes body = 2; } @@ -58,7 +67,9 @@ enum GrpclbRouteType { // Unary request. message SimpleRequest { - reserved 1; + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + PayloadType response_type = 1; // Desired payload size in the response from the server. int32 response_size = 2; @@ -89,6 +100,9 @@ message SimpleRequest { // Whether SimpleResponse should include grpclb_route_type. bool fill_grpclb_route_type = 10; + + // If set the server should record this metrics report data for the current RPC. + TestOrcaReport orca_per_query_report = 11; } // Unary response, as configured by the request. @@ -106,14 +120,11 @@ message SimpleResponse { string server_id = 4; // gRPCLB Path. GrpclbRouteType grpclb_route_type = 5; + // Server hostname. string hostname = 6; } -message SimpleContext { - string value = 1; -} - // Client-streaming request. message StreamingInputCallRequest { // Optional input payload sent along with the request. @@ -152,7 +163,11 @@ message ResponseParameters { // Server-streaming request. message StreamingOutputCallRequest { - reserved 1; + // Desired payload type in the response from the server. + // If response_type is RANDOM, the payload from each response in the stream + // might be of different types. This is to simulate a mixed type of payload + // stream. + PayloadType response_type = 1; // Configuration for each expected response message. repeated ResponseParameters response_parameters = 2; @@ -162,6 +177,9 @@ message StreamingOutputCallRequest { // Whether server should return a given status EchoStatus response_status = 7; + + // If set the server should update this metrics report data at the OOB server. + TestOrcaReport orca_oob_report = 8; } // Server-streaming response, as configured by the request and parameters. @@ -258,3 +276,13 @@ message ClientConfigureRequest { // Response for updating a test client's configuration. message ClientConfigureResponse {} + +// Metrics data the server will update and send to the client. It mirrors orca load report +// https://github.com/cncf/xds/blob/eded343319d09f30032952beda9840bbd3dcf7ac/xds/data/orca/v3/orca_load_report.proto#L15, +// but avoids orca dependency. Used by both per-query and out-of-band reporting tests. +message TestOrcaReport { + double cpu_utilization = 1; + double memory_utilization = 2; + map request_cost = 3; + map utilization = 4; +} diff --git a/interop-testing/src/main/proto/grpc/testing/test.proto b/interop-testing/src/main/proto/grpc/testing/test.proto index f8b0927cc98..efe83b88261 100644 --- a/interop-testing/src/main/proto/grpc/testing/test.proto +++ b/interop-testing/src/main/proto/grpc/testing/test.proto @@ -1,4 +1,5 @@ -// Copyright 2015 The gRPC Authors + +// Copyright 2015-2016 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,8 +12,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + // An integration test service that covers all the method signature permutations // of unary/streaming requests/responses. + syntax = "proto3"; import "grpc/testing/empty.proto"; @@ -73,7 +76,7 @@ service UnimplementedService { // A service used to control reconnect server. service ReconnectService { - rpc Start(grpc.testing.Empty) returns (grpc.testing.Empty); + rpc Start(grpc.testing.ReconnectParams) returns (grpc.testing.Empty); rpc Stop(grpc.testing.Empty) returns (grpc.testing.ReconnectInfo); } @@ -81,11 +84,11 @@ service ReconnectService { service LoadBalancerStatsService { // Gets the backend distribution for RPCs sent by a test client. rpc GetClientStats(LoadBalancerStatsRequest) - returns (LoadBalancerStatsResponse) {} + returns (LoadBalancerStatsResponse) {} // Gets the accumulated stats for RPCs sent by a test client. rpc GetClientAccumulatedStats(LoadBalancerAccumulatedStatsRequest) - returns (LoadBalancerAccumulatedStatsResponse) {} + returns (LoadBalancerAccumulatedStatsResponse) {} } // A service to remotely control health status of an xDS test server. diff --git a/interop-testing/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/interop-testing/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider new file mode 100644 index 00000000000..3a60a58e533 --- /dev/null +++ b/interop-testing/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -0,0 +1 @@ +io.grpc.testing.integration.RpcBehaviorLoadBalancerProvider diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java index b566a8c888f..fab130a4543 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/ConcurrencyTest.java @@ -16,8 +16,10 @@ package io.grpc.testing.integration; -import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.ManagedChannel; @@ -32,7 +34,8 @@ import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; import java.io.File; import java.io.IOException; -import java.util.concurrent.CountDownLatch; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -40,9 +43,7 @@ import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -58,28 +59,29 @@ @RunWith(JUnit4.class) public class ConcurrencyTest { - @Rule public final Timeout globalTimeout = Timeout.seconds(60); - /** - * A response observer that signals a {@code CountDownLatch} when the proper number of responses - * arrives and the server signals that the RPC is complete. + * A response observer that completes a {@code ListenableFuture} when the proper number of + * responses arrives and the server signals that the RPC is complete. */ private static class SignalingResponseObserver implements StreamObserver { - public SignalingResponseObserver(CountDownLatch responsesDoneSignal) { - this.responsesDoneSignal = responsesDoneSignal; + public SignalingResponseObserver(SettableFuture completionFuture) { + this.completionFuture = completionFuture; } @Override public void onCompleted() { - Preconditions.checkState(numResponsesReceived == NUM_RESPONSES_PER_REQUEST); - responsesDoneSignal.countDown(); + if (numResponsesReceived != NUM_RESPONSES_PER_REQUEST) { + completionFuture.setException( + new IllegalStateException("Wrong number of responses: " + numResponsesReceived)); + } else { + completionFuture.set(null); + } } @Override public void onError(Throwable error) { - // This should never happen. If it does happen, ensure that the error is visible. - error.printStackTrace(); + completionFuture.setException(error); } @Override @@ -87,19 +89,19 @@ public void onNext(StreamingOutputCallResponse response) { numResponsesReceived++; } - private final CountDownLatch responsesDoneSignal; + private final SettableFuture completionFuture; private int numResponsesReceived = 0; } /** * A client worker task that waits until all client workers are ready, then sends a request for a - * server-streaming RPC and arranges for a {@code CountDownLatch} to be signaled when the RPC is + * server-streaming RPC and arranges for a {@code ListenableFuture} to be signaled when the RPC is * complete. */ private class ClientWorker implements Runnable { - public ClientWorker(CyclicBarrier startBarrier, CountDownLatch responsesDoneSignal) { + public ClientWorker(CyclicBarrier startBarrier, SettableFuture completionFuture) { this.startBarrier = startBarrier; - this.responsesDoneSignal = responsesDoneSignal; + this.completionFuture = completionFuture; } @Override @@ -117,14 +119,17 @@ public void run() { // Wait until all client worker threads are poised & ready, then send the request. This way // all clients send their requests at approximately the same time. startBarrier.await(); - clientStub.streamingOutputCall(request, new SignalingResponseObserver(responsesDoneSignal)); - } catch (Exception e) { - throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e); + clientStub.streamingOutputCall(request, new SignalingResponseObserver(completionFuture)); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + completionFuture.setException(ex); + } catch (Throwable t) { + completionFuture.setException(t); } } private final CyclicBarrier startBarrier; - private final CountDownLatch responsesDoneSignal; + private final SettableFuture completionFuture; } private static final int NUM_SERVER_THREADS = 10; @@ -168,14 +173,15 @@ public void tearDown() { @Test public void serverStreamingTest() throws Exception { CyclicBarrier startBarrier = new CyclicBarrier(NUM_CONCURRENT_REQUESTS); - CountDownLatch responsesDoneSignal = new CountDownLatch(NUM_CONCURRENT_REQUESTS); + List> workerFutures = new ArrayList<>(NUM_CONCURRENT_REQUESTS); for (int i = 0; i < NUM_CONCURRENT_REQUESTS; i++) { - clientExecutor.execute(new ClientWorker(startBarrier, responsesDoneSignal)); + SettableFuture future = SettableFuture.create(); + clientExecutor.execute(new ClientWorker(startBarrier, future)); + workerFutures.add(future); } - // Wait until the clients all receive their complete RPC response streams. - responsesDoneSignal.await(); + Futures.allAsList(workerFutures).get(60, TimeUnit.SECONDS); } /** diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java index aead88e3135..2c741a9420d 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java @@ -16,11 +16,7 @@ package io.grpc.testing.integration; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; - import io.grpc.ChannelCredentials; -import io.grpc.Metadata; import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.TlsChannelCredentials; @@ -30,9 +26,7 @@ import io.grpc.netty.InternalNettyServerBuilder; import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyServerBuilder; -import io.grpc.stub.MetadataUtils; import java.io.IOException; -import java.net.InetAddress; import java.net.InetSocketAddress; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,36 +78,8 @@ protected NettyChannelBuilder createChannelBuilder() { } } - @Test - public void remoteAddr() { - InetSocketAddress isa = (InetSocketAddress) obtainRemoteClientAddr(); - assertEquals(InetAddress.getLoopbackAddress(), isa.getAddress()); - // It should not be the same as the server - assertNotEquals(((InetSocketAddress) getListenAddress()).getPort(), isa.getPort()); - } - - @Test - public void localAddr() throws Exception { - InetSocketAddress isa = (InetSocketAddress) obtainLocalServerAddr(); - assertEquals(InetAddress.getLoopbackAddress(), isa.getAddress()); - assertEquals(((InetSocketAddress) getListenAddress()).getPort(), isa.getPort()); - } - @Test public void tlsInfo() { assertX500SubjectDn("CN=testclient, O=Internet Widgits Pty Ltd, ST=Some-State, C=AU"); } - - @Test - public void contentLengthPermitted() throws Exception { - // Some third-party gRPC implementations (e.g., ServiceTalk) include Content-Length. The HTTP/2 - // code starting in Netty 4.1.60.Final has special-cased handling of Content-Length, and may - // call uncommon methods on our custom headers implementation. - // https://github.com/grpc/grpc-java/issues/7953 - Metadata contentLength = new Metadata(); - contentLength.put(Metadata.Key.of("content-length", Metadata.ASCII_STRING_MARSHALLER), "5"); - blockingStub - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(contentLength)) - .emptyCall(EMPTY); - } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java new file mode 100644 index 00000000000..198836bf552 --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import io.grpc.ChannelCredentials; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.internal.testing.TestUtils; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.InternalNettyServerBuilder; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.okhttp.InternalOkHttpChannelBuilder; +import io.grpc.okhttp.InternalOkHttpServerBuilder; +import io.grpc.okhttp.OkHttpChannelBuilder; +import io.grpc.okhttp.OkHttpServerBuilder; +import io.grpc.stub.MetadataUtils; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Arrays; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +/** + * Integration tests for GRPC over the various HTTP2 transports. + */ +@RunWith(Parameterized.class) +public class Http2Test extends AbstractInteropTest { + @BeforeClass + public static void loadConscrypt() throws Exception { + // Load conscrypt if it is available. Either Conscrypt or Jetty ALPN needs to be available for + // OkHttp to negotiate. + TestUtils.installConscryptIfAvailable(); + } + + enum Transport { + NETTY, OKHTTP; + } + + /** Parameterized test cases. */ + @Parameters(name = "client={0},server={1}") + public static Iterable data() { + return Arrays.asList(new Object[][] { + {Transport.NETTY, Transport.NETTY}, + {Transport.OKHTTP, Transport.OKHTTP}, + {Transport.OKHTTP, Transport.NETTY}, + {Transport.NETTY, Transport.OKHTTP}, + }); + } + + private final Transport clientType; + private final Transport serverType; + + public Http2Test(Transport clientType, Transport serverType) { + this.clientType = clientType; + this.serverType = serverType; + } + + @Override + protected ServerBuilder getServerBuilder() { + // Starts the server with HTTPS. + ServerCredentials serverCreds; + try { + serverCreds = TlsServerCredentials.create( + TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + ServerBuilder builder; + if (serverType == Transport.NETTY) { + NettyServerBuilder nettyBuilder = NettyServerBuilder.forPort(0, serverCreds) + .flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW); + // Disable the default census stats tracer, use testing tracer instead. + InternalNettyServerBuilder.setStatsEnabled(nettyBuilder, false); + builder = nettyBuilder; + } else { + OkHttpServerBuilder okHttpBuilder = OkHttpServerBuilder.forPort(0, serverCreds) + .flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW); + // Disable the default census stats tracer, use testing tracer instead. + InternalOkHttpServerBuilder.setStatsEnabled(okHttpBuilder, false); + builder = okHttpBuilder; + } + return builder + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .addStreamTracerFactory(createCustomCensusTracerFactory()); + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + ChannelCredentials channelCreds; + try { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(TestUtils.loadCert("ca.pem")) + .build(); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + int port = ((InetSocketAddress) getListenAddress()).getPort(); + ManagedChannelBuilder builder; + if (clientType == Transport.NETTY) { + NettyChannelBuilder nettyBuilder = NettyChannelBuilder + .forAddress("localhost", port, channelCreds) + .flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW); + // Disable the default census stats interceptor, use testing interceptor instead. + InternalNettyChannelBuilder.setStatsEnabled(nettyBuilder, false); + builder = nettyBuilder; + } else { + OkHttpChannelBuilder okHttpBuilder = OkHttpChannelBuilder + .forAddress("localhost", port, channelCreds) + .flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW); + // Disable the default census stats interceptor, use testing interceptor instead. + InternalOkHttpChannelBuilder.setStatsEnabled(okHttpBuilder, false); + builder = okHttpBuilder; + } + return builder + .overrideAuthority(TestUtils.TEST_SERVER_HOST) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .intercept(createCensusStatsClientInterceptor()); + } + + @Test + public void remoteAddr() { + InetSocketAddress isa = (InetSocketAddress) obtainRemoteClientAddr(); + assertEquals(InetAddress.getLoopbackAddress(), isa.getAddress()); + // It should not be the same as the server + assertNotEquals(((InetSocketAddress) getListenAddress()).getPort(), isa.getPort()); + } + + @Test + public void localAddr() throws Exception { + InetSocketAddress isa = (InetSocketAddress) obtainLocalServerAddr(); + assertEquals(InetAddress.getLoopbackAddress(), isa.getAddress()); + assertEquals(((InetSocketAddress) getListenAddress()).getPort(), isa.getPort()); + } + + @Test + public void contentLengthPermitted() throws Exception { + // Some third-party gRPC implementations (e.g., ServiceTalk) include Content-Length. The HTTP/2 + // code starting in Netty 4.1.60.Final has special-cased handling of Content-Length, and may + // call uncommon methods on our custom headers implementation. + // https://github.com/grpc/grpc-java/issues/7953 + Metadata contentLength = new Metadata(); + contentLength.put(Metadata.Key.of("content-length", Metadata.ASCII_STRING_MARSHALLER), "5"); + blockingStub + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(contentLength)) + .emptyCall(EMPTY); + } +} diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/NettyFlowControlTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/NettyFlowControlTest.java index c86bd8070a0..c94e95704ac 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/NettyFlowControlTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/NettyFlowControlTest.java @@ -147,10 +147,11 @@ private void doTest(int bandwidth, int latency) throws InterruptedException { // deal with cases that either don't cause a window update or hit max window expectedWindow = Math.min(MAX_WINDOW, Math.max(expectedWindow, REGULAR_WINDOW)); - // Range looks large, but this allows for only one extra/missed window update + // Range looks large, but this allows for only one extra/missed window update plus + // bdpPing variations. // (one extra update causes a 2x difference and one missed update causes a .5x difference) assertTrue("Window was " + lastWindow + " expecting " + expectedWindow, - lastWindow < 2 * expectedWindow); + lastWindow < 2.2 * expectedWindow); assertTrue("Window was " + lastWindow + " expecting " + expectedWindow, expectedWindow < 2 * lastWindow); } @@ -194,6 +195,7 @@ private static class TestStreamObserver implements StreamObserver grpcHandlerRef, long window) { @@ -206,9 +208,18 @@ public TestStreamObserver( public void onNext(StreamingOutputCallResponse value) { GrpcHttp2ConnectionHandler grpcHandler = grpcHandlerRef.get(); Http2Stream connectionStream = grpcHandler.connection().connectionStream(); - lastWindow = grpcHandler.decoder().flowController().initialWindowSize(connectionStream); - if (lastWindow >= expectedWindow) { - onCompleted(); + int curWindow = grpcHandler.decoder().flowController().initialWindowSize(connectionStream); + synchronized (this) { + if (curWindow >= expectedWindow) { + if (wasCompleted) { + return; + } + wasCompleted = true; + lastWindow = curWindow; + onCompleted(); + } else if (!wasCompleted) { + lastWindow = curWindow; + } } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index 229d873571a..72ed8bf975b 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -209,7 +209,7 @@ private void elapseBackoff(long time, TimeUnit unit) throws Exception { private void assertRpcStartedRecorded() throws Exception { MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); - assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)) + assertThat(record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_STARTED_RPCS)) .isEqualTo(1); } @@ -249,9 +249,9 @@ private void assertRpcStatusRecorded( assertThat(statusTag.asString()).isEqualTo(code.toString()); assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)) .isEqualTo(1); - assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)) + assertThat(record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_ROUNDTRIP_LATENCY)) .isEqualTo(roundtripLatencyMs); - assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)) + assertThat(record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_RPC)) .isEqualTo(outboundMessages); } @@ -276,7 +276,7 @@ public void retryUntilBufferLimitExceeded() throws Exception { .put("maxBackoff", "10s") .put("backoffMultiplier", 1D) .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) - .build(); + .buildOrThrow(); createNewChannel(); ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); call.start(mockCallListener, new Metadata()); @@ -300,7 +300,7 @@ public void retryUntilBufferLimitExceeded() throws Exception { Status.UNAVAILABLE.withDescription("2nd attempt failed"), new Metadata()); // no more retry - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); verify(mockCallListener, timeout(5000)).onClose(statusCaptor.capture(), any(Metadata.class)); assertThat(statusCaptor.getValue().getDescription()).contains("2nd attempt failed"); } @@ -314,7 +314,7 @@ public void statsRecorded() throws Exception { .put("maxBackoff", "10s") .put("backoffMultiplier", 1D) .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) - .build(); + .buildOrThrow(); createNewChannel(); ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); @@ -347,12 +347,12 @@ public void statsRecorded() throws Exception { fakeClock.forwardTime(2, SECONDS); serverCall.sendHeaders(new Metadata()); serverCall.sendMessage(3); + serverCall.close(Status.OK, new Metadata()); call.request(1); assertInboundMessageRecorded(); assertInboundWireSizeRecorded(1); - serverCall.close(Status.OK, new Metadata()); - assertRpcStatusRecorded(Status.Code.OK, 2000, 2); - assertRetryStatsRecorded(1, 0, 10_000); + assertRpcStatusRecorded(Status.Code.OK, 12000, 2); + assertRetryStatsRecorded(1, 0, 0); } @Test @@ -364,7 +364,7 @@ public void statsRecorde_callCancelledBeforeCommit() throws Exception { .put("maxBackoff", "10s") .put("backoffMultiplier", 1D) .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) - .build(); + .buildOrThrow(); createNewChannel(); // We will have streamClosed return at a particular moment that we want. @@ -410,13 +410,14 @@ public void streamClosed(Status status) { serverCall.request(2); assertOutboundWireSizeRecorded(message.length()); fakeClock.forwardTime(7, SECONDS); - call.cancel("Cancelled before commit", null); // A noop substream will commit. - // The call listener is closed, but the netty substream listener is not yet closed. - verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); + // A noop substream will commit. But call is not yet closed. + call.cancel("Cancelled before commit", null); // Let the netty substream listener be closed. streamClosedLatch.countDown(); - assertRetryStatsRecorded(1, 0, 10_000); - assertRpcStatusRecorded(Code.CANCELLED, 7_000, 1); + // The call listener is closed. + verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); + assertRpcStatusRecorded(Code.CANCELLED, 17_000, 1); + assertRetryStatsRecorded(1, 0, 0); } @Test diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java new file mode 100644 index 00000000000..e19208b8883 --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.testing.integration.RpcBehaviorLoadBalancerProvider.RpcBehaviorConfig; +import io.grpc.testing.integration.RpcBehaviorLoadBalancerProvider.RpcBehaviorHelper; +import io.grpc.testing.integration.RpcBehaviorLoadBalancerProvider.RpcBehaviorLoadBalancer; +import io.grpc.testing.integration.RpcBehaviorLoadBalancerProvider.RpcBehaviorPicker; +import java.net.SocketAddress; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Unit tests for {@link RpcBehaviorLoadBalancerProvider}. + */ +@RunWith(JUnit4.class) +public class RpcBehaviorLoadBalancerProviderTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private LoadBalancer mockDelegateLb; + + @Mock + private Helper mockHelper; + + @Mock + private SubchannelPicker mockPicker; + + @Test + public void parseValidConfig() { + assertThat(buildConfig().rpcBehavior).isEqualTo("error-code-15"); + } + + @Test + public void parseInvalidConfig() { + Status status = new RpcBehaviorLoadBalancerProvider().parseLoadBalancingPolicyConfig( + ImmutableMap.of("foo", "bar")).getError(); + assertThat(status.getDescription()).contains("rpcBehavior"); + } + + @Test + public void handleResolvedAddressesDelegated() { + RpcBehaviorLoadBalancer lb = new RpcBehaviorLoadBalancer(new RpcBehaviorHelper(mockHelper), + mockDelegateLb); + ResolvedAddresses resolvedAddresses = buildResolvedAddresses(buildConfig()); + lb.handleResolvedAddresses(resolvedAddresses); + verify(mockDelegateLb).handleResolvedAddresses(resolvedAddresses); + } + + @Test + public void helperWrapsPicker() { + RpcBehaviorHelper helper = new RpcBehaviorHelper(mockHelper); + helper.setRpcBehavior("error-code-15"); + helper.updateBalancingState(ConnectivityState.READY, mockPicker); + + verify(mockHelper).updateBalancingState(eq(ConnectivityState.READY), + isA(RpcBehaviorPicker.class)); + } + + @Test + public void pickerAddsRpcBehaviorMetadata() { + PickSubchannelArgsImpl args = new PickSubchannelArgsImpl(TestMethodDescriptors.voidMethod(), + new Metadata(), CallOptions.DEFAULT); + new RpcBehaviorPicker(mockPicker, "error-code-15").pickSubchannel(args); + + assertThat(args.getHeaders() + .get(Metadata.Key.of("rpc-behavior", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo( + "error-code-15"); + } + + private RpcBehaviorConfig buildConfig() { + RpcBehaviorConfig config = (RpcBehaviorConfig) new RpcBehaviorLoadBalancerProvider() + .parseLoadBalancingPolicyConfig( + ImmutableMap.of("rpcBehavior", "error-code-15")).getConfig(); + return config; + } + + private ResolvedAddresses buildResolvedAddresses(RpcBehaviorConfig config) { + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setLoadBalancingPolicyConfig(config) + .setAddresses(ImmutableList.of( + new EquivalentAddressGroup(new SocketAddress() { + }))) + .setAttributes(Attributes.newBuilder().build()).build(); + return resolvedAddresses; + } +} diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java index 14a98514918..ab32d584e7c 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java @@ -65,7 +65,9 @@ public void testCaseNamesShouldMapToEnums() { "unimplemented_service", "cancel_after_begin", "cancel_after_first_response", - "timeout_on_sleeping_server" + "timeout_on_sleeping_server", + "orca_per_rpc", + "orca_oob" }; // additional test cases diff --git a/istio-interop-testing/build.gradle b/istio-interop-testing/build.gradle new file mode 100644 index 00000000000..7ba342c6884 --- /dev/null +++ b/istio-interop-testing/build.gradle @@ -0,0 +1,72 @@ +plugins { + id "application" + id "java" + + id "com.google.protobuf" + id 'com.google.cloud.tools.jib' + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: Istio Interop testing" + +configurations { + alpnagent +} + +evaluationDependsOn(project(':grpc-context').path) + +dependencies { + implementation project(':grpc-core'), + project(':grpc-netty'), + project(':grpc-protobuf'), + project(':grpc-services'), + project(':grpc-stub'), + project(':grpc-testing'), + project(':grpc-xds') + + compileOnly libraries.javax.annotation + + runtimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testImplementation project(':grpc-context').sourceSets.test.output, + project(':grpc-api').sourceSets.test.output, + project(':grpc-core').sourceSets.test.output, + libraries.mockito.core, + libraries.junit, + libraries.truth + alpnagent libraries.jetty.alpn.agent + + signature libraries.signature.java +} + +sourceSets { + main { + proto { + srcDir 'third_party/istio/src/main/proto' + } + } +} + +configureProtoCompilation() + +import net.ltgt.gradle.errorprone.CheckSeverity + +tasks.named("compileJava").configure { + // This isn't a library; it can use beta APIs + options.errorprone.check("BetaApi", CheckSeverity.OFF) +} + + +// For releasing to Docker Hub +jib { + from.image = "gcr.io/distroless/java:8" + container { + ports = ['50051'] + mainClass="io.grpc.testing.istio.EchoTestServer" + } + outputPaths { + tar = 'build/istio-echo-server.tar' + digest = 'build/istio-echo-server.digest' + imageId = 'build/istio-echo-server.id' + } +} diff --git a/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java b/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java new file mode 100644 index 00000000000..1c71469e01a --- /dev/null +++ b/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java @@ -0,0 +1,350 @@ +package io.istio.test; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + */ +@javax.annotation.Generated( + value = "by gRPC proto compiler", + comments = "Source: test/echo/proto/echo.proto") +@io.grpc.stub.annotations.GrpcGenerated +public final class EchoTestServiceGrpc { + + private EchoTestServiceGrpc() {} + + public static final String SERVICE_NAME = "proto.EchoTestService"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getEchoMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "Echo", + requestType = io.istio.test.Echo.EchoRequest.class, + responseType = io.istio.test.Echo.EchoResponse.class, + methodType = io.grpc.MethodDescriptor.MethodType.UNARY) + public static io.grpc.MethodDescriptor getEchoMethod() { + io.grpc.MethodDescriptor getEchoMethod; + if ((getEchoMethod = EchoTestServiceGrpc.getEchoMethod) == null) { + synchronized (EchoTestServiceGrpc.class) { + if ((getEchoMethod = EchoTestServiceGrpc.getEchoMethod) == null) { + EchoTestServiceGrpc.getEchoMethod = getEchoMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.UNARY) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Echo")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.istio.test.Echo.EchoRequest.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.istio.test.Echo.EchoResponse.getDefaultInstance())) + .setSchemaDescriptor(new EchoTestServiceMethodDescriptorSupplier("Echo")) + .build(); + } + } + } + return getEchoMethod; + } + + private static volatile io.grpc.MethodDescriptor getForwardEchoMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "ForwardEcho", + requestType = io.istio.test.Echo.ForwardEchoRequest.class, + responseType = io.istio.test.Echo.ForwardEchoResponse.class, + methodType = io.grpc.MethodDescriptor.MethodType.UNARY) + public static io.grpc.MethodDescriptor getForwardEchoMethod() { + io.grpc.MethodDescriptor getForwardEchoMethod; + if ((getForwardEchoMethod = EchoTestServiceGrpc.getForwardEchoMethod) == null) { + synchronized (EchoTestServiceGrpc.class) { + if ((getForwardEchoMethod = EchoTestServiceGrpc.getForwardEchoMethod) == null) { + EchoTestServiceGrpc.getForwardEchoMethod = getForwardEchoMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.UNARY) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "ForwardEcho")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.istio.test.Echo.ForwardEchoRequest.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.istio.test.Echo.ForwardEchoResponse.getDefaultInstance())) + .setSchemaDescriptor(new EchoTestServiceMethodDescriptorSupplier("ForwardEcho")) + .build(); + } + } + } + return getForwardEchoMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static EchoTestServiceStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public EchoTestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceStub(channel, callOptions); + } + }; + return EchoTestServiceStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static EchoTestServiceBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public EchoTestServiceBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceBlockingStub(channel, callOptions); + } + }; + return EchoTestServiceBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static EchoTestServiceFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public EchoTestServiceFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceFutureStub(channel, callOptions); + } + }; + return EchoTestServiceFutureStub.newStub(factory, channel); + } + + /** + */ + public static abstract class EchoTestServiceImplBase implements io.grpc.BindableService { + + /** + */ + public void echo(io.istio.test.Echo.EchoRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getEchoMethod(), responseObserver); + } + + /** + */ + public void forwardEcho(io.istio.test.Echo.ForwardEchoRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getForwardEchoMethod(), responseObserver); + } + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getEchoMethod(), + io.grpc.stub.ServerCalls.asyncUnaryCall( + new MethodHandlers< + io.istio.test.Echo.EchoRequest, + io.istio.test.Echo.EchoResponse>( + this, METHODID_ECHO))) + .addMethod( + getForwardEchoMethod(), + io.grpc.stub.ServerCalls.asyncUnaryCall( + new MethodHandlers< + io.istio.test.Echo.ForwardEchoRequest, + io.istio.test.Echo.ForwardEchoResponse>( + this, METHODID_FORWARD_ECHO))) + .build(); + } + } + + /** + */ + public static final class EchoTestServiceStub extends io.grpc.stub.AbstractAsyncStub { + private EchoTestServiceStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected EchoTestServiceStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceStub(channel, callOptions); + } + + /** + */ + public void echo(io.istio.test.Echo.EchoRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ClientCalls.asyncUnaryCall( + getChannel().newCall(getEchoMethod(), getCallOptions()), request, responseObserver); + } + + /** + */ + public void forwardEcho(io.istio.test.Echo.ForwardEchoRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ClientCalls.asyncUnaryCall( + getChannel().newCall(getForwardEchoMethod(), getCallOptions()), request, responseObserver); + } + } + + /** + */ + public static final class EchoTestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { + private EchoTestServiceBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected EchoTestServiceBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceBlockingStub(channel, callOptions); + } + + /** + */ + public io.istio.test.Echo.EchoResponse echo(io.istio.test.Echo.EchoRequest request) { + return io.grpc.stub.ClientCalls.blockingUnaryCall( + getChannel(), getEchoMethod(), getCallOptions(), request); + } + + /** + */ + public io.istio.test.Echo.ForwardEchoResponse forwardEcho(io.istio.test.Echo.ForwardEchoRequest request) { + return io.grpc.stub.ClientCalls.blockingUnaryCall( + getChannel(), getForwardEchoMethod(), getCallOptions(), request); + } + } + + /** + */ + public static final class EchoTestServiceFutureStub extends io.grpc.stub.AbstractFutureStub { + private EchoTestServiceFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected EchoTestServiceFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceFutureStub(channel, callOptions); + } + + /** + */ + public com.google.common.util.concurrent.ListenableFuture echo( + io.istio.test.Echo.EchoRequest request) { + return io.grpc.stub.ClientCalls.futureUnaryCall( + getChannel().newCall(getEchoMethod(), getCallOptions()), request); + } + + /** + */ + public com.google.common.util.concurrent.ListenableFuture forwardEcho( + io.istio.test.Echo.ForwardEchoRequest request) { + return io.grpc.stub.ClientCalls.futureUnaryCall( + getChannel().newCall(getForwardEchoMethod(), getCallOptions()), request); + } + } + + private static final int METHODID_ECHO = 0; + private static final int METHODID_FORWARD_ECHO = 1; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final EchoTestServiceImplBase serviceImpl; + private final int methodId; + + MethodHandlers(EchoTestServiceImplBase serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_ECHO: + serviceImpl.echo((io.istio.test.Echo.EchoRequest) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; + case METHODID_FORWARD_ECHO: + serviceImpl.forwardEcho((io.istio.test.Echo.ForwardEchoRequest) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + } + + private static abstract class EchoTestServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + EchoTestServiceBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.istio.test.Echo.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("EchoTestService"); + } + } + + private static final class EchoTestServiceFileDescriptorSupplier + extends EchoTestServiceBaseDescriptorSupplier { + EchoTestServiceFileDescriptorSupplier() {} + } + + private static final class EchoTestServiceMethodDescriptorSupplier + extends EchoTestServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final String methodName; + + EchoTestServiceMethodDescriptorSupplier(String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (EchoTestServiceGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new EchoTestServiceFileDescriptorSupplier()) + .addMethod(getEchoMethod()) + .addMethod(getForwardEchoMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/istio-interop-testing/src/main/java/io/grpc/testing/istio/EchoTestServer.java b/istio-interop-testing/src/main/java/io/grpc/testing/istio/EchoTestServer.java new file mode 100644 index 00000000000..ae6f60098ac --- /dev/null +++ b/istio-interop-testing/src/main/java/io/grpc/testing/istio/EchoTestServer.java @@ -0,0 +1,506 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.istio; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.CharMatcher; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ChannelCredentials; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerCredentials; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsServerCredentials; +import io.grpc.services.AdminInterface; +import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.XdsChannelCredentials; +import io.grpc.xds.XdsServerCredentials; +import io.istio.test.Echo.EchoRequest; +import io.istio.test.Echo.EchoResponse; +import io.istio.test.Echo.ForwardEchoRequest; +import io.istio.test.Echo.ForwardEchoResponse; +import io.istio.test.Echo.Header; +import io.istio.test.EchoTestServiceGrpc; +import io.istio.test.EchoTestServiceGrpc.EchoTestServiceFutureStub; +import io.istio.test.EchoTestServiceGrpc.EchoTestServiceImplBase; +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * This class implements the Istio echo server functionality similar to + * https://github.com/istio/istio/blob/master/pkg/test/echo/server/endpoint/grpc.go . + * Please see Istio framework docs https://github.com/istio/istio/wiki/Istio-Test-Framework . + */ +public final class EchoTestServer { + + private static final Logger logger = Logger.getLogger(EchoTestServer.class.getName()); + + static final Context.Key CLIENT_ADDRESS_CONTEXT_KEY = + Context.key("io.grpc.testing.istio.ClientAddress"); + static final Context.Key AUTHORITY_CONTEXT_KEY = + Context.key("io.grpc.testing.istio.Authority"); + static final Context.Key> REQUEST_HEADERS_CONTEXT_KEY = + Context.key("io.grpc.testing.istio.RequestHeaders"); + + private static final String REQUEST_ID = "x-request-id"; + private static final String STATUS_CODE = "StatusCode"; + private static final String HOST = "Host"; + private static final String HOSTNAME = "Hostname"; + private static final String REQUEST_HEADER = "RequestHeader"; + private static final String IP = "IP"; + + @VisibleForTesting List servers; + + /** + * Preprocess args, for: + * - merging duplicate flags. So "--grpc=8080 --grpc=9090" becomes + * "--grpc=8080,9090". + **/ + @VisibleForTesting + static Map> preprocessArgs(String[] args) { + HashMap> argsMap = new HashMap<>(); + for (String arg : args) { + List keyValue = Splitter.on('=').limit(2).splitToList(arg); + + if (keyValue.size() == 2) { + String key = keyValue.get(0); + String value = keyValue.get(1); + List oldValue = argsMap.get(key); + if (oldValue == null) { + oldValue = new ArrayList<>(); + } + oldValue.add(value); + argsMap.put(key, oldValue); + } + } + return ImmutableMap.>builder().putAll(argsMap).build(); + } + + /** Turn ports from a string list to an int list. */ + @VisibleForTesting + static Set getPorts(Map> args, String flagName) { + List grpcPorts = args.get(flagName); + Set grpcPortsInt = new HashSet<>(grpcPorts.size()); + + for (String port : grpcPorts) { + port = CharMatcher.is('\"').trimFrom(port); + grpcPortsInt.add(Integer.parseInt(port)); + } + return grpcPortsInt; + } + + private static String determineHostname() { + try { + return InetAddress.getLocalHost().getHostName(); + } catch (IOException ex) { + logger.log(Level.INFO, "Failed to determine hostname. Will generate one", ex); + } + // let's make an identifier for ourselves. + return "generated-" + new Random().nextInt(); + } + + /** + * The main application allowing this program to be launched from the command line. + */ + public static void main(String[] args) throws Exception { + Map> processedArgs = preprocessArgs(args); + Set grpcPorts = getPorts(processedArgs, "--grpc"); + Set xdsPorts = getPorts(processedArgs, "--xds-grpc-server"); + // If an xds port does not exist in gRPC ports set, add it. + grpcPorts.addAll(xdsPorts); + // which ports are supposed to use tls + Set tlsPorts = getPorts(processedArgs, "--tls"); + List forwardingAddress = processedArgs.get("--forwarding-address"); + if (forwardingAddress.size() > 1) { + logger.severe("More than one value for --forwarding-address not allowed"); + System.exit(1); + } + if (forwardingAddress.size() == 0) { + forwardingAddress.add("0.0.0.0:7072"); + } + List key = processedArgs.get("key"); + List crt = processedArgs.get("crt"); + + if (key.size() > 1 || crt.size() > 1) { + logger.severe("More than one value for --key or --crt not allowed"); + System.exit(1); + } + if (key.size() != crt.size()) { + logger.severe("Both --key or --crt should be present or absent"); + System.exit(1); + } + ServerCredentials tlsServerCredentials = null; + if (key.size() == 1) { + tlsServerCredentials = TlsServerCredentials.create(new File(crt.get(0)), + new File(key.get(0))); + } else if (!tlsPorts.isEmpty()) { + logger.severe("Both --key or --crt should be present if tls ports used"); + System.exit(1); + } + + String hostname = determineHostname(); + EchoTestServer echoTestServer = new EchoTestServer(); + echoTestServer.runServers(hostname, grpcPorts, xdsPorts, tlsPorts, forwardingAddress.get(0), + tlsServerCredentials); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("Shutting down"); + echoTestServer.stopServers(); + } catch (Exception e) { + logger.log(Level.SEVERE, "stopServers", e); + throw e; + } + })); + echoTestServer.blockUntilShutdown(); + } + + void runServers(String hostname, Collection grpcPorts, Collection xdsPorts, + Collection tlsPorts, String forwardingAddress, + ServerCredentials tlsServerCredentials) + throws IOException { + ServerServiceDefinition service = ServerInterceptors.intercept( + new EchoTestServiceImpl(hostname, forwardingAddress), new EchoTestServerInterceptor()); + servers = new ArrayList<>(grpcPorts.size() + 1); + boolean runAdminServices = Boolean.getBoolean("EXPOSE_GRPC_ADMIN"); + for (int port : grpcPorts) { + ServerCredentials serverCredentials = InsecureServerCredentials.create(); + String credTypeString = "over plaintext"; + if (xdsPorts.contains(port)) { + serverCredentials = XdsServerCredentials.create(InsecureServerCredentials.create()); + credTypeString = "over xDS-configured mTLS"; + } else if (tlsPorts.contains(port)) { + serverCredentials = tlsServerCredentials; + credTypeString = "over TLS"; + } + servers.add(runServer(port, service, serverCredentials, credTypeString, runAdminServices)); + } + } + + static Server runServer( + int port, ServerServiceDefinition service, ServerCredentials serverCredentials, + String credTypeString, boolean runAdminServices) + throws IOException { + logger.log(Level.INFO, "Listening GRPC ({0}) on {1}", new Object[]{credTypeString, port}); + ServerBuilder builder = Grpc.newServerBuilderForPort(port, serverCredentials) + .addService(service); + if (runAdminServices) { + builder = builder.addServices(AdminInterface.getStandardServices()); + } + return builder.build().start(); + } + + void stopServers() { + for (Server server : servers) { + server.shutdownNow(); + } + } + + void blockUntilShutdown() throws InterruptedException { + for (Server server : servers) { + if (!server.awaitTermination(5, TimeUnit.SECONDS)) { + System.err.println("Timed out waiting for server shutdown"); + } + } + } + + private static class EchoTestServerInterceptor implements ServerInterceptor { + + @Override + public ServerCall.Listener interceptCall(ServerCall call, + final Metadata requestHeaders, ServerCallHandler next) { + final String methodName = call.getMethodDescriptor().getBareMethodName(); + + // we need this processing only for Echo + if (!"Echo".equals(methodName)) { + return next.startCall(call, requestHeaders); + } + final SocketAddress peerAddress = call.getAttributes() + .get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + + Context ctx = Context.current(); + if (peerAddress instanceof InetSocketAddress) { + InetSocketAddress inetPeerAddress = (InetSocketAddress) peerAddress; + ctx = ctx.withValue(CLIENT_ADDRESS_CONTEXT_KEY, + inetPeerAddress.getAddress().getHostAddress()); + } + ctx = ctx.withValue(AUTHORITY_CONTEXT_KEY, call.getAuthority()); + Map requestHeadersCopy = new HashMap<>(); + for (String key : requestHeaders.keys()) { + if (!key.endsWith("-bin")) { + requestHeadersCopy.put(key, + requestHeaders.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER))); + } + } + ctx = ctx.withValue(REQUEST_HEADERS_CONTEXT_KEY, requestHeadersCopy); + return Contexts.interceptCall( + ctx, + call, + requestHeaders, + next); + } + } + + private static class EchoTestServiceImpl extends EchoTestServiceImplBase { + + private final String hostname; + private final String forwardingAddress; + private final EchoTestServiceGrpc.EchoTestServiceBlockingStub forwardingStub; + + EchoTestServiceImpl(String hostname, String forwardingAddress) { + this.hostname = hostname; + this.forwardingAddress = forwardingAddress; + this.forwardingStub = EchoTestServiceGrpc.newBlockingStub( + Grpc.newChannelBuilder(forwardingAddress, InsecureChannelCredentials.create()).build()); + } + + @Override + public void echo(EchoRequest request, + io.grpc.stub.StreamObserver responseObserver) { + + EchoMessage echoMessage = new EchoMessage(); + echoMessage.writeKeyValue(HOSTNAME, hostname); + echoMessage.writeKeyValue("Echo", request.getMessage()); + String clientAddress = CLIENT_ADDRESS_CONTEXT_KEY.get(); + if (clientAddress != null) { + echoMessage.writeKeyValue(IP, clientAddress); + } + Map requestHeadersCopy = REQUEST_HEADERS_CONTEXT_KEY.get(); + for (Map.Entry entry : requestHeadersCopy.entrySet()) { + echoMessage.writeKeyValueForRequest(REQUEST_HEADER, entry.getKey(), entry.getValue()); + } + echoMessage.writeKeyValue(STATUS_CODE, "200"); + echoMessage.writeKeyValue(HOST, AUTHORITY_CONTEXT_KEY.get()); + EchoResponse echoResponse = EchoResponse.newBuilder() + .setMessage(echoMessage.toString()) + .build(); + + responseObserver.onNext(echoResponse); + responseObserver.onCompleted(); + } + + @Override + public void forwardEcho(ForwardEchoRequest request, + StreamObserver responseObserver) { + try { + responseObserver.onNext(buildEchoResponse(request)); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + responseObserver.onError(e); + Thread.currentThread().interrupt(); + } catch (Exception e) { + responseObserver.onError(e); + } + } + + private static final class EchoCall { + EchoResponse response; + Status status; + } + + private ForwardEchoResponse buildEchoResponse(ForwardEchoRequest request) + throws InterruptedException { + ForwardEchoResponse.Builder forwardEchoResponseBuilder + = ForwardEchoResponse.newBuilder(); + String rawUrl = request.getUrl(); + List urlParts = Splitter.on(':').limit(2).splitToList(rawUrl); + if (urlParts.size() < 2) { + throw new StatusRuntimeException( + Status.INVALID_ARGUMENT.withDescription("No protocol configured for url " + rawUrl)); + } + ChannelCredentials creds; + String target = null; + if ("grpc".equals(urlParts.get(0))) { + // We don't really want to test this but the istio test infrastructure needs + // this to be supported. If we ever decide to add support for this properly, + // we would need to add support for TLS creds here. + creds = InsecureChannelCredentials.create(); + target = urlParts.get(1).substring(2); + } else if ("xds".equals(urlParts.get(0))) { + creds = XdsChannelCredentials.create(InsecureChannelCredentials.create()); + target = rawUrl; + } else { + logger.log(Level.INFO, "Protocol {0} not supported. Forwarding to {1}", + new String[]{urlParts.get(0), forwardingAddress}); + return forwardingStub.withDeadline(Context.current().getDeadline()).forwardEcho(request); + } + + ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder(target, creds); + ManagedChannel channel = channelBuilder.build(); + + List
requestHeaders = request.getHeadersList(); + Metadata metadata = new Metadata(); + + for (Header header : requestHeaders) { + metadata.put(Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER), + header.getValue()); + } + + int count = request.getCount() == 0 ? 1 : request.getCount(); + // Calculate the amount of time to sleep after each call. + Duration durationPerQuery = Duration.ZERO; + if (request.getQps() > 0) { + durationPerQuery = Duration.ofNanos( + Duration.ofSeconds(1).toNanos() / request.getQps()); + } + logger.info("qps=" + request.getQps()); + logger.info("durationPerQuery=" + durationPerQuery); + EchoRequest echoRequest = EchoRequest.newBuilder() + .setMessage(request.getMessage()) + .build(); + Instant start = Instant.now(); + logger.info("starting instant=" + start); + Duration expected = Duration.ZERO; + final CountDownLatch latch = new CountDownLatch(count); + EchoCall[] echoCalls = new EchoCall[count]; + for (int i = 0; i < count; i++) { + Metadata currentMetadata = new Metadata(); + currentMetadata.merge(metadata); + currentMetadata.put( + Metadata.Key.of(REQUEST_ID, Metadata.ASCII_STRING_MARSHALLER), "" + i); + EchoTestServiceGrpc.EchoTestServiceFutureStub stub + = EchoTestServiceGrpc.newFutureStub(channel).withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(currentMetadata)) + .withDeadlineAfter(request.getTimeoutMicros(), TimeUnit.MICROSECONDS); + + echoCalls[i] = new EchoCall(); + callEcho(stub, echoRequest, echoCalls[i], latch); + Instant current = Instant.now(); + logger.info("after rpc instant=" + current); + Duration elapsed = Duration.between(start, current); + expected = expected.plus(durationPerQuery); + Duration timeLeft = expected.minus(elapsed); + logger.info("elapsed=" + elapsed + ", expected=" + expected + ", timeLeft=" + timeLeft); + if (!timeLeft.isNegative()) { + logger.info("sleeping for ms =" + timeLeft); + Thread.sleep(timeLeft.toMillis()); + } + } + latch.await(); + for (int i = 0; i < count; i++) { + if (Status.OK.equals(echoCalls[i].status)) { + forwardEchoResponseBuilder.addOutput( + buildForwardEchoStruct(i, echoCalls, request.getMessage())); + } else { + logger.log(Level.SEVERE, "RPC {0} failed {1}: {2}", + new Object[]{i, echoCalls[i].status.getCode(), echoCalls[i].status.getDescription()}); + forwardEchoResponseBuilder.clear(); + throw echoCalls[i].status.asRuntimeException(); + } + } + return forwardEchoResponseBuilder.build(); + } + + private static String buildForwardEchoStruct(int i, EchoCall[] echoCalls, + String requestMessage) { + // The test infrastructure might expect the entire struct instead of + // just the message. + StringBuilder sb = new StringBuilder(); + sb.append(String.format("[%d] grpcecho.Echo(%s)\n", i, requestMessage)); + Iterable iterable = Splitter.on('\n').split(echoCalls[i].response.getMessage()); + for (String line : iterable) { + if (!line.isEmpty()) { + sb.append(String.format("[%d body] %s\n", i, line)); + } + } + return sb.toString(); + } + + private void callEcho(EchoTestServiceFutureStub stub, + EchoRequest echoRequest, final EchoCall echoCall, CountDownLatch latch) { + + ListenableFuture response = stub.echo(echoRequest); + Futures.addCallback( + response, + new FutureCallback() { + @Override + public void onSuccess(@Nullable EchoResponse result) { + echoCall.response = result; + echoCall.status = Status.OK; + latch.countDown(); + } + + @Override + public void onFailure(Throwable t) { + echoCall.status = Status.fromThrowable(t); + latch.countDown(); + } + }, + MoreExecutors.directExecutor()); + } + } + + private static class EchoMessage { + private final StringBuilder sb = new StringBuilder(); + + void writeKeyValue(String key, String value) { + sb.append(key).append("=").append(value).append("\n"); + } + + void writeKeyValueForRequest(String requestHeader, String key, String value) { + if (value != null) { + writeKeyValue(requestHeader, key + ":" + value); + } + } + + void writeMessage(String message) { + sb.append(message); + } + + @Override + public String toString() { + return sb.toString(); + } + } +} diff --git a/istio-interop-testing/src/test/java/io/grpc/testing/istio/EchoTestServerTest.java b/istio-interop-testing/src/test/java/io/grpc/testing/istio/EchoTestServerTest.java new file mode 100644 index 00000000000..091a300b874 --- /dev/null +++ b/istio-interop-testing/src/test/java/io/grpc/testing/istio/EchoTestServerTest.java @@ -0,0 +1,388 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.istio; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; +import io.istio.test.Echo.EchoRequest; +import io.istio.test.Echo.EchoResponse; +import io.istio.test.Echo.ForwardEchoRequest; +import io.istio.test.Echo.ForwardEchoResponse; +import io.istio.test.Echo.Header; +import io.istio.test.EchoTestServiceGrpc; +import io.istio.test.EchoTestServiceGrpc.EchoTestServiceImplBase; +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link EchoTestServer}. + */ +@RunWith(JUnit4.class) +public class EchoTestServerTest { + + private static final String[] EXPECTED_KEY_SET = { + "--server_first", "--forwarding-address", + "--bind_ip", "--istio-version", "--bind_localhost", "--grpc", "--tls", + "--cluster", "--key", "--tcp", "--crt", "--metrics", "--port", "--version" + }; + + private static final String TEST_ARGS = + "--metrics=15014 --cluster=\"cluster-0\" --port=\"18080\" --grpc=\"17070\" --port=\"18085\"" + + " --tcp=\"19090\" --port=\"18443\" --tls=18443 --tcp=\"16060\" --server_first=16060" + + " --tcp=\"19091\" --tcp=\"16061\" --server_first=16061 --port=\"18081\"" + + " --grpc=\"17071\" --port=\"19443\" --tls=19443 --port=\"18082\" --bind_ip=18082" + + " --port=\"18084\" --bind_localhost=18084 --tcp=\"19092\" --port=\"18083\"" + + " --port=\"8080\" --port=\"3333\" --version=\"v1\" --istio-version=3 --crt=/cert.crt" + + " --key=/cert.key --forwarding-address=192.168.1.10:7072"; + + private static final String TEST_ARGS_PORTS = + "--metrics=15014 --cluster=\"cluster-0\" --port=\"18080\" --grpc=17070 --port=18085" + + " --tcp=\"19090\" --port=\"18443\" --tls=18443 --tcp=16060 --server_first=16060" + + " --tcp=\"19091\" --tcp=\"16061\" --server_first=16061 --port=\"18081\"" + + " --grpc=\"17071\" --port=\"19443\" --tls=\"19443\" --port=\"18082\" --bind_ip=18082" + + " --port=\"18084\" --bind_localhost=18084 --tcp=\"19092\" --port=\"18083\"" + + " --port=\"8080\" --port=3333 --version=\"v1\" --istio-version=3 --crt=/cert.crt" + + " --key=/cert.key --xds-grpc-server=12034 --xds-grpc-server=\"34012\""; + + @Test + public void preprocessArgsTest() { + String[] splitArgs = TEST_ARGS.split(" "); + Map> processedArgs = EchoTestServer.preprocessArgs(splitArgs); + + assertEquals(processedArgs.keySet(), ImmutableSet.copyOf(EXPECTED_KEY_SET)); + assertEquals(processedArgs.get("--server_first"), ImmutableList.of("16060", "16061")); + assertEquals(processedArgs.get("--bind_ip"), ImmutableList.of("18082")); + assertEquals(processedArgs.get("--bind_localhost"), ImmutableList.of("18084")); + assertEquals(processedArgs.get("--grpc"), ImmutableList.of("\"17070\"", "\"17071\"")); + assertEquals(processedArgs.get("--tls"), ImmutableList.of("18443", "19443")); + assertEquals(processedArgs.get("--cluster"), ImmutableList.of("\"cluster-0\"")); + assertEquals(processedArgs.get("--key"), ImmutableList.of("/cert.key")); + assertEquals(processedArgs.get("--tcp"), ImmutableList.of("\"19090\"", "\"16060\"", + "\"19091\"","\"16061\"","\"19092\"")); + assertEquals(processedArgs.get("--istio-version"), ImmutableList.of("3")); + assertEquals(processedArgs.get("--crt"), ImmutableList.of("/cert.crt")); + assertEquals(processedArgs.get("--metrics"), ImmutableList.of("15014")); + assertEquals(ImmutableList.of("192.168.1.10:7072"), processedArgs.get("--forwarding-address")); + assertEquals( + processedArgs.get("--port"), + ImmutableList.of( + "\"18080\"", + "\"18085\"", + "\"18443\"", + "\"18081\"", + "\"19443\"", + "\"18082\"", + "\"18084\"", + "\"18083\"", + "\"8080\"", + "\"3333\"")); + } + + @Test + public void preprocessArgsPortsTest() { + String[] splitArgs = TEST_ARGS_PORTS.split(" "); + Map> processedArgs = EchoTestServer.preprocessArgs(splitArgs); + + Set ports = EchoTestServer.getPorts(processedArgs, "--port"); + assertThat(ports).containsExactly(18080, 8080, 18081, 18082, 19443, 18083, 18084, 18085, + 3333, 18443); + ports = EchoTestServer.getPorts(processedArgs, "--grpc"); + assertThat(ports).containsExactly(17070, 17071); + ports = EchoTestServer.getPorts(processedArgs, "--tls"); + assertThat(ports).containsExactly(18443, 19443); + ports = EchoTestServer.getPorts(processedArgs, "--xds-grpc-server"); + assertThat(ports).containsExactly(34012, 12034); + } + + + @Test + public void echoTest() throws IOException, InterruptedException { + EchoTestServer echoTestServer = new EchoTestServer(); + + echoTestServer.runServers( + "test-host", + ImmutableList.of(0, 0), + ImmutableList.of(), + ImmutableList.of(), + "0.0.0.0:7072", + null); + assertEquals(2, echoTestServer.servers.size()); + int port = echoTestServer.servers.get(0).getPort(); + assertNotEquals(0, port); + assertNotEquals(0, echoTestServer.servers.get(1).getPort()); + + ManagedChannelBuilder channelBuilder = + Grpc.newChannelBuilderForAddress("localhost", port, InsecureChannelCredentials.create()); + ManagedChannel channel = channelBuilder.build(); + + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("header1", Metadata.ASCII_STRING_MARSHALLER), "value1"); + metadata.put(Metadata.Key.of("header2", Metadata.ASCII_STRING_MARSHALLER), "value2"); + + EchoTestServiceGrpc.EchoTestServiceBlockingStub stub = + EchoTestServiceGrpc.newBlockingStub(channel) + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + + EchoRequest echoRequest = EchoRequest.newBuilder() + .setMessage("test-message1") + .build(); + EchoResponse echoResponse = stub.echo(echoRequest); + String echoMessage = echoResponse.getMessage(); + Set lines = ImmutableSet.copyOf(echoMessage.split("\n")); + + assertThat(lines).contains("RequestHeader=header1:value1"); + assertThat(lines).contains("RequestHeader=header2:value2"); + assertThat(lines).contains("Echo=test-message1"); + assertThat(lines).contains("Hostname=test-host"); + assertThat(lines).contains("Host=localhost:" + port); + assertThat(lines).contains("StatusCode=200"); + + echoTestServer.stopServers(); + echoTestServer.blockUntilShutdown(); + } + + static final int COUNT_OF_REQUESTS_TO_FORWARD = 60; + + @Test + public void forwardEchoTest() throws IOException, InterruptedException { + EchoTestServer echoTestServer = new EchoTestServer(); + + echoTestServer.runServers( + "test-host", + ImmutableList.of(0, 0), + ImmutableList.of(), + ImmutableList.of(), + "0.0.0.0:7072", + null); + assertEquals(2, echoTestServer.servers.size()); + int port1 = echoTestServer.servers.get(0).getPort(); + int port2 = echoTestServer.servers.get(1).getPort(); + + ManagedChannelBuilder channelBuilder = + Grpc.newChannelBuilderForAddress("localhost", port1, InsecureChannelCredentials.create()); + ManagedChannel channel = channelBuilder.build(); + + ForwardEchoRequest forwardEchoRequest = + ForwardEchoRequest.newBuilder() + .setCount(COUNT_OF_REQUESTS_TO_FORWARD) + .setQps(100) + .setTimeoutMicros(5000_000L) // 5000 millis + .setUrl("grpc://localhost:" + port2) + .addHeaders( + Header.newBuilder().setKey("test-key1").setValue("test-value1").build()) + .addHeaders( + Header.newBuilder().setKey("test-key2").setValue("test-value2").build()) + .setMessage("forward-echo-test-message") + .build(); + + EchoTestServiceGrpc.EchoTestServiceBlockingStub stub = + EchoTestServiceGrpc.newBlockingStub(channel); + + Instant start = Instant.now(); + ForwardEchoResponse forwardEchoResponse = stub.forwardEcho(forwardEchoRequest); + Instant end = Instant.now(); + List outputs = forwardEchoResponse.getOutputList(); + assertEquals(COUNT_OF_REQUESTS_TO_FORWARD, outputs.size()); + for (int i = 0; i < COUNT_OF_REQUESTS_TO_FORWARD; i++) { + validateOutput(outputs.get(i), i); + } + long duration = Duration.between(start, end).toMillis(); + assertThat(duration).isAtLeast(COUNT_OF_REQUESTS_TO_FORWARD * 10L); + echoTestServer.stopServers(); + echoTestServer.blockUntilShutdown(); + } + + private static void validateOutput(String output, int i) { + List content = Splitter.on('\n').splitToList(output); + assertThat(content.size()).isAtLeast(7); // see echo implementation + assertThat(content.get(0)) + .isEqualTo(String.format("[%d] grpcecho.Echo(forward-echo-test-message)", i)); + String prefix = "[" + i + " body] "; + assertThat(content).contains(prefix + "RequestHeader=x-request-id:" + i); + assertThat(content).contains(prefix + "RequestHeader=test-key1:test-value1"); + assertThat(content).contains(prefix + "RequestHeader=test-key2:test-value2"); + assertThat(content).contains(prefix + "Hostname=test-host"); + assertThat(content).contains(prefix + "StatusCode=200"); + } + + @Test + public void nonGrpcForwardEchoTest() throws IOException, InterruptedException { + ForwardServiceForNonGrpcImpl forwardServiceForNonGrpc = new ForwardServiceForNonGrpcImpl(); + forwardServiceForNonGrpc.receivedRequests = new ArrayList<>(); + forwardServiceForNonGrpc.responsesToReturn = new ArrayList<>(); + Server nonGrpcEchoServer = + EchoTestServer.runServer( + 0, forwardServiceForNonGrpc.bindService(), InsecureServerCredentials.create(), + "", false); + int nonGrpcEchoServerPort = nonGrpcEchoServer.getPort(); + + EchoTestServer echoTestServer = new EchoTestServer(); + + echoTestServer.runServers( + "test-host", + ImmutableList.of(0), + ImmutableList.of(), + ImmutableList.of(), + "0.0.0.0:" + nonGrpcEchoServerPort, + null); + assertEquals(1, echoTestServer.servers.size()); + int port1 = echoTestServer.servers.get(0).getPort(); + + ManagedChannelBuilder channelBuilder = + Grpc.newChannelBuilderForAddress("localhost", port1, InsecureChannelCredentials.create()); + ManagedChannel channel = channelBuilder.build(); + + EchoTestServiceGrpc.EchoTestServiceBlockingStub stub = + EchoTestServiceGrpc.newBlockingStub(channel); + + forwardServiceForNonGrpc.responsesToReturn.add( + ForwardEchoResponse.newBuilder().addOutput("line 1").addOutput("line 2").build()); + + ForwardEchoRequest forwardEchoRequest = + ForwardEchoRequest.newBuilder() + .setCount(COUNT_OF_REQUESTS_TO_FORWARD) + .setQps(100) + .setTimeoutMicros(2000_000L) // 2000 millis + .setUrl("http://www.example.com") // non grpc protocol + .addHeaders( + Header.newBuilder().setKey("test-key1").setValue("test-value1").build()) + .addHeaders( + Header.newBuilder().setKey("test-key2").setValue("test-value2").build()) + .setMessage("non-grpc-forward-echo-test-message1") + .build(); + + ForwardEchoResponse forwardEchoResponse = stub.forwardEcho(forwardEchoRequest); + List outputs = forwardEchoResponse.getOutputList(); + assertEquals(2, outputs.size()); + assertThat(outputs.get(0)).isEqualTo("line 1"); + assertThat(outputs.get(1)).isEqualTo("line 2"); + + assertThat(forwardServiceForNonGrpc.receivedRequests).hasSize(1); + ForwardEchoRequest receivedRequest = forwardServiceForNonGrpc.receivedRequests.remove(0); + assertThat(receivedRequest.getUrl()).isEqualTo("http://www.example.com"); + assertThat(receivedRequest.getMessage()).isEqualTo("non-grpc-forward-echo-test-message1"); + assertThat(receivedRequest.getCount()).isEqualTo(COUNT_OF_REQUESTS_TO_FORWARD); + assertThat(receivedRequest.getQps()).isEqualTo(100); + + forwardServiceForNonGrpc.responsesToReturn.add( + Status.UNIMPLEMENTED.asRuntimeException()); + forwardEchoRequest = + ForwardEchoRequest.newBuilder() + .setCount(1) + .setQps(100) + .setTimeoutMicros(2000_000L) // 2000 millis + .setUrl("redis://192.168.1.1") // non grpc protocol + .addHeaders( + Header.newBuilder().setKey("test-key1").setValue("test-value1").build()) + .setMessage("non-grpc-forward-echo-test-message2") + .build(); + + try { + ForwardEchoResponse unused = stub.forwardEcho(forwardEchoRequest); + fail("exception expected"); + } catch (StatusRuntimeException e) { + assertThat(e.getStatus()).isEqualTo(Status.UNIMPLEMENTED); + } + + assertThat(forwardServiceForNonGrpc.receivedRequests).hasSize(1); + receivedRequest = forwardServiceForNonGrpc.receivedRequests.remove(0); + assertThat(receivedRequest.getUrl()).isEqualTo("redis://192.168.1.1"); + assertThat(receivedRequest.getMessage()).isEqualTo("non-grpc-forward-echo-test-message2"); + assertThat(receivedRequest.getCount()).isEqualTo(1); + + forwardServiceForNonGrpc.responsesToReturn.add( + ForwardEchoResponse.newBuilder().addOutput("line 3").build()); + + forwardEchoRequest = + ForwardEchoRequest.newBuilder() + .setCount(1) + .setQps(100) + .setTimeoutMicros(2000_000L) // 2000 millis + .setUrl("http2://192.168.1.1") // non grpc protocol + .addHeaders( + Header.newBuilder().setKey("test-key3").setValue("test-value3").build()) + .setMessage("non-grpc-forward-echo-test-message3") + .build(); + forwardEchoResponse = stub.forwardEcho(forwardEchoRequest); + outputs = forwardEchoResponse.getOutputList(); + assertEquals(1, outputs.size()); + assertThat(outputs.get(0)).isEqualTo("line 3"); + + assertThat(forwardServiceForNonGrpc.receivedRequests).hasSize(1); + receivedRequest = forwardServiceForNonGrpc.receivedRequests.remove(0); + assertThat(receivedRequest.getUrl()).isEqualTo("http2://192.168.1.1"); + assertThat(receivedRequest.getMessage()).isEqualTo("non-grpc-forward-echo-test-message3"); + List
headers = receivedRequest.getHeadersList(); + assertThat(headers).hasSize(1); + assertThat(headers.get(0).getKey()).isEqualTo("test-key3"); + assertThat(headers.get(0).getValue()).isEqualTo("test-value3"); + + echoTestServer.stopServers(); + echoTestServer.blockUntilShutdown(); + nonGrpcEchoServer.shutdown(); + nonGrpcEchoServer.awaitTermination(5, TimeUnit.SECONDS); + } + + /** + * Emulate the Go Echo server that receives the non-grpc protocol requests. + */ + private static class ForwardServiceForNonGrpcImpl extends EchoTestServiceImplBase { + + List receivedRequests; + List responsesToReturn; + + @Override + public void forwardEcho(ForwardEchoRequest request, + StreamObserver responseObserver) { + receivedRequests.add(request); + Object response = responsesToReturn.remove(0); + if (response instanceof Throwable) { + responseObserver.onError((Throwable) response); + } else if (response instanceof ForwardEchoResponse) { + responseObserver.onNext((ForwardEchoResponse) response); + responseObserver.onCompleted(); + } + responseObserver.onError(new IllegalArgumentException("Unknown type in responsesToReturn")); + } + } +} diff --git a/xds/third_party/istio/LICENSE b/istio-interop-testing/third_party/istio/LICENSE similarity index 100% rename from xds/third_party/istio/LICENSE rename to istio-interop-testing/third_party/istio/LICENSE diff --git a/xds/third_party/istio/import.sh b/istio-interop-testing/third_party/istio/import.sh similarity index 83% rename from xds/third_party/istio/import.sh rename to istio-interop-testing/third_party/istio/import.sh index b31725083ea..739c3c20fce 100755 --- a/xds/third_party/istio/import.sh +++ b/istio-interop-testing/third_party/istio/import.sh @@ -17,22 +17,22 @@ set -e BRANCH=master - -VERSION=e0ce39487b4806bdd9062b2c0d0cae0bebbbac7b +# import VERSION from the istio repository +VERSION=cbee1999ad8b0f1ec790ec47f9ea33fed887f4a7 GIT_REPO="https://github.com/istio/istio.git" GIT_BASE_DIR=istio -SOURCE_PROTO_BASE_DIR=istio +SOURCE_PROTO_BASE_DIR=istio/pkg TARGET_PROTO_BASE_DIR=src/main/proto # Sorted alphabetically. FILES=( -security/proto/providers/google/meshca.proto +test/echo/proto/echo.proto ) -pushd `git rev-parse --show-toplevel`/xds/third_party/istio +pushd `git rev-parse --show-toplevel`/istio-interop-testing/third_party/istio # clone the istio github repo in a tmp directory tmpdir="$(mktemp -d)" -trap "rm -rf $tmpdir" EXIT +trap "rm -rf ${tmpdir}" EXIT pushd "${tmpdir}" git clone -b $BRANCH $GIT_REPO diff --git a/istio-interop-testing/third_party/istio/src/main/proto/test/echo/proto/echo.proto b/istio-interop-testing/third_party/istio/src/main/proto/test/echo/proto/echo.proto new file mode 100644 index 00000000000..7e931b13b99 --- /dev/null +++ b/istio-interop-testing/third_party/istio/src/main/proto/test/echo/proto/echo.proto @@ -0,0 +1,93 @@ +// Copyright Istio Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +import "google/protobuf/wrappers.proto"; + +// Generate with protoc --go_out=. echo.proto -I /work/common-protos/ -I. +package proto; +option go_package="../proto"; +option java_package = "io.istio.test"; +option java_outer_classname = "Echo"; + +service EchoTestService { + rpc Echo (EchoRequest) returns (EchoResponse); + rpc ForwardEcho (ForwardEchoRequest) returns (ForwardEchoResponse); +} + +message EchoRequest { + string message = 1; +} + +message EchoResponse { + string message = 1; +} + +message Header { + string key = 1; + string value = 2; +} + +message ForwardEchoRequest { + int32 count = 1; + int32 qps = 2; + int64 timeout_micros = 3; + string url = 4; + repeated Header headers = 5; + string message = 6; + // Method for the request. Valid only for HTTP + string method = 9; + // If true, requests will be sent using h2c prior knowledge + bool http2 = 7; + // If true, requests will be sent using http3 + bool http3 = 15; + // If true, requests will not be sent until magic string is received + bool serverFirst = 8; + // If true, 301 redirects will be followed + bool followRedirects = 14; + // If non-empty, make the request with the corresponding cert and key. + string cert = 10; + string key = 11; + // If non-empty, verify the server CA + string caCert = 12; + // If non-empty, make the request with the corresponding cert and key file. + string certFile = 16; + string keyFile = 17; + // If non-empty, verify the server CA with the ca cert file. + string caCertFile = 18; + // Skip verifying peer's certificate. + bool insecureSkipVerify = 19; + // List of ALPNs to present. If not set, this will be automatically be set based on the protocol + Alpn alpn = 13; + // Server name (SNI) to present in TLS connections. If not set, Host will be used for http requests. + string serverName = 20; + // Expected response determines what string to look for in the response to validate TCP requests succeeded. + // If not set, defaults to "StatusCode=200" + google.protobuf.StringValue expectedResponse = 21; + // If set, a new connection will be made to the server for each individual request. If false, an attempt + // will be made to re-use the connection for the life of the forward request. This is automatically + // set for DNS, TCP, TLS, and WebSocket protocols. + bool newConnectionPerRequest = 22; + // If set, each request will force a DNS lookup. Only applies if newConnectionPerRequest is set. + bool forceDNSLookup = 23; +} + +message Alpn { + repeated string value = 1; +} + +message ForwardEchoResponse { + repeated string output = 1; +} diff --git a/java_grpc_library.bzl b/java_grpc_library.bzl index 913e905a711..11d6e393e98 100644 --- a/java_grpc_library.bzl +++ b/java_grpc_library.bzl @@ -30,12 +30,12 @@ java_rpc_toolchain = rule( providers = [JavaInfo], ), "plugin": attr.label( - cfg = "host", + cfg = "exec", executable = True, ), "plugin_arg": attr.string(), "_protoc": attr.label( - cfg = "host", + cfg = "exec", default = Label("@com_google_protobuf//:protoc"), executable = True, ), @@ -85,7 +85,7 @@ def _java_rpc_library_impl(ctx): args = ctx.actions.args() args.add(toolchain.plugin, format = "--plugin=protoc-gen-rpc-plugin=%s") args.add("--rpc-plugin_out={0}:{1}".format(toolchain.plugin_arg, srcjar.path)) - args.add_joined("--descriptor_set_in", descriptor_set_in, join_with = ctx.host_configuration.host_path_separator) + args.add_joined("--descriptor_set_in", descriptor_set_in, join_with = ctx.configuration.host_path_separator) args.add_all(srcs, map_each = _path_ignoring_repository) ctx.actions.run( @@ -93,6 +93,7 @@ def _java_rpc_library_impl(ctx): outputs = [srcjar], executable = toolchain.protoc, arguments = [args], + use_default_shell_env = True, ) deps_java_info = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps]) diff --git a/netty/BUILD.bazel b/netty/BUILD.bazel index 668abc3ca30..d2497d065ec 100644 --- a/netty/BUILD.bazel +++ b/netty/BUILD.bazel @@ -24,15 +24,17 @@ java_library( "@io_netty_netty_handler_proxy//jar", "@io_netty_netty_resolver//jar", "@io_netty_netty_transport//jar", + "@io_netty_netty_transport_native_unix_common//jar", "@io_perfmark_perfmark_api//jar", ], ) # Mirrors the dependencies included in the artifact on Maven Central for usage -# with maven_install's override_targets. Purposefully does not export any -# symbols, as it should only be used as a dep for pre-compiled binaries on -# Maven Central. Not actually shaded; libraries should not be referencing -# unstable APIs so there should not be any references to the shaded package. +# with maven_install's override_targets. Should only be used as a dep for +# pre-compiled binaries on Maven Central. +# +# Not actually shaded; libraries should not be referencing unstable APIs so +# there should not be any references to the shaded package. java_library( name = "shaded_maven", visibility = ["//visibility:public"], diff --git a/netty/build.gradle b/netty/build.gradle index 5fdc8f20f08..013d962ed37 100644 --- a/netty/build.gradle +++ b/netty/build.gradle @@ -17,32 +17,65 @@ evaluationDependsOn(project(':grpc-core').path) dependencies { api project(':grpc-core'), - libraries.netty - implementation libraries.netty_proxy_handler, + libraries.netty.codec.http2 + implementation libs.netty.handler.proxy, libraries.guava, - libraries.errorprone, - libraries.perfmark + libraries.errorprone.annotations, + libraries.perfmark.api, + libraries.netty.unix.common // Tests depend on base class defined by core module. testImplementation project(':grpc-core').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, project(':grpc-testing'), - project(':grpc-testing-proto') - testRuntimeOnly libraries.netty_tcnative, + project(':grpc-testing-proto'), libraries.conscrypt, - libraries.netty_epoll - signature "org.codehaus.mojo.signature:java17:1.0@signature" - alpnagent libraries.jetty_alpn_agent + libraries.netty.transport.epoll + testRuntimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "windows-x86_64" + } + } + testRuntimeOnly (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-x86_64" + } + } + signature libraries.signature.java + signature libraries.signature.android + alpnagent libraries.jetty.alpn.agent } import net.ltgt.gradle.errorprone.CheckSeverity -[compileJava, compileTestJava].each() { +[tasks.named("compileJava"), tasks.named("compileTestJava")]*.configure { // Netty retuns a lot of futures that we mostly don't care about. - it.options.errorprone.check("FutureReturnValueIgnored", CheckSeverity.OFF) + options.errorprone.check("FutureReturnValueIgnored", CheckSeverity.OFF) } -javadoc { +tasks.named("javadoc").configure { options.links 'http://netty.io/4.1/api/' exclude 'io/grpc/netty/Internal*' } @@ -51,17 +84,17 @@ project.sourceSets { main { java { srcDir "${projectDir}/third_party/netty/java" } } } -test { +tasks.named("test").configure { // Allow testing Jetty ALPN in TlsTest jvmArgs "-javaagent:" + configurations.alpnagent.asPath } -jmh { +tasks.named("jmh").configure { // Workaround // https://github.com/melix/jmh-gradle-plugin/issues/97#issuecomment-316664026 includeTests = true } -checkstyleMain { +tasks.named("checkstyleMain").configure { source = source.minus(fileTree(dir: "src/main", include: "**/Http2ControlFrameLimitEncoder.java")) } diff --git a/netty/shaded/BUILD.bazel b/netty/shaded/BUILD.bazel index bb23cfdf489..657bf6aafa9 100644 --- a/netty/shaded/BUILD.bazel +++ b/netty/shaded/BUILD.bazel @@ -5,6 +5,8 @@ java_library( runtime_deps = [ "//netty", "@io_netty_netty_tcnative_boringssl_static//jar", + "@io_netty_netty_tcnative_classes//jar", + "@io_netty_netty_transport_native_unix_common//jar", "@io_netty_netty_transport_native_epoll_linux_x86_64//jar", ], ) diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index e2409a86d99..0a9be812d7b 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -9,6 +9,7 @@ plugins { id "maven-publish" id "com.github.johnrengelman.shadow" + id "ru.vyarus.animalsniffer" } description = "gRPC: Netty Shaded" @@ -17,9 +18,43 @@ sourceSets { testShadow {} } dependencies { implementation project(':grpc-netty') - runtimeOnly libraries.netty_tcnative, - libraries.netty_epoll, - libraries.netty_epoll_arm64 + runtimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + runtimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-x86_64" + } + } + runtimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-aarch_64" + } + } + runtimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-x86_64" + } + } + runtimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-aarch_64" + } + } + runtimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "windows-x86_64" + } + } + runtimeOnly (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-x86_64" + } + } + runtimeOnly (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-aarch_64" + } + } testShadowImplementation files(shadowJar), project(':grpc-testing-proto'), project(':grpc-testing'), @@ -27,14 +62,16 @@ dependencies { shadow project(':grpc-netty').configurations.runtimeClasspath.allDependencies.matching { it.group != 'io.netty' } + signature libraries.signature.java + signature libraries.signature.android } -jar { +tasks.named("jar").configure { // Must use a different archiveClassifier to avoid conflicting with shadowJar archiveClassifier = 'original' } -shadowJar { +tasks.named("shadowJar").configure { archiveClassifier = null dependencies { include(project(':grpc-netty')) @@ -86,14 +123,18 @@ publishing { } } -task testShadow(type: Test) { +tasks.register("testShadow", Test) { testClassesDirs = sourceSets.testShadow.output.classesDirs classpath = sourceSets.testShadow.runtimeClasspath } -compileTestShadowJava.options.compilerArgs = compileTestJava.options.compilerArgs -compileTestShadowJava.options.encoding = compileTestJava.options.encoding +tasks.named("compileTestShadowJava").configure { + options.compilerArgs = compileTestJava.options.compilerArgs + options.encoding = compileTestJava.options.encoding +} -test.dependsOn testShadow +tasks.named("test").configure { + dependsOn tasks.named("testShadow") +} /** * A Transformer which updates the Netty JAR META-INF/ resources to accurately diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index 5c2ff317ccd..e1f46720e05 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -39,7 +39,6 @@ import io.grpc.testing.protobuf.SimpleServiceGrpc; import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceBlockingStub; import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceImplBase; - import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; diff --git a/netty/src/jmh/java/io/grpc/netty/InboundHeadersBenchmark.java b/netty/src/jmh/java/io/grpc/netty/InboundHeadersBenchmark.java index aef75a1a26c..504f4ec27d7 100644 --- a/netty/src/jmh/java/io/grpc/netty/InboundHeadersBenchmark.java +++ b/netty/src/jmh/java/io/grpc/netty/InboundHeadersBenchmark.java @@ -147,39 +147,39 @@ private static void clientHandler(Blackhole bh, Http2Headers headers) { bh.consume(Utils.convertHeaders(headers)); } -// /** -// * Prints the size of the header objects in bytes. Needs JOL (Java Object Layout) as a -// * dependency. -// */ -// public static void main(String... args) { -// Http2Headers grpcRequestHeaders = new GrpcHttp2RequestHeaders(4); -// Http2Headers defaultRequestHeaders = new DefaultHttp2Headers(true, 9); -// for (int i = 0; i < requestHeaders.length; i += 2) { -// grpcRequestHeaders.add(requestHeaders[i], requestHeaders[i + 1]); -// defaultRequestHeaders.add(requestHeaders[i], requestHeaders[i + 1]); -// } -// long c = 10L; -// int m = ((int) c) / 20; -// -// long grpcRequestHeadersBytes = GraphLayout.parseInstance(grpcRequestHeaders).totalSize(); -// long defaultRequestHeadersBytes = -// GraphLayout.parseInstance(defaultRequestHeaders).totalSize(); -// -// System.out.printf("gRPC Request Headers: %d bytes%nNetty Request Headers: %d bytes%n", -// grpcRequestHeadersBytes, defaultRequestHeadersBytes); -// -// Http2Headers grpcResponseHeaders = new GrpcHttp2RequestHeaders(4); -// Http2Headers defaultResponseHeaders = new DefaultHttp2Headers(true, 9); -// for (int i = 0; i < responseHeaders.length; i += 2) { -// grpcResponseHeaders.add(responseHeaders[i], responseHeaders[i + 1]); -// defaultResponseHeaders.add(responseHeaders[i], responseHeaders[i + 1]); -// } -// -// long grpcResponseHeadersBytes = GraphLayout.parseInstance(grpcResponseHeaders).totalSize(); -// long defaultResponseHeadersBytes = -// GraphLayout.parseInstance(defaultResponseHeaders).totalSize(); -// -// System.out.printf("gRPC Response Headers: %d bytes%nNetty Response Headers: %d bytes%n", -// grpcResponseHeadersBytes, defaultResponseHeadersBytes); -// } + ///** + // * Prints the size of the header objects in bytes. Needs JOL (Java Object Layout) as a + // * dependency. + // */ + //public static void main(String... args) { + // Http2Headers grpcRequestHeaders = new GrpcHttp2RequestHeaders(4); + // Http2Headers defaultRequestHeaders = new DefaultHttp2Headers(true, 9); + // for (int i = 0; i < requestHeaders.length; i += 2) { + // grpcRequestHeaders.add(requestHeaders[i], requestHeaders[i + 1]); + // defaultRequestHeaders.add(requestHeaders[i], requestHeaders[i + 1]); + // } + // long c = 10L; + // int m = ((int) c) / 20; + + // long grpcRequestHeadersBytes = GraphLayout.parseInstance(grpcRequestHeaders).totalSize(); + // long defaultRequestHeadersBytes = + // GraphLayout.parseInstance(defaultRequestHeaders).totalSize(); + + // System.out.printf("gRPC Request Headers: %d bytes%nNetty Request Headers: %d bytes%n", + // grpcRequestHeadersBytes, defaultRequestHeadersBytes); + + // Http2Headers grpcResponseHeaders = new GrpcHttp2RequestHeaders(4); + // Http2Headers defaultResponseHeaders = new DefaultHttp2Headers(true, 9); + // for (int i = 0; i < responseHeaders.length; i += 2) { + // grpcResponseHeaders.add(responseHeaders[i], responseHeaders[i + 1]); + // defaultResponseHeaders.add(responseHeaders[i], responseHeaders[i + 1]); + // } + + // long grpcResponseHeadersBytes = GraphLayout.parseInstance(grpcResponseHeaders).totalSize(); + // long defaultResponseHeadersBytes = + // GraphLayout.parseInstance(defaultResponseHeaders).totalSize(); + + // System.out.printf("gRPC Response Headers: %d bytes%nNetty Response Headers: %d bytes%n", + // grpcResponseHeadersBytes, defaultResponseHeadersBytes); + //} } diff --git a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java index ab66472105a..7f088509c04 100644 --- a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java +++ b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java @@ -16,10 +16,12 @@ package io.grpc.netty; +import static com.google.common.base.Preconditions.checkNotNull; import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Ticker; import io.grpc.ChannelLogger; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; @@ -44,6 +46,7 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { private boolean autoTuneFlowControlOn; private ChannelHandlerContext ctx; private boolean initialWindowSent = false; + private final Ticker ticker; private static final long BDP_MEASUREMENT_PING = 1234; @@ -54,7 +57,8 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { Http2Settings initialSettings, ChannelLogger negotiationLogger, boolean autoFlowControl, - PingLimiter pingLimiter) { + PingLimiter pingLimiter, + Ticker ticker) { super(channelUnused, decoder, encoder, initialSettings, negotiationLogger); // During a graceful shutdown, wait until all streams are closed. @@ -62,12 +66,13 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { // Extract the connection window from the settings if it was set. this.initialConnectionWindow = initialSettings.initialWindowSize() == null ? -1 : - initialSettings.initialWindowSize(); + initialSettings.initialWindowSize(); this.autoTuneFlowControlOn = autoFlowControl; if (pingLimiter == null) { pingLimiter = new AllowPingLimiter(); } this.flowControlPing = new FlowControlPinger(pingLimiter); + this.ticker = checkNotNull(ticker, "ticker"); } @Override @@ -131,14 +136,17 @@ void setAutoTuneFlowControl(boolean isOn) { final class FlowControlPinger { private static final int MAX_WINDOW_SIZE = 8 * 1024 * 1024; + public static final int MAX_BACKOFF = 10; private final PingLimiter pingLimiter; private int pingCount; private int pingReturn; private boolean pinging; private int dataSizeSincePing; - private float lastBandwidth; // bytes per second + private long lastBandwidth; // bytes per nanosecond private long lastPingTime; + private int lastTargetWindow; + private int pingFrequencyMultiplier; public FlowControlPinger(PingLimiter pingLimiter) { Preconditions.checkNotNull(pingLimiter, "pingLimiter"); @@ -157,10 +165,24 @@ public void onDataRead(int dataLength, int paddingLength) { if (!autoTuneFlowControlOn) { return; } - if (!isPinging() && pingLimiter.isPingAllowed()) { + + // Note that we are double counting around the ping initiation as the current data will be + // added at the end of this method, so will be available in the next check. This at worst + // causes us to send a ping one data packet earlier, but makes startup faster if there are + // small packets before big ones. + int dataForCheck = getDataSincePing() + dataLength + paddingLength; + // Need to double the data here to account for targetWindow being set to twice the data below + if (!isPinging() && pingLimiter.isPingAllowed() + && dataForCheck * 2 >= lastTargetWindow * pingFrequencyMultiplier) { setPinging(true); sendPing(ctx()); } + + if (lastTargetWindow == 0) { + lastTargetWindow = + decoder().flowController().initialWindowSize(connection().connectionStream()); + } + incrementDataSincePing(dataLength + paddingLength); } @@ -169,25 +191,32 @@ public void updateWindow() throws Http2Exception { return; } pingReturn++; - long elapsedTime = (System.nanoTime() - lastPingTime); + setPinging(false); + + long elapsedTime = (ticker.read() - lastPingTime); if (elapsedTime == 0) { elapsedTime = 1; } + long bandwidth = (getDataSincePing() * TimeUnit.SECONDS.toNanos(1)) / elapsedTime; - Http2LocalFlowController fc = decoder().flowController(); // Calculate new window size by doubling the observed BDP, but cap at max window int targetWindow = Math.min(getDataSincePing() * 2, MAX_WINDOW_SIZE); - setPinging(false); + Http2LocalFlowController fc = decoder().flowController(); int currentWindow = fc.initialWindowSize(connection().connectionStream()); - if (targetWindow > currentWindow && bandwidth > lastBandwidth) { - lastBandwidth = bandwidth; - int increase = targetWindow - currentWindow; - fc.incrementWindowSize(connection().connectionStream(), increase); - fc.initialWindowSize(targetWindow); - Http2Settings settings = new Http2Settings(); - settings.initialWindowSize(targetWindow); - frameWriter().writeSettings(ctx(), settings, ctx().newPromise()); + if (bandwidth <= lastBandwidth || targetWindow <= currentWindow) { + pingFrequencyMultiplier = Math.min(pingFrequencyMultiplier + 1, MAX_BACKOFF); + return; } + + pingFrequencyMultiplier = 0; // react quickly when size is changing + lastBandwidth = bandwidth; + lastTargetWindow = targetWindow; + int increase = targetWindow - currentWindow; + fc.incrementWindowSize(connection().connectionStream(), increase); + fc.initialWindowSize(targetWindow); + Http2Settings settings = new Http2Settings(); + settings.initialWindowSize(targetWindow); + frameWriter().writeSettings(ctx(), settings, ctx().newPromise()); } private boolean isPinging() { @@ -200,7 +229,7 @@ private void setPinging(boolean pingOut) { private void sendPing(ChannelHandlerContext ctx) { setDataSizeSincePing(0); - lastPingTime = System.nanoTime(); + lastPingTime = ticker.read(); encoder().writePing(ctx, false, BDP_MEASUREMENT_PING, ctx.newPromise()); pingCount++; } @@ -229,16 +258,18 @@ private void setDataSizeSincePing(int dataSize) { dataSizeSincePing = dataSize; } + // Only used in testing @VisibleForTesting void setDataSizeAndSincePing(int dataSize) { setDataSizeSincePing(dataSize); - lastPingTime = System.nanoTime() - TimeUnit.SECONDS.toNanos(1); + pingFrequencyMultiplier = 1; + lastPingTime = ticker.read() ; } } /** Controls whether PINGs like those for BDP are permitted to be sent at the current time. */ public interface PingLimiter { - public boolean isPingAllowed(); + boolean isPingAllowed(); } private static final class AllowPingLimiter implements PingLimiter { diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java index 25f4f9232cf..13f55226483 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java @@ -34,6 +34,9 @@ */ @Internal public abstract class GrpcHttp2ConnectionHandler extends Http2ConnectionHandler { + static final int ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT = 1024; + static final Cumulator ADAPTIVE_CUMULATOR = + new NettyAdaptiveCumulator(ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT); @Nullable protected final ChannelPromise channelUnused; @@ -48,6 +51,7 @@ protected GrpcHttp2ConnectionHandler( super(decoder, encoder, initialSettings); this.channelUnused = channelUnused; this.negotiationLogger = negotiationLogger; + setCumulator(ADAPTIVE_CUMULATOR); } /** diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java index df7875fc7ae..4023fd1218f 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java @@ -340,7 +340,12 @@ public Http2Headers add(CharSequence csName, CharSequence csValue) { AsciiString name = validateName(requireAsciiString(csName)); AsciiString value = requireAsciiString(csValue); if (isPseudoHeader(name)) { - addPseudoHeader(name, value); + AsciiString previous = getPseudoHeader(name); + if (previous != null) { + PlatformDependent.throwException( + connectionError(PROTOCOL_ERROR, "Duplicate %s header", name)); + } + setPseudoHeader(name, value); return this; } if (equals(TE_HEADER, name)) { @@ -353,44 +358,42 @@ public Http2Headers add(CharSequence csName, CharSequence csValue) { @Override public CharSequence get(CharSequence csName) { AsciiString name = requireAsciiString(csName); - checkArgument(!isPseudoHeader(name), "Use direct accessor methods for pseudo headers."); + if (isPseudoHeader(name)) { + return getPseudoHeader(name); + } if (equals(TE_HEADER, name)) { return te; } return get(name); } - private void addPseudoHeader(CharSequence csName, CharSequence csValue) { - AsciiString name = requireAsciiString(csName); - AsciiString value = requireAsciiString(csValue); + private AsciiString getPseudoHeader(AsciiString name) { + if (equals(PATH_HEADER, name)) { + return path; + } else if (equals(AUTHORITY_HEADER, name)) { + return authority; + } else if (equals(METHOD_HEADER, name)) { + return method; + } else if (equals(SCHEME_HEADER, name)) { + return scheme; + } else { + return null; + } + } + private void setPseudoHeader(AsciiString name, AsciiString value) { if (equals(PATH_HEADER, name)) { - if (path != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :path header")); - } path = value; } else if (equals(AUTHORITY_HEADER, name)) { - if (authority != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :authority header")); - } authority = value; } else if (equals(METHOD_HEADER, name)) { - if (method != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :method header")); - } method = value; } else if (equals(SCHEME_HEADER, name)) { - if (scheme != null) { - PlatformDependent.throwException( - connectionError(PROTOCOL_ERROR, "Duplicate :scheme header")); - } scheme = value; } else { PlatformDependent.throwException( connectionError(PROTOCOL_ERROR, "Illegal pseudo-header '%s' in request.", name)); + throw new AssertionError(); // Make flow control obvious to javac } } @@ -418,8 +421,12 @@ public CharSequence scheme() { public List getAll(CharSequence csName) { AsciiString name = requireAsciiString(csName); if (isPseudoHeader(name)) { - // This code should never be reached. - throw new IllegalArgumentException("Use direct accessor methods for pseudo headers."); + AsciiString value = getPseudoHeader(name); + if (value == null) { + return Collections.emptyList(); + } else { + return Collections.singletonList(value); + } } if (equals(TE_HEADER, name)) { return Collections.singletonList((CharSequence) te); @@ -431,8 +438,12 @@ public List getAll(CharSequence csName) { public boolean remove(CharSequence csName) { AsciiString name = requireAsciiString(csName); if (isPseudoHeader(name)) { - // This code should never be reached. - throw new IllegalArgumentException("Use direct accessor methods for pseudo headers."); + if (getPseudoHeader(name) == null) { + return false; + } else { + setPseudoHeader(name, null); + return true; + } } if (equals(TE_HEADER, name)) { boolean wasPresent = te != null; diff --git a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java index d89c1647e5d..04a290165d7 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java +++ b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java @@ -154,8 +154,7 @@ public static SslContextBuilder configure(SslContextBuilder builder) { @CanIgnoreReturnValue public static SslContextBuilder configure(SslContextBuilder builder, SslProvider provider) { switch (provider) { - case JDK: - { + case JDK: { Provider jdkProvider = findJdkProvider(); if (jdkProvider == null) { throw new IllegalArgumentException( @@ -163,8 +162,7 @@ public static SslContextBuilder configure(SslContextBuilder builder, SslProvider } return configure(builder, jdkProvider); } - case OPENSSL: - { + case OPENSSL: { ApplicationProtocolConfig apc; if (OpenSsl.isAlpnSupported()) { apc = NPN_AND_ALPN; diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java index 72cb211ecf3..c5ad99181ef 100644 --- a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java @@ -16,10 +16,12 @@ package io.grpc.netty; +import com.google.common.annotations.VisibleForTesting; import io.grpc.Internal; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TransportTracer; import io.netty.channel.socket.nio.NioSocketChannel; /** @@ -107,5 +109,11 @@ public static ClientTransportFactory buildTransportFactory(NettyChannelBuilder b return builder.buildTransportFactory(); } + @VisibleForTesting + public static void setTransportTracerFactory( + NettyChannelBuilder builder, TransportTracer.Factory factory) { + builder.setTransportTracerFactory(factory); + } + private InternalNettyChannelBuilder() {} } diff --git a/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java b/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java new file mode 100644 index 00000000000..b3a28c55c79 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java @@ -0,0 +1,224 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.handler.codec.ByteToMessageDecoder.Cumulator; + +class NettyAdaptiveCumulator implements Cumulator { + private final int composeMinSize; + + /** + * "Adaptive" cumulator: cumulate {@link ByteBuf}s by dynamically switching between merge and + * compose strategies. + * + * @param composeMinSize Determines the minimal size of the buffer that should be composed (added + * as a new component of the {@link CompositeByteBuf}). If the total size + * of the last component (tail) and the incoming buffer is below this value, + * the incoming buffer is appended to the tail, and the new component is not + * added. + */ + NettyAdaptiveCumulator(int composeMinSize) { + Preconditions.checkArgument(composeMinSize >= 0, "composeMinSize must be non-negative"); + this.composeMinSize = composeMinSize; + } + + /** + * "Adaptive" cumulator: cumulate {@link ByteBuf}s by dynamically switching between merge and + * compose strategies. + * + *

This cumulator applies a heuristic to make a decision whether to track a reference to the + * buffer with bytes received from the network stack in an array ("zero-copy"), or to merge into + * the last component (the tail) by performing a memory copy. + * + *

It is necessary as a protection from a potential attack on the {@link + * io.netty.handler.codec.ByteToMessageDecoder#COMPOSITE_CUMULATOR}. Consider a pathological case + * when an attacker sends TCP packages containing a single byte of data, and forcing the cumulator + * to track each one in a separate buffer. The cost is memory overhead for each buffer, and extra + * compute to read the cumulation. + * + *

Implemented heuristic establishes a minimal threshold for the total size of the tail and + * incoming buffer, below which they are merged. The sum of the tail and the incoming buffer is + * used to avoid a case where attacker alternates the size of data packets to trick the cumulator + * into always selecting compose strategy. + * + *

Merging strategy attempts to minimize unnecessary memory writes. When possible, it expands + * the tail capacity and only copies the incoming buffer into available memory. Otherwise, when + * both tail and the buffer must be copied, the tail is reallocated (or fully replaced) with a new + * buffer of exponentially increasing capacity (bounded to {@link #composeMinSize}) to ensure + * runtime {@code O(n^2)} is amortized to {@code O(n)}. + */ + @Override + @SuppressWarnings("ReferenceEquality") + public final ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) { + if (!cumulation.isReadable()) { + cumulation.release(); + return in; + } + CompositeByteBuf composite = null; + try { + if (cumulation instanceof CompositeByteBuf && cumulation.refCnt() == 1) { + composite = (CompositeByteBuf) cumulation; + // Writer index must equal capacity if we are going to "write" + // new components to the end + if (composite.writerIndex() != composite.capacity()) { + composite.capacity(composite.writerIndex()); + } + } else { + composite = alloc.compositeBuffer(Integer.MAX_VALUE) + .addFlattenedComponents(true, cumulation); + } + addInput(alloc, composite, in); + in = null; + return composite; + } finally { + if (in != null) { + // We must release if the ownership was not transferred as otherwise it may produce a leak + in.release(); + // Also release any new buffer allocated if we're not returning it + if (composite != null && composite != cumulation) { + composite.release(); + } + } + } + } + + @VisibleForTesting + void addInput(ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + if (shouldCompose(composite, in, composeMinSize)) { + composite.addFlattenedComponents(true, in); + } else { + // The total size of the new data and the last component are below the threshold. Merge them. + mergeWithCompositeTail(alloc, composite, in); + } + } + + @VisibleForTesting + static boolean shouldCompose(CompositeByteBuf composite, ByteBuf in, int composeMinSize) { + int componentCount = composite.numComponents(); + if (composite.numComponents() == 0) { + return true; + } + int inputSize = in.readableBytes(); + int tailStart = composite.toByteIndex(componentCount - 1); + int tailSize = composite.writerIndex() - tailStart; + return tailSize + inputSize >= composeMinSize; + } + + /** + * Append the given {@link ByteBuf} {@code in} to {@link CompositeByteBuf} {@code composite} by + * expanding or replacing the tail component of the {@link CompositeByteBuf}. + * + *

The goal is to prevent {@code O(n^2)} runtime in a pathological case, that forces copying + * the tail component into a new buffer, for each incoming single-byte buffer. We append the new + * bytes to the tail, when a write (or a fast write) is possible. + * + *

Otherwise, the tail is replaced with a new buffer, with the capacity increased enough to + * achieve runtime amortization. + * + *

We assume that implementations of {@link ByteBufAllocator#calculateNewCapacity(int, int)}, + * are similar to {@link io.netty.buffer.AbstractByteBufAllocator#calculateNewCapacity(int, int)}, + * which doubles buffer capacity by normalizing it to the closest power of two. This assumption + * is verified in unit tests for this method. + */ + @VisibleForTesting + static void mergeWithCompositeTail( + ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + int inputSize = in.readableBytes(); + int tailComponentIndex = composite.numComponents() - 1; + int tailStart = composite.toByteIndex(tailComponentIndex); + int tailSize = composite.writerIndex() - tailStart; + int newTailSize = inputSize + tailSize; + ByteBuf tail = composite.component(tailComponentIndex); + ByteBuf newTail = null; + try { + if (tail.refCnt() == 1 && !tail.isReadOnly() && newTailSize <= tail.maxCapacity()) { + // Ideal case: the tail isn't shared, and can be expanded to the required capacity. + // Take ownership of the tail. + newTail = tail.retain(); + + // TODO(https://github.com/netty/netty/issues/12844): remove when we use Netty with + // the issue fixed. + // In certain cases, removing the CompositeByteBuf component, and then adding it back + // isn't idempotent. An example is provided in https://github.com/netty/netty/issues/12844. + // This happens because the buffer returned by composite.component() has out-of-sync + // indexes. Under the hood the CompositeByteBuf returns a duplicate() of the underlying + // buffer, but doesn't set the indexes. + // + // To get the right indexes we use the fact that composite.internalComponent() returns + // the slice() into the readable portion of the underlying buffer. + // We use this implementation detail (internalComponent() returning a *SlicedByteBuf), + // and combine it with the fact that SlicedByteBuf duplicates have their indexes + // adjusted so they correspond to the to the readable portion of the slice. + // + // Hence composite.internalComponent().duplicate() returns a buffer with the + // indexes that should've been on the composite.component() in the first place. + // Until the issue is fixed, we manually adjust the indexes of the removed component. + ByteBuf sliceDuplicate = composite.internalComponent(tailComponentIndex).duplicate(); + newTail.setIndex(sliceDuplicate.readerIndex(), sliceDuplicate.writerIndex()); + + /* + * The tail is a readable non-composite buffer, so writeBytes() handles everything for us. + * + * - ensureWritable() performs a fast resize when possible (f.e. PooledByteBuf simply + * updates its boundary to the end of consecutive memory run assigned to this buffer) + * - when the required size doesn't fit into writableBytes(), a new buffer is + * allocated, and the capacity calculated with alloc.calculateNewCapacity() + * - note that maxFastWritableBytes() would normally allow a fast expansion of PooledByteBuf + * is not called because CompositeByteBuf.component() returns a duplicate, wrapped buffer. + * Unwrapping buffers is unsafe, and potential benefit of fast writes may not be + * as pronounced because the capacity is doubled with each reallocation. + */ + newTail.writeBytes(in); + } else { + // The tail is shared, or not expandable. Replace it with a new buffer of desired capacity. + newTail = alloc.buffer(alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE)); + newTail.setBytes(0, composite, tailStart, tailSize) + .setBytes(tailSize, in, in.readerIndex(), inputSize) + .writerIndex(newTailSize); + in.readerIndex(in.writerIndex()); + } + // Store readerIndex to avoid out of bounds writerIndex during component replacement. + int prevReader = composite.readerIndex(); + // Remove the old tail, reset writer index. + composite.removeComponent(tailComponentIndex).setIndex(0, tailStart); + // Add back the new tail. + composite.addFlattenedComponents(true, newTail); + // New tail's ownership transferred to the composite buf. + newTail = null; + in.release(); + in = null; + // Restore the reader. In case it fails we restore the reader after releasing/forgetting + // the input and the new tail so that finally block can handles them properly. + composite.readerIndex(prevReader); + } finally { + // Input buffer was merged with the tail. + if (in != null) { + in.release(); + } + // If new tail's ownership isn't transferred to the composite buf. + // Release it to prevent a leak. + if (newTail != null) { + newTail.release(); + } + } + } +} diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index ca029fe7cfc..da7fe84d9cb 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -23,7 +23,9 @@ import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.InlineMe; import io.grpc.Attributes; import io.grpc.CallCredentials; @@ -62,7 +64,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.ssl.SSLException; @@ -70,7 +71,7 @@ * A builder to help simplify construction of channels using the Netty transport. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1784") -@CanIgnoreReturnValue +@CheckReturnValue public final class NettyChannelBuilder extends AbstractManagedChannelImplBuilder { @@ -122,7 +123,6 @@ public final class NettyChannelBuilder extends * noticing changes to DNS. If an unresolved InetSocketAddress is passed in, then it will remain * unresolved. */ - @CheckReturnValue public static NettyChannelBuilder forAddress(SocketAddress serverAddress) { return new NettyChannelBuilder(serverAddress); } @@ -134,7 +134,6 @@ public static NettyChannelBuilder forAddress(SocketAddress serverAddress) { * method, since that API permits delaying DNS lookups and noticing changes to DNS. If an * unresolved InetSocketAddress is passed in, then it will remain unresolved. */ - @CheckReturnValue public static NettyChannelBuilder forAddress(SocketAddress serverAddress, ChannelCredentials creds) { FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); @@ -147,7 +146,6 @@ public static NettyChannelBuilder forAddress(SocketAddress serverAddress, /** * Creates a new builder with the given host and port. */ - @CheckReturnValue public static NettyChannelBuilder forAddress(String host, int port) { return forTarget(GrpcUtil.authorityFromHostAndPort(host, port)); } @@ -155,7 +153,6 @@ public static NettyChannelBuilder forAddress(String host, int port) { /** * Creates a new builder with the given host and port. */ - @CheckReturnValue public static NettyChannelBuilder forAddress(String host, int port, ChannelCredentials creds) { return forTarget(GrpcUtil.authorityFromHostAndPort(host, port), creds); } @@ -164,7 +161,6 @@ public static NettyChannelBuilder forAddress(String host, int port, ChannelCrede * Creates a new builder with the given target string that will be resolved by * {@link io.grpc.NameResolver}. */ - @CheckReturnValue public static NettyChannelBuilder forTarget(String target) { return new NettyChannelBuilder(target); } @@ -173,7 +169,6 @@ public static NettyChannelBuilder forTarget(String target) { * Creates a new builder with the given target string that will be resolved by * {@link io.grpc.NameResolver}. */ - @CheckReturnValue public static NettyChannelBuilder forTarget(String target, ChannelCredentials creds) { FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); if (result.error != null) { @@ -196,7 +191,6 @@ public int getDefaultPort() { } } - @CheckReturnValue NettyChannelBuilder(String target) { managedChannelImplBuilder = new ManagedChannelImplBuilder(target, new NettyChannelTransportFactoryBuilder(), @@ -215,7 +209,6 @@ public int getDefaultPort() { this.freezeProtocolNegotiatorFactory = true; } - @CheckReturnValue NettyChannelBuilder(SocketAddress address) { managedChannelImplBuilder = new ManagedChannelImplBuilder(address, getAuthorityFromAddress(address), @@ -242,7 +235,6 @@ protected ManagedChannelBuilder delegate() { return managedChannelImplBuilder; } - @CheckReturnValue private static String getAuthorityFromAddress(SocketAddress address) { if (address instanceof InetSocketAddress) { InetSocketAddress inetAddress = (InetSocketAddress) address; @@ -266,6 +258,7 @@ private static String getAuthorityFromAddress(SocketAddress address) { * {@link NioSocketChannel} must use {@link io.netty.channel.nio.NioEventLoopGroup}, otherwise * your application won't start. */ + @CanIgnoreReturnValue public NettyChannelBuilder channelType(Class channelType) { checkNotNull(channelType, "channelType"); return channelFactory(new ReflectiveChannelFactory<>(channelType)); @@ -284,6 +277,7 @@ public NettyChannelBuilder channelType(Class channelType) { * {@link NioSocketChannel} based {@link ChannelFactory} must use {@link * io.netty.channel.nio.NioEventLoopGroup}, otherwise your application won't start. */ + @CanIgnoreReturnValue public NettyChannelBuilder channelFactory(ChannelFactory channelFactory) { this.channelFactory = checkNotNull(channelFactory, "channelFactory"); return this; @@ -293,6 +287,7 @@ public NettyChannelBuilder channelFactory(ChannelFactory chan * Specifies a channel option. As the underlying channel as well as network implementation may * ignore this value applications should consider it a hint. */ + @CanIgnoreReturnValue public NettyChannelBuilder withOption(ChannelOption option, T value) { channelOptions.put(option, value); return this; @@ -303,6 +298,7 @@ public NettyChannelBuilder withOption(ChannelOption option, T value) { * *

Default: TLS */ + @CanIgnoreReturnValue public NettyChannelBuilder negotiationType(NegotiationType type) { checkState(!freezeProtocolNegotiatorFactory, "Cannot change security when using ChannelCredentials"); @@ -327,6 +323,7 @@ public NettyChannelBuilder negotiationType(NegotiationType type) { *

The channel won't take ownership of the given EventLoopGroup. It's caller's responsibility * to shut it down when it's desired. */ + @CanIgnoreReturnValue public NettyChannelBuilder eventLoopGroup(@Nullable EventLoopGroup eventLoopGroup) { if (eventLoopGroup != null) { return eventLoopGroupPool(new FixedObjectPool<>(eventLoopGroup)); @@ -334,6 +331,7 @@ public NettyChannelBuilder eventLoopGroup(@Nullable EventLoopGroup eventLoopGrou return eventLoopGroupPool(DEFAULT_EVENT_LOOP_GROUP_POOL); } + @CanIgnoreReturnValue NettyChannelBuilder eventLoopGroupPool(ObjectPool eventLoopGroupPool) { this.eventLoopGroupPool = checkNotNull(eventLoopGroupPool, "eventLoopGroupPool"); return this; @@ -343,6 +341,7 @@ NettyChannelBuilder eventLoopGroupPool(ObjectPool even * SSL/TLS context to use instead of the system default. It must have been configured with {@link * GrpcSslContexts}, but options could have been overridden. */ + @CanIgnoreReturnValue public NettyChannelBuilder sslContext(SslContext sslContext) { checkState(!freezeProtocolNegotiatorFactory, "Cannot change security when using ChannelCredentials"); @@ -365,6 +364,7 @@ public NettyChannelBuilder sslContext(SslContext sslContext) { * tuning, use {@link #flowControlWindow(int)}. By default, auto flow control is enabled with * initial flow control window size of {@link #DEFAULT_FLOW_CONTROL_WINDOW}. */ + @CanIgnoreReturnValue public NettyChannelBuilder initialFlowControlWindow(int initialFlowControlWindow) { checkArgument(initialFlowControlWindow > 0, "initialFlowControlWindow must be positive"); this.flowControlWindow = initialFlowControlWindow; @@ -378,6 +378,7 @@ public NettyChannelBuilder initialFlowControlWindow(int initialFlowControlWindow * called, the default value is {@link #DEFAULT_FLOW_CONTROL_WINDOW}) with auto flow control * tuning. */ + @CanIgnoreReturnValue public NettyChannelBuilder flowControlWindow(int flowControlWindow) { checkArgument(flowControlWindow > 0, "flowControlWindow must be positive"); this.flowControlWindow = flowControlWindow; @@ -393,6 +394,7 @@ public NettyChannelBuilder flowControlWindow(int flowControlWindow) { * * @deprecated Use {@link #maxInboundMetadataSize} instead */ + @CanIgnoreReturnValue @Deprecated @InlineMe(replacement = "this.maxInboundMetadataSize(maxHeaderListSize)") public NettyChannelBuilder maxHeaderListSize(int maxHeaderListSize) { @@ -410,6 +412,7 @@ public NettyChannelBuilder maxHeaderListSize(int maxHeaderListSize) { * @throws IllegalArgumentException if bytes is non-positive * @since 1.17.0 */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder maxInboundMetadataSize(int bytes) { checkArgument(bytes > 0, "maxInboundMetadataSize must be > 0"); @@ -420,6 +423,7 @@ public NettyChannelBuilder maxInboundMetadataSize(int bytes) { /** * Equivalent to using {@link #negotiationType(NegotiationType)} with {@code PLAINTEXT}. */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder usePlaintext() { negotiationType(NegotiationType.PLAINTEXT); @@ -429,6 +433,7 @@ public NettyChannelBuilder usePlaintext() { /** * Equivalent to using {@link #negotiationType(NegotiationType)} with {@code TLS}. */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder useTransportSecurity() { negotiationType(NegotiationType.TLS); @@ -440,6 +445,7 @@ public NettyChannelBuilder useTransportSecurity() { * * @since 1.3.0 */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { checkArgument(keepAliveTime > 0L, "keepalive time must be positive"); @@ -457,6 +463,7 @@ public NettyChannelBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) * * @since 1.3.0 */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { checkArgument(keepAliveTimeout > 0L, "keepalive timeout must be positive"); @@ -470,6 +477,7 @@ public NettyChannelBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit time * * @since 1.3.0 */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder keepAliveWithoutCalls(boolean enable) { keepAliveWithoutCalls = enable; @@ -480,6 +488,7 @@ public NettyChannelBuilder keepAliveWithoutCalls(boolean enable) { /** * If non-{@code null}, attempts to create connections bound to a local port. */ + @CanIgnoreReturnValue public NettyChannelBuilder localSocketPicker(@Nullable LocalSocketPicker localSocketPicker) { this.localSocketPicker = localSocketPicker; return this; @@ -516,6 +525,7 @@ public SocketAddress createSocketAddress( * than this limit is received it will not be processed and the RPC will fail with * RESOURCE_EXHAUSTED. */ + @CanIgnoreReturnValue @Override public NettyChannelBuilder maxInboundMessageSize(int max) { checkArgument(max >= 0, "negative max"); @@ -523,7 +533,6 @@ public NettyChannelBuilder maxInboundMessageSize(int max) { return this; } - @CheckReturnValue ClientTransportFactory buildTransportFactory() { assertEventLoopAndChannelType(); @@ -546,13 +555,11 @@ void assertEventLoopAndChannelType() { "Both EventLoopGroup and ChannelType should be provided or neither should be"); } - @CheckReturnValue int getDefaultPort() { return protocolNegotiatorFactory.getDefaultPort(); } @VisibleForTesting - @CheckReturnValue static ProtocolNegotiator createProtocolNegotiatorByType( NegotiationType negotiationType, SslContext sslContext, @@ -569,11 +576,13 @@ static ProtocolNegotiator createProtocolNegotiatorByType( } } + @CanIgnoreReturnValue NettyChannelBuilder disableCheckAuthority() { this.managedChannelImplBuilder.disableCheckAuthority(); return this; } + @CanIgnoreReturnValue NettyChannelBuilder enableCheckAuthority() { this.managedChannelImplBuilder.enableCheckAuthority(); return this; @@ -610,6 +619,7 @@ void setStatsRecordRetryMetrics(boolean value) { this.managedChannelImplBuilder.setStatsRecordRetryMetrics(value); } + @CanIgnoreReturnValue @VisibleForTesting NettyChannelBuilder setTransportTracerFactory(TransportTracer.Factory transportTracerFactory) { this.transportTracerFactory = transportTracerFactory; @@ -651,7 +661,6 @@ public int getDefaultPort() { /** * Creates Netty transports. Exposed for internal use, as it should be private. */ - @CheckReturnValue private static final class NettyTransportFactory implements ClientTransportFactory { private final ProtocolNegotiator protocolNegotiator; private final ChannelFactory channelFactory; @@ -730,7 +739,7 @@ public void run() { maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), tooManyPingsRunnable, transportTracerFactory.create(), options.getEagAttributes(), - localSocketPicker, channelLogger, useGetForSafeMethods); + localSocketPicker, channelLogger, useGetForSafeMethods, Ticker.systemTicker()); return transport; } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java index bf3df4fa6aa..7cc77c150a0 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java @@ -19,6 +19,10 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** Provider for {@link NettyChannelBuilder} instances. */ @Internal @@ -52,4 +56,9 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia return NewChannelBuilderResult.channelBuilder( new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator)); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 80d11e54859..55337935e3b 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -24,6 +24,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.common.base.Ticker; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.InternalChannelz; @@ -143,7 +144,8 @@ static NettyClientHandler newHandler( TransportTracer transportTracer, Attributes eagAttributes, String authority, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger, + Ticker ticker) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); Http2HeadersDecoder headersDecoder = new GrpcHttp2ClientHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -169,7 +171,8 @@ static NettyClientHandler newHandler( transportTracer, eagAttributes, authority, - negotiationLogger); + negotiationLogger, + ticker); } @VisibleForTesting @@ -187,7 +190,8 @@ static NettyClientHandler newHandler( TransportTracer transportTracer, Attributes eagAttributes, String authority, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger, + Ticker ticker) { Preconditions.checkNotNull(connection, "connection"); Preconditions.checkNotNull(frameReader, "frameReader"); Preconditions.checkNotNull(lifecycleManager, "lifecycleManager"); @@ -237,7 +241,8 @@ static NettyClientHandler newHandler( eagAttributes, authority, autoFlowControl, - pingCounter); + pingCounter, + ticker); } private NettyClientHandler( @@ -253,9 +258,10 @@ private NettyClientHandler( Attributes eagAttributes, String authority, boolean autoFlowControl, - PingLimiter pingLimiter) { + PingLimiter pingLimiter, + Ticker ticker) { super(/* channelUnused= */ null, decoder, encoder, settings, - negotiationLogger, autoFlowControl, pingLimiter); + negotiationLogger, autoFlowControl, pingLimiter, ticker); this.lifecycleManager = lifecycleManager; this.keepAliveManager = keepAliveManager; this.stopwatchFactory = stopwatchFactory; @@ -951,17 +957,16 @@ public void onPingAckRead(ChannelHandlerContext ctx, long ackPayload) throws Htt Http2Ping p = ping; if (ackPayload == flowControlPing().payload()) { flowControlPing().updateWindow(); - if (logger.isLoggable(Level.FINE)) { - logger.log(Level.FINE, String.format("Window: %d", - decoder().flowController().initialWindowSize(connection().connectionStream()))); - } + logger.log(Level.FINE, "Window: {0}", + decoder().flowController().initialWindowSize(connection().connectionStream())); } else if (p != null) { if (p.payload() == ackPayload) { p.complete(); ping = null; } else { - logger.log(Level.WARNING, String.format( - "Received unexpected ping ack. Expecting %d, got %d", p.payload(), ackPayload)); + logger.log(Level.WARNING, + "Received unexpected ping ack. Expecting {0}, got {1}", + new Object[] {p.payload(), ackPayload}); } } else { logger.warning("Received unexpected ping ack. No ping outstanding"); diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index a7a1044059c..689dd847d5e 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -23,6 +23,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.base.Ticker; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; @@ -102,6 +103,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final LocalSocketPicker localSocketPicker; private final ChannelLogger channelLogger; private final boolean useGetForSafeMethods; + private final Ticker ticker; NettyClientTransport( SocketAddress address, ChannelFactory channelFactory, @@ -112,7 +114,8 @@ class NettyClientTransport implements ConnectionClientTransport { boolean keepAliveWithoutCalls, String authority, @Nullable String userAgent, Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, LocalSocketPicker localSocketPicker, ChannelLogger channelLogger, - boolean useGetForSafeMethods) { + boolean useGetForSafeMethods, Ticker ticker) { + this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.negotiationScheme = this.negotiator.scheme(); this.remoteAddress = Preconditions.checkNotNull(address, "address"); @@ -137,6 +140,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.logId = InternalLogId.allocate(getClass(), remoteAddress.toString()); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.useGetForSafeMethods = useGetForSafeMethods; + this.ticker = Preconditions.checkNotNull(ticker, "ticker"); } @Override @@ -225,7 +229,8 @@ public Runnable start(Listener transportListener) { transportTracer, eagAttributes, authorityString, - channelLogger); + channelLogger, + ticker); ChannelHandler negotiationHandler = negotiator.newHandler(handler); @@ -251,7 +256,7 @@ public Runnable start(Listener transportListener) { ChannelHandler bufferingHandler = new WriteBufferingAndExceptionHandler(negotiationHandler); - /** + /* * We don't use a ChannelInitializer in the client bootstrap because its "initChannel" method * is executed in the event loop and we need this handler to be in the pipeline immediately so * that it may begin buffering writes. diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index ff5553eb116..3e8674e8405 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -26,6 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.InlineMe; import io.grpc.Attributes; import io.grpc.ExperimentalApi; @@ -59,14 +60,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.net.ssl.SSLException; /** * A builder to help simplify the construction of a Netty-based GRPC server. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1784") -@CanIgnoreReturnValue +@CheckReturnValue public final class NettyServerBuilder extends AbstractServerImplBuilder { // 1MiB @@ -121,7 +121,6 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder delegate() { * addresses are compatible with the Netty channel type, and that they don't conflict with each * other. */ + @CanIgnoreReturnValue public NettyServerBuilder addListenAddress(SocketAddress listenAddress) { this.listenAddresses.add(checkNotNull(listenAddress, "listenAddress")); return this; @@ -219,6 +213,7 @@ public NettyServerBuilder addListenAddress(SocketAddress listenAddress) { * example, {@link NioServerSocketChannel} must use {@link * io.netty.channel.nio.NioEventLoopGroup}, otherwise your server won't start. */ + @CanIgnoreReturnValue public NettyServerBuilder channelType(Class channelType) { checkNotNull(channelType, "channelType"); return channelFactory(new ReflectiveChannelFactory<>(channelType)); @@ -238,6 +233,7 @@ public NettyServerBuilder channelType(Class channelType * example, if the factory creates {@link NioServerSocketChannel} you must use {@link * io.netty.channel.nio.NioEventLoopGroup}, otherwise your server won't start. */ + @CanIgnoreReturnValue public NettyServerBuilder channelFactory(ChannelFactory channelFactory) { this.channelFactory = checkNotNull(channelFactory, "channelFactory"); return this; @@ -249,6 +245,7 @@ public NettyServerBuilder channelFactory(ChannelFactory * * @since 1.30.0 */ + @CanIgnoreReturnValue public NettyServerBuilder withOption(ChannelOption option, T value) { this.channelOptions.put(option, value); return this; @@ -260,6 +257,7 @@ public NettyServerBuilder withOption(ChannelOption option, T value) { * * @since 1.9.0 */ + @CanIgnoreReturnValue public NettyServerBuilder withChildOption(ChannelOption option, T value) { this.childChannelOptions.put(option, value); return this; @@ -288,6 +286,7 @@ public NettyServerBuilder withChildOption(ChannelOption option, T value) * A simple solution to this problem is to call {@link io.grpc.Server#awaitTermination()} to * keep the main thread alive until the server has terminated. */ + @CanIgnoreReturnValue public NettyServerBuilder bossEventLoopGroup(EventLoopGroup group) { if (group != null) { return bossEventLoopGroupPool(new FixedObjectPool<>(group)); @@ -295,6 +294,7 @@ public NettyServerBuilder bossEventLoopGroup(EventLoopGroup group) { return bossEventLoopGroupPool(DEFAULT_BOSS_EVENT_LOOP_GROUP_POOL); } + @CanIgnoreReturnValue NettyServerBuilder bossEventLoopGroupPool( ObjectPool bossEventLoopGroupPool) { this.bossEventLoopGroupPool = checkNotNull(bossEventLoopGroupPool, "bossEventLoopGroupPool"); @@ -324,6 +324,7 @@ NettyServerBuilder bossEventLoopGroupPool( * A simple solution to this problem is to call {@link io.grpc.Server#awaitTermination()} to * keep the main thread alive until the server has terminated. */ + @CanIgnoreReturnValue public NettyServerBuilder workerEventLoopGroup(EventLoopGroup group) { if (group != null) { return workerEventLoopGroupPool(new FixedObjectPool<>(group)); @@ -331,6 +332,7 @@ public NettyServerBuilder workerEventLoopGroup(EventLoopGroup group) { return workerEventLoopGroupPool(DEFAULT_WORKER_EVENT_LOOP_GROUP_POOL); } + @CanIgnoreReturnValue NettyServerBuilder workerEventLoopGroupPool( ObjectPool workerEventLoopGroupPool) { this.workerEventLoopGroupPool = @@ -349,6 +351,7 @@ void setForceHeapBuffer(boolean value) { * Sets the TLS context to use for encryption. Providing a context enables encryption. It must * have been configured with {@link GrpcSslContexts}, but options could have been overridden. */ + @CanIgnoreReturnValue public NettyServerBuilder sslContext(SslContext sslContext) { checkState(!freezeProtocolNegotiatorFactory, "Cannot change security when using ServerCredentials"); @@ -367,6 +370,7 @@ public NettyServerBuilder sslContext(SslContext sslContext) { * Sets the {@link ProtocolNegotiator} to be used. Overrides the value specified in {@link * #sslContext(SslContext)}. */ + @CanIgnoreReturnValue @Internal public final NettyServerBuilder protocolNegotiator(ProtocolNegotiator protocolNegotiator) { checkState(!freezeProtocolNegotiatorFactory, @@ -395,6 +399,7 @@ void setStatsRecordRealTimeMetrics(boolean value) { * The maximum number of concurrent calls permitted for each incoming connection. Defaults to no * limit. */ + @CanIgnoreReturnValue public NettyServerBuilder maxConcurrentCallsPerConnection(int maxCalls) { checkArgument(maxCalls > 0, "max must be positive: %s", maxCalls); this.maxConcurrentCallsPerConnection = maxCalls; @@ -407,6 +412,7 @@ public NettyServerBuilder maxConcurrentCallsPerConnection(int maxCalls) { * tuning, use {@link #flowControlWindow(int)}. By default, auto flow control is enabled with * initial flow control window size of {@link #DEFAULT_FLOW_CONTROL_WINDOW}. */ + @CanIgnoreReturnValue public NettyServerBuilder initialFlowControlWindow(int initialFlowControlWindow) { checkArgument(initialFlowControlWindow > 0, "initialFlowControlWindow must be positive"); this.flowControlWindow = initialFlowControlWindow; @@ -420,6 +426,7 @@ public NettyServerBuilder initialFlowControlWindow(int initialFlowControlWindow) * called, the default value is {@link #DEFAULT_FLOW_CONTROL_WINDOW}) with auto flow control * tuning. */ + @CanIgnoreReturnValue public NettyServerBuilder flowControlWindow(int flowControlWindow) { checkArgument(flowControlWindow > 0, "flowControlWindow must be positive: %s", flowControlWindow); @@ -437,6 +444,7 @@ public NettyServerBuilder flowControlWindow(int flowControlWindow) { * @deprecated Call {@link #maxInboundMessageSize} instead. This method will be removed in a * future release. */ + @CanIgnoreReturnValue @Deprecated @InlineMe(replacement = "this.maxInboundMessageSize(maxMessageSize)") public NettyServerBuilder maxMessageSize(int maxMessageSize) { @@ -444,6 +452,7 @@ public NettyServerBuilder maxMessageSize(int maxMessageSize) { } /** {@inheritDoc} */ + @CanIgnoreReturnValue @Override public NettyServerBuilder maxInboundMessageSize(int bytes) { checkArgument(bytes >= 0, "bytes must be non-negative: %s", bytes); @@ -459,6 +468,7 @@ public NettyServerBuilder maxInboundMessageSize(int bytes) { * * @deprecated Use {@link #maxInboundMetadataSize} instead */ + @CanIgnoreReturnValue @Deprecated @InlineMe(replacement = "this.maxInboundMetadataSize(maxHeaderListSize)") public NettyServerBuilder maxHeaderListSize(int maxHeaderListSize) { @@ -476,6 +486,7 @@ public NettyServerBuilder maxHeaderListSize(int maxHeaderListSize) { * @throws IllegalArgumentException if bytes is non-positive * @since 1.17.0 */ + @CanIgnoreReturnValue @Override public NettyServerBuilder maxInboundMetadataSize(int bytes) { checkArgument(bytes > 0, "maxInboundMetadataSize must be positive: %s", bytes); @@ -490,6 +501,8 @@ public NettyServerBuilder maxInboundMetadataSize(int bytes) { * * @since 1.3.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { checkArgument(keepAliveTime > 0L, "keepalive time must be positive:%s", keepAliveTime); keepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime); @@ -511,6 +524,8 @@ public NettyServerBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { * * @since 1.3.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { checkArgument(keepAliveTimeout > 0L, "keepalive timeout must be positive: %s", keepAliveTimeout); @@ -533,6 +548,8 @@ public NettyServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit timeU * * @since 1.4.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder maxConnectionIdle(long maxConnectionIdle, TimeUnit timeUnit) { checkArgument(maxConnectionIdle > 0L, "max connection idle must be positive: %s", maxConnectionIdle); @@ -554,6 +571,8 @@ public NettyServerBuilder maxConnectionIdle(long maxConnectionIdle, TimeUnit tim * * @since 1.3.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder maxConnectionAge(long maxConnectionAge, TimeUnit timeUnit) { checkArgument(maxConnectionAge > 0L, "max connection age must be positive: %s", maxConnectionAge); @@ -576,6 +595,8 @@ public NettyServerBuilder maxConnectionAge(long maxConnectionAge, TimeUnit timeU * @see #maxConnectionAge(long, TimeUnit) * @since 1.3.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder maxConnectionAgeGrace(long maxConnectionAgeGrace, TimeUnit timeUnit) { checkArgument(maxConnectionAgeGrace >= 0L, "max connection age grace must be non-negative: %s", maxConnectionAgeGrace); @@ -600,6 +621,8 @@ public NettyServerBuilder maxConnectionAgeGrace(long maxConnectionAgeGrace, Time * @see #permitKeepAliveWithoutCalls(boolean) * @since 1.3.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) { checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s", keepAliveTime); @@ -614,6 +637,8 @@ public NettyServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeU * @see #permitKeepAliveTime(long, TimeUnit) * @since 1.3.0 */ + @CanIgnoreReturnValue + @Override public NettyServerBuilder permitKeepAliveWithoutCalls(boolean permit) { permitKeepAliveWithoutCalls = permit; return this; @@ -624,7 +649,6 @@ void eagAttributes(Attributes eagAttributes) { this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); } - @CheckReturnValue NettyServer buildTransportServers( List streamTracerFactories) { assertEventLoopsAndChannelType(); @@ -657,12 +681,13 @@ void assertEventLoopsAndChannelType() { + "neither should be"); } - NettyServerBuilder setTransportTracerFactory( - TransportTracer.Factory transportTracerFactory) { + @CanIgnoreReturnValue + NettyServerBuilder setTransportTracerFactory(TransportTracer.Factory transportTracerFactory) { this.transportTracerFactory = transportTracerFactory; return this; } + @CanIgnoreReturnValue @Override public NettyServerBuilder useTransportSecurity(File certChain, File privateKey) { checkState(!freezeProtocolNegotiatorFactory, @@ -678,6 +703,7 @@ public NettyServerBuilder useTransportSecurity(File certChain, File privateKey) return this; } + @CanIgnoreReturnValue @Override public NettyServerBuilder useTransportSecurity(InputStream certChain, InputStream privateKey) { checkState(!freezeProtocolNegotiatorFactory, diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index f552b937a05..6382471f46a 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -34,6 +34,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.google.common.base.Ticker; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; @@ -44,14 +45,17 @@ import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.LogExceptionRunnable; +import io.grpc.internal.MaxConnectionIdleManager; import io.grpc.internal.ServerTransportListener; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; @@ -187,7 +191,8 @@ static NettyServerHandler newHandler( maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, - eagAttributes); + eagAttributes, + Ticker.systemTicker()); } static NettyServerHandler newHandler( @@ -209,7 +214,8 @@ static NettyServerHandler newHandler( long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, - Attributes eagAttributes) { + Attributes eagAttributes, + Ticker ticker) { Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive: %s", flowControlWindow); @@ -242,6 +248,10 @@ static NettyServerHandler newHandler( settings.maxConcurrentStreams(maxStreams); settings.maxHeaderListSize(maxHeaderListSize); + if (ticker == null) { + ticker = Ticker.systemTicker(); + } + return new NettyServerHandler( channelUnused, connection, @@ -255,7 +265,7 @@ static NettyServerHandler newHandler( maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, - eagAttributes); + eagAttributes, ticker); } private NettyServerHandler( @@ -275,24 +285,16 @@ private NettyServerHandler( long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, - Attributes eagAttributes) { + Attributes eagAttributes, + Ticker ticker) { super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(), - autoFlowControl, null); + autoFlowControl, null, ticker); final MaxConnectionIdleManager maxConnectionIdleManager; if (maxConnectionIdleInNanos == MAX_CONNECTION_IDLE_NANOS_DISABLED) { maxConnectionIdleManager = null; } else { - maxConnectionIdleManager = new MaxConnectionIdleManager(maxConnectionIdleInNanos) { - @Override - void close(ChannelHandlerContext ctx) { - if (gracefulShutdown == null) { - gracefulShutdown = new GracefulShutdown("max_idle", null); - gracefulShutdown.start(ctx); - ctx.flush(); - } - } - }; + maxConnectionIdleManager = new MaxConnectionIdleManager(maxConnectionIdleInNanos); } connection.addListener(new Http2ConnectionAdapter() { @@ -331,7 +333,6 @@ public void onStreamClosed(Http2Stream stream) { this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); this.transportTracer = checkNotNull(transportTracer, "transportTracer"); - // Set the frame listener on the decoder. decoder().frameListener(new FrameListener()); } @@ -363,7 +364,16 @@ public void run() { } if (maxConnectionIdleManager != null) { - maxConnectionIdleManager.start(ctx); + maxConnectionIdleManager.start(new Runnable() { + @Override + public void run() { + if (gracefulShutdown == null) { + gracefulShutdown = new GracefulShutdown("max_idle", null); + gracefulShutdown.start(ctx); + ctx.flush(); + } + } + }, ctx.executor()); } if (keepAliveTimeInNanos != SERVER_KEEPALIVE_TIME_NANOS_DISABLED) { @@ -854,6 +864,9 @@ public void onHeadersRead(ChannelHandlerContext ctx, keepAliveManager.onDataReceived(); } NettyServerHandler.this.onHeadersRead(ctx, streamId, headers); + if (endStream) { + NettyServerHandler.this.onDataRead(streamId, Unpooled.EMPTY_BUFFER, 0, endStream); + } } @Override @@ -890,10 +903,8 @@ public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exce } if (data == flowControlPing().payload()) { flowControlPing().updateWindow(); - if (logger.isLoggable(Level.FINE)) { - logger.log(Level.FINE, String.format("Window: %d", - decoder().flowController().initialWindowSize(connection().connectionStream()))); - } + logger.log(Level.FINE, "Window: {0}", + decoder().flowController().initialWindowSize(connection().connectionStream())); } else if (data == GRACEFUL_SHUTDOWN_PING) { if (gracefulShutdown == null) { // this should never happen diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index 3850a6a291c..6ab391b260c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -108,10 +108,6 @@ public void writeHeaders(Metadata headers) { private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) { Preconditions.checkArgument(numMessages >= 0); - if (frame == null) { - writeQueue.scheduleFlush(); - return; - } ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); final int numBytes = bytebuf.readableBytes(); // Add the bytes to outbound flow control. diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolver.java b/netty/src/main/java/io/grpc/netty/UdsNameResolver.java new file mode 100644 index 00000000000..8fa8ea06250 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolver.java @@ -0,0 +1,66 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Preconditions; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.netty.channel.unix.DomainSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +final class UdsNameResolver extends NameResolver { + private NameResolver.Listener2 listener; + private final String authority; + + UdsNameResolver(String authority, String targetPath) { + checkArgument(authority == null, "non-null authority not supported"); + this.authority = targetPath; + } + + @Override + public String getServiceAuthority() { + return this.authority; + } + + @Override + public void start(Listener2 listener) { + Preconditions.checkState(this.listener == null, "already started"); + this.listener = checkNotNull(listener, "listener"); + resolve(); + } + + @Override + public void refresh() { + resolve(); + } + + private void resolve() { + ResolutionResult.Builder resolutionResultBuilder = ResolutionResult.newBuilder(); + List servers = new ArrayList<>(1); + servers.add(new EquivalentAddressGroup(new DomainSocketAddress(authority))); + resolutionResultBuilder.setAddresses(Collections.unmodifiableList(servers)); + listener.onResult(resolutionResultBuilder.build()); + } + + @Override + public void shutdown() {} +} diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java new file mode 100644 index 00000000000..ffc07ff6ecb --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java @@ -0,0 +1,71 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.base.Preconditions; +import io.grpc.Internal; +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.netty.channel.unix.DomainSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; + +@Internal +public final class UdsNameResolverProvider extends NameResolverProvider { + + private static final String SCHEME = "unix"; + + @Override + public UdsNameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + return new UdsNameResolver(targetUri.getAuthority(), getTargetPathFromUri(targetUri)); + } else { + return null; + } + } + + static String getTargetPathFromUri(URI targetUri) { + Preconditions.checkArgument(SCHEME.equals(targetUri.getScheme()), "scheme must be " + SCHEME); + String targetPath = targetUri.getPath(); + if (targetPath == null) { + targetPath = Preconditions.checkNotNull(targetUri.getSchemeSpecificPart(), "targetPath"); + } + return targetPath; + } + + @Override + public String getDefaultScheme() { + return SCHEME; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 3; + } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(DomainSocketAddress.class); + } +} diff --git a/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java new file mode 100644 index 00000000000..59b50657a69 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java @@ -0,0 +1,67 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.ChannelCredentials; +import io.grpc.Internal; +import io.grpc.ManagedChannelProvider; +import io.grpc.internal.SharedResourcePool; +import io.netty.channel.unix.DomainSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; + +/** Provider for {@link NettyChannelBuilder} instances for UDS channels. */ +@Internal +public final class UdsNettyChannelProvider extends ManagedChannelProvider { + + @Override + public boolean isAvailable() { + return (Utils.EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE != null); + } + + @Override + public int priority() { + return 3; + } + + @Override + public NettyChannelBuilder builderForAddress(String name, int port) { + throw new AssertionError("NettyChannelProvider shadows this implementation"); + } + + @Override + public NettyChannelBuilder builderForTarget(String target) { + throw new AssertionError("NettyChannelProvider shadows this implementation"); + } + + @Override + public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds) { + NewChannelBuilderResult result = new NettyChannelProvider().newChannelBuilder(target, creds); + if (result.getChannelBuilder() != null) { + ((NettyChannelBuilder) result.getChannelBuilder()) + .eventLoopGroupPool(SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP)) + .channelType(Utils.EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE); + } + return result; + } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(DomainSocketAddress.class); + } +} diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java index c2f2fa4a7bf..fdff8567294 100644 --- a/netty/src/main/java/io/grpc/netty/Utils.java +++ b/netty/src/main/java/io/grpc/netty/Utils.java @@ -89,6 +89,7 @@ class Utils { = new DefaultEventLoopGroupResource(1, "grpc-nio-boss-ELG", EventLoopGroupType.NIO); public static final Resource NIO_WORKER_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource(0, "grpc-nio-worker-ELG", EventLoopGroupType.NIO); + public static final Resource DEFAULT_BOSS_EVENT_LOOP_GROUP; public static final Resource DEFAULT_WORKER_EVENT_LOOP_GROUP; @@ -104,6 +105,7 @@ private static final class ByteBufAllocatorPreferHeapHolder { public static final ChannelFactory DEFAULT_SERVER_CHANNEL_FACTORY; public static final Class DEFAULT_CLIENT_CHANNEL_TYPE; + public static final Class EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE; @Nullable private static final Constructor EPOLL_EVENT_LOOP_GROUP_CONSTRUCTOR; @@ -112,6 +114,7 @@ private static final class ByteBufAllocatorPreferHeapHolder { // Decide default channel types and EventLoopGroup based on Epoll availability if (isEpollAvailable()) { DEFAULT_CLIENT_CHANNEL_TYPE = epollChannelType(); + EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE = epollDomainSocketChannelType(); DEFAULT_SERVER_CHANNEL_FACTORY = new ReflectiveChannelFactory<>(epollServerChannelType()); EPOLL_EVENT_LOOP_GROUP_CONSTRUCTOR = epollEventLoopGroupConstructor(); DEFAULT_BOSS_EVENT_LOOP_GROUP @@ -122,6 +125,7 @@ private static final class ByteBufAllocatorPreferHeapHolder { logger.log(Level.FINE, "Epoll is not available, using Nio.", getEpollUnavailabilityCause()); DEFAULT_SERVER_CHANNEL_FACTORY = nioServerChannelFactory(); DEFAULT_CLIENT_CHANNEL_TYPE = NioSocketChannel.class; + EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE = null; DEFAULT_BOSS_EVENT_LOOP_GROUP = NIO_BOSS_EVENT_LOOP_GROUP; DEFAULT_WORKER_EVENT_LOOP_GROUP = NIO_WORKER_EVENT_LOOP_GROUP; EPOLL_EVENT_LOOP_GROUP_CONSTRUCTOR = null; @@ -326,6 +330,17 @@ private static Class epollChannelType() { } } + // Must call when epoll is available + private static Class epollDomainSocketChannelType() { + try { + Class channelType = Class + .forName("io.netty.channel.epoll.EpollDomainSocketChannel").asSubclass(Channel.class); + return channelType; + } catch (ClassNotFoundException e) { + throw new RuntimeException("Cannot load EpollDomainSocketChannel", e); + } + } + // Must call when epoll is available private static Constructor epollEventLoopGroupConstructor() { try { @@ -451,8 +466,11 @@ static final class FlowControlReader implements TransportTracer.FlowControlReade private final Http2FlowController remote; FlowControlReader(Http2Connection connection) { - local = connection.local().flowController(); - remote = connection.remote().flowController(); + // 'local' in Netty is the _controller_ that controls inbound data. 'local' in Channelz is + // the _present window_ provided by the remote that allows data to be sent. They are + // opposites. + local = connection.remote().flowController(); + remote = connection.local().flowController(); connectionStream = connection.connectionStream(); } diff --git a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java index 100367625fa..2799dfccb61 100644 --- a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java +++ b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java @@ -184,7 +184,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { */ @Override public void flush(ChannelHandlerContext ctx) { - /** + /* * Swallowing any flushes is not only an optimization but also required * for the SslHandler to work correctly. If the SslHandler receives multiple * flushes while the handshake is still ongoing, then the handshake "randomly" diff --git a/netty/src/main/resources/META-INF/services/io.grpc.ManagedChannelProvider b/netty/src/main/resources/META-INF/services/io.grpc.ManagedChannelProvider index ebd1bcdf024..e7b37ea49ac 100644 --- a/netty/src/main/resources/META-INF/services/io.grpc.ManagedChannelProvider +++ b/netty/src/main/resources/META-INF/services/io.grpc.ManagedChannelProvider @@ -1 +1,2 @@ io.grpc.netty.NettyChannelProvider +io.grpc.netty.UdsNettyChannelProvider diff --git a/netty/src/main/resources/META-INF/services/io.grpc.NameResolverProvider b/netty/src/main/resources/META-INF/services/io.grpc.NameResolverProvider new file mode 100644 index 00000000000..ec775013c1e --- /dev/null +++ b/netty/src/main/resources/META-INF/services/io.grpc.NameResolverProvider @@ -0,0 +1 @@ +io.grpc.netty.UdsNameResolverProvider diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index 6b5a96b45ab..fafe33c0ee2 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -16,8 +16,10 @@ package io.grpc.netty; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.google.common.util.concurrent.MoreExecutors; @@ -40,7 +42,6 @@ import io.grpc.util.AdvancedTlsX509TrustManager.SslSocketAndEnginePeerVerifier; import io.grpc.util.AdvancedTlsX509TrustManager.Verification; import io.grpc.util.CertificateUtils; - import java.io.Closeable; import java.io.File; import java.io.IOException; @@ -57,9 +58,7 @@ import javax.net.ssl.SSLEngine; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -90,9 +89,6 @@ public class AdvancedTlsTest { private PrivateKey serverKeyBad; private X509Certificate[] serverCertBad; - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - @Before public void setUp() throws NoSuchAlgorithmException, IOException, CertificateException, InvalidKeySpecException { @@ -428,24 +424,22 @@ public void onFileLoadingKeyManagerTrustManagerTest() throws Exception { @Test public void onFileReloadingKeyManagerBadInitialContentTest() throws Exception { - exceptionRule.expect(GeneralSecurityException.class); AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); // We swap the order of key and certificates to intentionally create an exception. - Closeable keyShutdown = keyManager.updateIdentityCredentialsFromFile(serverCert0File, - serverKey0File, 100, TimeUnit.MILLISECONDS, executor); - keyShutdown.close(); + assertThrows(GeneralSecurityException.class, + () -> keyManager.updateIdentityCredentialsFromFile(serverCert0File, + serverKey0File, 100, TimeUnit.MILLISECONDS, executor)); } @Test public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception { - exceptionRule.expect(GeneralSecurityException.class); AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); // We pass in a key as the trust certificates to intentionally create an exception. - Closeable trustShutdown = trustManager.updateTrustCredentialsFromFile(serverKey0File, - 100, TimeUnit.MILLISECONDS, executor); - trustShutdown.close(); + assertThrows(GeneralSecurityException.class, + () -> trustManager.updateTrustCredentialsFromFile(serverKey0File, + 100, TimeUnit.MILLISECONDS, executor)); } @Test @@ -473,40 +467,38 @@ public void trustManagerCheckTrustedWithSocketTest() throws Exception { @Test public void trustManagerCheckClientTrustedWithoutParameterTest() throws Exception { - exceptionRule.expect(CertificateException.class); - exceptionRule.expectMessage( - "Not enough information to validate peer. SSLEngine or Socket required."); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); - tm.checkClientTrusted(serverCert0, "RSA"); + CertificateException ex = + assertThrows(CertificateException.class, () -> tm.checkClientTrusted(serverCert0, "RSA")); + assertThat(ex).hasMessageThat() + .isEqualTo("Not enough information to validate peer. SSLEngine or Socket required."); } @Test public void trustManagerCheckServerTrustedWithoutParameterTest() throws Exception { - exceptionRule.expect(CertificateException.class); - exceptionRule.expectMessage( - "Not enough information to validate peer. SSLEngine or Socket required."); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); - tm.checkServerTrusted(serverCert0, "RSA"); + CertificateException ex = + assertThrows(CertificateException.class, () -> tm.checkServerTrusted(serverCert0, "RSA")); + assertThat(ex).hasMessageThat() + .isEqualTo("Not enough information to validate peer. SSLEngine or Socket required."); } @Test public void trustManagerEmptyChainTest() throws Exception { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage( - "Want certificate verification but got null or empty certificates"); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); tm.updateTrustCredentials(caCert); - tm.checkClientTrusted(null, "RSA", (SSLEngine) null); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> tm.checkClientTrusted(null, "RSA", (SSLEngine) null)); + assertThat(ex).hasMessageThat() + .isEqualTo("Want certificate verification but got null or empty certificates"); } @Test public void trustManagerBadCustomVerificationTest() throws Exception { - exceptionRule.expect(CertificateException.class); - exceptionRule.expectMessage("Bad Custom Verification"); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .setSslSocketAndEnginePeerVerifier( @@ -524,7 +516,10 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy } }).build(); tm.updateTrustCredentials(caCert); - tm.checkClientTrusted(serverCert0, "RSA", new Socket()); + CertificateException ex = assertThrows( + CertificateException.class, + () -> tm.checkClientTrusted(serverCert0, "RSA", new Socket())); + assertThat(ex).hasMessageThat().isEqualTo("Bad Custom Verification"); } private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { diff --git a/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java b/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java index 11488b752f1..48c6320f4c6 100644 --- a/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java +++ b/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java @@ -22,6 +22,7 @@ import static io.netty.util.AsciiString.of; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import com.google.common.collect.Iterables; import com.google.common.io.BaseEncoding; @@ -133,6 +134,130 @@ public void decode_emptyHeaders() throws Http2Exception { assertThat(decodedHeaders.toString()).contains("[]"); } + // contains() is used by Netty 4.1.75+. https://github.com/grpc/grpc-java/issues/8981 + // Just implement everything pseudo headers for all methods; too many recent breakages. + @Test + public void grpcHttp2RequestHeaders_pseudoHeaders_notPresent() { + Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2); + assertThat(http2Headers.get(AsciiString.of(":path"))).isNull(); + assertThat(http2Headers.get(AsciiString.of(":authority"))).isNull(); + assertThat(http2Headers.get(AsciiString.of(":method"))).isNull(); + assertThat(http2Headers.get(AsciiString.of(":scheme"))).isNull(); + assertThat(http2Headers.get(AsciiString.of(":status"))).isNull(); + + assertThat(http2Headers.getAll(AsciiString.of(":path"))).isEmpty(); + assertThat(http2Headers.getAll(AsciiString.of(":authority"))).isEmpty(); + assertThat(http2Headers.getAll(AsciiString.of(":method"))).isEmpty(); + assertThat(http2Headers.getAll(AsciiString.of(":scheme"))).isEmpty(); + assertThat(http2Headers.getAll(AsciiString.of(":status"))).isEmpty(); + + assertThat(http2Headers.contains(AsciiString.of(":path"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":authority"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":method"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":scheme"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":status"))).isFalse(); + + assertThat(http2Headers.remove(AsciiString.of(":path"))).isFalse(); + assertThat(http2Headers.remove(AsciiString.of(":authority"))).isFalse(); + assertThat(http2Headers.remove(AsciiString.of(":method"))).isFalse(); + assertThat(http2Headers.remove(AsciiString.of(":scheme"))).isFalse(); + assertThat(http2Headers.remove(AsciiString.of(":status"))).isFalse(); + } + + @Test + public void grpcHttp2RequestHeaders_pseudoHeaders_present() { + Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2); + http2Headers.add(AsciiString.of(":path"), AsciiString.of("mypath")); + http2Headers.add(AsciiString.of(":authority"), AsciiString.of("myauthority")); + http2Headers.add(AsciiString.of(":method"), AsciiString.of("mymethod")); + http2Headers.add(AsciiString.of(":scheme"), AsciiString.of("myscheme")); + + assertThat(http2Headers.get(AsciiString.of(":path"))).isEqualTo(AsciiString.of("mypath")); + assertThat(http2Headers.get(AsciiString.of(":authority"))) + .isEqualTo(AsciiString.of("myauthority")); + assertThat(http2Headers.get(AsciiString.of(":method"))).isEqualTo(AsciiString.of("mymethod")); + assertThat(http2Headers.get(AsciiString.of(":scheme"))).isEqualTo(AsciiString.of("myscheme")); + + assertThat(http2Headers.getAll(AsciiString.of(":path"))) + .containsExactly(AsciiString.of("mypath")); + assertThat(http2Headers.getAll(AsciiString.of(":authority"))) + .containsExactly(AsciiString.of("myauthority")); + assertThat(http2Headers.getAll(AsciiString.of(":method"))) + .containsExactly(AsciiString.of("mymethod")); + assertThat(http2Headers.getAll(AsciiString.of(":scheme"))) + .containsExactly(AsciiString.of("myscheme")); + + assertThat(http2Headers.contains(AsciiString.of(":path"))).isTrue(); + assertThat(http2Headers.contains(AsciiString.of(":authority"))).isTrue(); + assertThat(http2Headers.contains(AsciiString.of(":method"))).isTrue(); + assertThat(http2Headers.contains(AsciiString.of(":scheme"))).isTrue(); + + assertThat(http2Headers.remove(AsciiString.of(":path"))).isTrue(); + assertThat(http2Headers.remove(AsciiString.of(":authority"))).isTrue(); + assertThat(http2Headers.remove(AsciiString.of(":method"))).isTrue(); + assertThat(http2Headers.remove(AsciiString.of(":scheme"))).isTrue(); + + assertThat(http2Headers.contains(AsciiString.of(":path"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":authority"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":method"))).isFalse(); + assertThat(http2Headers.contains(AsciiString.of(":scheme"))).isFalse(); + } + + @Test + public void grpcHttp2RequestHeaders_pseudoHeaders_set() { + Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2); + http2Headers.set(AsciiString.of(":path"), AsciiString.of("mypath")); + http2Headers.set(AsciiString.of(":authority"), AsciiString.of("myauthority")); + http2Headers.set(AsciiString.of(":method"), AsciiString.of("mymethod")); + http2Headers.set(AsciiString.of(":scheme"), AsciiString.of("myscheme")); + + assertThat(http2Headers.getAll(AsciiString.of(":path"))) + .containsExactly(AsciiString.of("mypath")); + assertThat(http2Headers.getAll(AsciiString.of(":authority"))) + .containsExactly(AsciiString.of("myauthority")); + assertThat(http2Headers.getAll(AsciiString.of(":method"))) + .containsExactly(AsciiString.of("mymethod")); + assertThat(http2Headers.getAll(AsciiString.of(":scheme"))) + .containsExactly(AsciiString.of("myscheme")); + + http2Headers.set(AsciiString.of(":path"), AsciiString.of("mypath2")); + http2Headers.set(AsciiString.of(":authority"), AsciiString.of("myauthority2")); + http2Headers.set(AsciiString.of(":method"), AsciiString.of("mymethod2")); + http2Headers.set(AsciiString.of(":scheme"), AsciiString.of("myscheme2")); + + assertThat(http2Headers.getAll(AsciiString.of(":path"))) + .containsExactly(AsciiString.of("mypath2")); + assertThat(http2Headers.getAll(AsciiString.of(":authority"))) + .containsExactly(AsciiString.of("myauthority2")); + assertThat(http2Headers.getAll(AsciiString.of(":method"))) + .containsExactly(AsciiString.of("mymethod2")); + assertThat(http2Headers.getAll(AsciiString.of(":scheme"))) + .containsExactly(AsciiString.of("myscheme2")); + } + + @Test + public void grpcHttp2RequestHeaders_pseudoHeaders_addWhenPresent_throws() { + Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2); + http2Headers.add(AsciiString.of(":path"), AsciiString.of("mypath")); + try { + http2Headers.add(AsciiString.of(":path"), AsciiString.of("mypath2")); + fail("Expected exception"); + } catch (Exception ex) { + // expected + } + } + + @Test + public void grpcHttp2RequestHeaders_pseudoHeaders_addInvalid_throws() { + Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2); + try { + http2Headers.add(AsciiString.of(":status"), AsciiString.of("mystatus")); + fail("Expected exception"); + } catch (Exception ex) { + // expected + } + } + @Test public void dupBinHeadersWithComma() { Key key = Key.of("bytes-bin", BINARY_BYTE_MARSHALLER); diff --git a/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java new file mode 100644 index 00000000000..6a0c00bac0e --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java @@ -0,0 +1,641 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static com.google.common.truth.TruthJUnit.assume; +import static io.netty.util.CharsetUtil.US_ASCII; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Enclosed.class) +public class NettyAdaptiveCumulatorTest { + + private static Collection cartesianProductParams(List... lists) { + return Lists.cartesianProduct(lists).stream().map(List::toArray).collect(Collectors.toList()); + } + + @RunWith(JUnit4.class) + public static class CumulateTests { + // Represent data as immutable ASCII Strings for easy and readable ByteBuf equality assertions. + private static final String DATA_INITIAL = "0123"; + private static final String DATA_INCOMING = "456789"; + private static final String DATA_CUMULATED = "0123456789"; + + private static final ByteBufAllocator alloc = new UnpooledByteBufAllocator(false); + private NettyAdaptiveCumulator cumulator; + private NettyAdaptiveCumulator throwingCumulator; + private final UnsupportedOperationException throwingCumulatorError = + new UnsupportedOperationException(); + + // Buffers for testing + private final ByteBuf contiguous = ByteBufUtil.writeAscii(alloc, DATA_INITIAL); + private final ByteBuf in = ByteBufUtil.writeAscii(alloc, DATA_INCOMING); + + @Before + public void setUp() { + cumulator = new NettyAdaptiveCumulator(0) { + @Override + void addInput(ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + // To limit the testing scope to NettyAdaptiveCumulator.cumulate(), always compose + composite.addFlattenedComponents(true, in); + } + }; + + // Throws an error on adding incoming buffer. + throwingCumulator = new NettyAdaptiveCumulator(0) { + @Override + void addInput(ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + throw throwingCumulatorError; + } + }; + } + + @Test + public void cumulate_notReadableCumulation_replacedWithInputAndReleased() { + contiguous.readerIndex(contiguous.writerIndex()); + assertFalse(contiguous.isReadable()); + ByteBuf cumulation = cumulator.cumulate(alloc, contiguous, in); + assertEquals(DATA_INCOMING, cumulation.toString(US_ASCII)); + assertEquals(0, contiguous.refCnt()); + // In retained by cumulation. + assertEquals(1, in.refCnt()); + assertEquals(1, cumulation.refCnt()); + cumulation.release(); + } + + @Test + public void cumulate_contiguousCumulation_newCompositeFromContiguousAndInput() { + CompositeByteBuf cumulation = (CompositeByteBuf) cumulator.cumulate(alloc, contiguous, in); + assertEquals(DATA_INITIAL, cumulation.component(0).toString(US_ASCII)); + assertEquals(DATA_INCOMING, cumulation.component(1).toString(US_ASCII)); + assertEquals(DATA_CUMULATED, cumulation.toString(US_ASCII)); + // Both in and contiguous are retained by cumulation. + assertEquals(1, contiguous.refCnt()); + assertEquals(1, in.refCnt()); + assertEquals(1, cumulation.refCnt()); + cumulation.release(); + } + + @Test + public void cumulate_compositeCumulation_inputAppendedAsANewComponent() { + CompositeByteBuf composite = alloc.compositeBuffer().addComponent(true, contiguous); + assertSame(composite, cumulator.cumulate(alloc, composite, in)); + assertEquals(DATA_INITIAL, composite.component(0).toString(US_ASCII)); + assertEquals(DATA_INCOMING, composite.component(1).toString(US_ASCII)); + assertEquals(DATA_CUMULATED, composite.toString(US_ASCII)); + // Both in and contiguous are retained by cumulation. + assertEquals(1, contiguous.refCnt()); + assertEquals(1, in.refCnt()); + assertEquals(1, composite.refCnt()); + composite.release(); + } + + @Test + public void cumulate_compositeCumulation_inputReleasedOnError() { + CompositeByteBuf composite = alloc.compositeBuffer().addComponent(true, contiguous); + try { + throwingCumulator.cumulate(alloc, composite, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(throwingCumulatorError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // Initial composite cumulation owned by the caller in this case, so it isn't released. + assertEquals(1, composite.refCnt()); + // Contiguous still managed by the cumulation + assertEquals(1, contiguous.refCnt()); + } finally { + composite.release(); + } + } + + @Test + public void cumulate_contiguousCumulation_inputAndNewCompositeReleasedOnError() { + // Return our instance of new composite to ensure it's released. + CompositeByteBuf newComposite = alloc.compositeBuffer(Integer.MAX_VALUE); + ByteBufAllocator mockAlloc = mock(ByteBufAllocator.class); + when(mockAlloc.compositeBuffer(anyInt())).thenReturn(newComposite); + + try { + // Previous cumulation is non-composite, so cumulator will create anew composite and add + // both buffers to it. + throwingCumulator.cumulate(mockAlloc, contiguous, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(throwingCumulatorError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // New composite cumulation hasn't been returned to the caller, so it must be released. + assertEquals(0, newComposite.refCnt()); + // Previous cumulation released because it was owned by the new composite cumulation. + assertEquals(0, contiguous.refCnt()); + } + } + } + + @RunWith(Parameterized.class) + public static class ShouldComposeTests { + // Represent data as immutable ASCII Strings for easy and readable ByteBuf equality assertions. + private static final String DATA_INITIAL = "0123"; + private static final String DATA_INCOMING = "456789"; + + /** + * Cartesian product of the test values. + */ + @Parameters(name = "composeMinSize={0}, tailData=\"{1}\", inData=\"{2}\"") + public static Collection params() { + List composeMinSize = ImmutableList.of(0, 9, 10, 11, Integer.MAX_VALUE); + List tailData = ImmutableList.of("", DATA_INITIAL); + List inData = ImmutableList.of("", DATA_INCOMING); + return cartesianProductParams(composeMinSize, tailData, inData); + } + + @Parameter public int composeMinSize; + @Parameter(1) public String tailData; + @Parameter(2) public String inData; + + private CompositeByteBuf composite; + private ByteBuf tail; + private ByteBuf in; + + @Before + public void setUp() { + ByteBufAllocator alloc = new UnpooledByteBufAllocator(false); + in = ByteBufUtil.writeAscii(alloc, inData); + tail = ByteBufUtil.writeAscii(alloc, tailData); + composite = alloc.compositeBuffer(Integer.MAX_VALUE); + // Note that addFlattenedComponents() will not add a new component when tail is not readable. + composite.addFlattenedComponents(true, tail); + } + + @After + public void tearDown() { + in.release(); + composite.release(); + } + + @Test + public void shouldCompose_emptyComposite() { + assume().that(composite.numComponents()).isEqualTo(0); + assertTrue(NettyAdaptiveCumulator.shouldCompose(composite, in, composeMinSize)); + } + + @Test + public void shouldCompose_composeMinSizeReached() { + assume().that(composite.numComponents()).isGreaterThan(0); + assume().that(tail.readableBytes() + in.readableBytes()).isAtLeast(composeMinSize); + assertTrue(NettyAdaptiveCumulator.shouldCompose(composite, in, composeMinSize)); + } + + @Test + public void shouldCompose_composeMinSizeNotReached() { + assume().that(composite.numComponents()).isGreaterThan(0); + assume().that(tail.readableBytes() + in.readableBytes()).isLessThan(composeMinSize); + assertFalse(NettyAdaptiveCumulator.shouldCompose(composite, in, composeMinSize)); + } + } + + @RunWith(Parameterized.class) + public static class MergeWithCompositeTailTests { + private static final String INCOMING_DATA_READABLE = "+incoming"; + private static final String INCOMING_DATA_DISCARDABLE = "discard"; + + private static final String TAIL_DATA_DISCARDABLE = "---"; + private static final String TAIL_DATA_READABLE = "tail"; + private static final String TAIL_DATA = TAIL_DATA_DISCARDABLE + TAIL_DATA_READABLE; + private static final int TAIL_READER_INDEX = TAIL_DATA_DISCARDABLE.length(); + private static final int TAIL_MAX_CAPACITY = 128; + + // DRY sacrificed to improve readability. + private static final String EXPECTED_TAIL_DATA = "tail+incoming"; + + /** + * Cartesian product of the test values. + * + *

Test cases when the cumulation contains components, other than tail, and could be + * partially read. This is needed to verify the correctness if reader and writer indexes of the + * composite cumulation after the merge. + */ + @Parameters(name = "compositeHeadData=\"{0}\", compositeReaderIndex={1}") + public static Collection params() { + String headData = "head"; + + List compositeHeadData = ImmutableList.of( + // Test without the "head" component. Empty string is equivalent of fully read buffer, + // so it's not added to the composite byte buf. The tail is added as the first component. + "", + // Test with the "head" component, so the tail is added as the second component. + headData + ); + + // After the tail is added to the composite cumulator, advance the reader index to + // cover different cases. + // The reader index only looks at what's readable in the composite byte buf, so + // discardable bytes of head and tail doesn't count. + List compositeReaderIndex = ImmutableList.of( + // Reader in the beginning + 0, + // Within the head (when present) or the tail + headData.length() - 2, + // Within the tail, even if the head is present + headData.length() + 2 + ); + return cartesianProductParams(compositeHeadData, compositeReaderIndex); + } + + @Parameter public String compositeHeadData; + @Parameter(1) public int compositeReaderIndex; + + // Use pooled allocator to have maxFastWritableBytes() behave differently than writableBytes(). + private final ByteBufAllocator alloc = new PooledByteBufAllocator(); + + // Composite buffer to be used in tests. + private CompositeByteBuf composite; + private ByteBuf tail; + private ByteBuf in; + + @Before + public void setUp() { + composite = alloc.compositeBuffer(); + + // The "head" component. It represents existing data in the cumulator. + // Note that addFlattenedComponents() does not add completely read buffer, which covers + // the case when compositeHeadData parameter is an empty string. + ByteBuf head = alloc.buffer().writeBytes(compositeHeadData.getBytes(US_ASCII)); + composite.addFlattenedComponents(true, head); + + // The "tail" component. It also represents existing data in the cumulator, but it's + // not added to the cumulator during setUp() stage. It is to be manipulated by tests to + // produce different buffer write scenarios based on different tail's capacity. + // After tail is changes for each test scenario, it's added to the composite buffer. + // + // The default state of the tail before each test: tail is full, but expandable (the data uses + // all initial capacity, but not maximum capacity). + // Tail data and indexes: + // ----tail + // r w + tail = alloc.buffer(TAIL_DATA.length(), TAIL_MAX_CAPACITY) + .writeBytes(TAIL_DATA.getBytes(US_ASCII)) + .readerIndex(TAIL_READER_INDEX); + + // Incoming data and indexes: + // discard+incoming + // r w + in = alloc.buffer() + .writeBytes(INCOMING_DATA_DISCARDABLE.getBytes(US_ASCII)) + .writeBytes(INCOMING_DATA_READABLE.getBytes(US_ASCII)) + .readerIndex(INCOMING_DATA_DISCARDABLE.length()); + } + + @After + public void tearDown() { + composite.release(); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_write() { + // Make incoming data fit into tail capacity. + int fitCapacity = tail.capacity() + INCOMING_DATA_READABLE.length(); + tail.capacity(fitCapacity); + // Confirm it fits. + assertThat(in.readableBytes()).isAtMost(tail.writableBytes()); + + // All fits, so tail capacity must stay the same. + composite.addFlattenedComponents(true, tail); + assertTailExpanded(EXPECTED_TAIL_DATA, fitCapacity); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_fastWrite() { + // Confirm that the tail can be expanded fast to fit the incoming data. + assertThat(in.readableBytes()).isAtMost(tail.maxFastWritableBytes()); + + // To avoid undesirable buffer unwrapping, at the moment adaptive cumulator is set not + // apply fastWrite technique. Even when fast write is possible, it will fall back to + // reallocating a larger buffer. + // int tailFastCapacity = tail.writerIndex() + tail.maxFastWritableBytes(); + int tailFastCapacity = + alloc.calculateNewCapacity(EXPECTED_TAIL_DATA.length(), Integer.MAX_VALUE); + + // Tail capacity is extended to its fast capacity. + composite.addFlattenedComponents(true, tail); + assertTailExpanded(EXPECTED_TAIL_DATA, tailFastCapacity); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_reallocateInMemory() { + int tailFastCapacity = tail.writerIndex() + tail.maxFastWritableBytes(); + String inSuffixOverFastBytes = Strings.repeat("a", tailFastCapacity + 1); + int newTailSize = tail.readableBytes() + inSuffixOverFastBytes.length(); + composite.addFlattenedComponents(true, tail); + + // Make input larger than tailFastCapacity + in.writeCharSequence(inSuffixOverFastBytes, US_ASCII); + // Confirm that the tail can only fit incoming data via reallocation. + assertThat(in.readableBytes()).isGreaterThan(tail.maxFastWritableBytes()); + assertThat(in.readableBytes()).isAtMost(tail.maxWritableBytes()); + + // Confirm the assumption that new capacity is produced by alloc.calculateNewCapacity(). + int expectedTailCapacity = alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE); + assertTailExpanded(EXPECTED_TAIL_DATA.concat(inSuffixOverFastBytes), expectedTailCapacity); + } + + private void assertTailExpanded(String expectedTailReadableData, int expectedNewTailCapacity) { + int originalNumComponents = composite.numComponents(); + + // Handle the case when reader index is beyond all readable bytes of the cumulation. + int compositeReaderIndexBounded = Math.min(compositeReaderIndex, composite.writerIndex()); + composite.readerIndex(compositeReaderIndexBounded); + + // Execute the merge logic. + NettyAdaptiveCumulator.mergeWithCompositeTail(alloc, composite, in); + + // Composite component count shouldn't change. + assertWithMessage( + "When tail is expanded, the number of components in the cumulation must not change") + .that(composite.numComponents()).isEqualTo(originalNumComponents); + + ByteBuf newTail = composite.component(composite.numComponents() - 1); + + // Verify the readable part of the expanded tail: + // 1. Initial readable bytes of the tail not changed + // 2. Discardable bytes (0 < discardable < readerIndex) of the incoming buffer are discarded. + // 3. Readable bytes of the incoming buffer are fully read and appended to the tail. + assertEquals(expectedTailReadableData, newTail.toString(US_ASCII)); + // Verify expanded capacity. + assertEquals(expectedNewTailCapacity, newTail.capacity()); + + // Discardable bytes (0 < discardable < readerIndex) of the tail are kept as is. + String newTailDataDiscardable = newTail.toString(0, newTail.readerIndex(), US_ASCII); + assertWithMessage("After tail expansion, its discardable bytes should be unchanged") + .that(newTailDataDiscardable).isEqualTo(TAIL_DATA_DISCARDABLE); + + // Reader index must stay where it was + assertEquals(TAIL_READER_INDEX, newTail.readerIndex()); + // Writer index at the end + assertEquals(TAIL_READER_INDEX + expectedTailReadableData.length(), + newTail.writerIndex()); + + // Verify resulting cumulation. + assertExpectedCumulation(newTail, expectedTailReadableData, compositeReaderIndexBounded); + + // Verify incoming buffer. + assertWithMessage("Incoming buffer is fully read").that(in.isReadable()).isFalse(); + assertWithMessage("Incoming buffer is released").that(in.refCnt()).isEqualTo(0); + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_maxCapacityReached() { + // Fill in tail to the maxCapacity. + String tailSuffixFullCapacity = Strings.repeat("a", tail.maxWritableBytes()); + tail.writeCharSequence(tailSuffixFullCapacity, US_ASCII); + composite.addFlattenedComponents(true, tail); + assertTailReplaced(); + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_shared() { + tail.retain(); + composite.addFlattenedComponents(true, tail); + assertTailReplaced(); + tail.release(); + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_readOnly() { + composite.addFlattenedComponents(true, tail.asReadOnly()); + assertTailReplaced(); + } + + private void assertTailReplaced() { + int cumulationOriginalComponentsNum = composite.numComponents(); + int taiOriginalRefCount = tail.refCnt(); + String expectedTailReadable = tail.toString(US_ASCII) + in.toString(US_ASCII); + int expectedReallocatedTailCapacity = alloc + .calculateNewCapacity(expectedTailReadable.length(), Integer.MAX_VALUE); + + int compositeReaderIndexBounded = Math.min(compositeReaderIndex, composite.writerIndex()); + composite.readerIndex(compositeReaderIndexBounded); + NettyAdaptiveCumulator.mergeWithCompositeTail(alloc, composite, in); + + // Composite component count shouldn't change. + assertEquals(cumulationOriginalComponentsNum, composite.numComponents()); + ByteBuf replacedTail = composite.component(composite.numComponents() - 1); + + // Verify the readable part of the expanded tail: + // 1. Discardable bytes (0 < discardable < readerIndex) of the tail are discarded. + // 2. Readable bytes of the tail are kept as is + // 3. Discardable bytes (0 < discardable < readerIndex) of the incoming buffer are discarded. + // 4. Readable bytes of the incoming buffer are fully read and appended to the tail. + assertEquals(0, in.readableBytes()); + assertEquals(expectedTailReadable, replacedTail.toString(US_ASCII)); + + // Since tail discardable bytes are discarded, new reader index must be reset to 0. + assertEquals(0, replacedTail.readerIndex()); + // And new writer index at the new data's length. + assertEquals(expectedTailReadable.length(), replacedTail.writerIndex()); + // Verify the capacity of reallocated tail. + assertEquals(expectedReallocatedTailCapacity, replacedTail.capacity()); + + // Verify resulting cumulation. + assertExpectedCumulation(replacedTail, expectedTailReadable, compositeReaderIndexBounded); + + // Verify incoming buffer. + assertWithMessage("Incoming buffer is fully read").that(in.isReadable()).isFalse(); + assertWithMessage("Incoming buffer is released").that(in.refCnt()).isEqualTo(0); + + // The old tail must be released once (have one less reference). + assertWithMessage("Replaced tail released once.") + .that(tail.refCnt()).isEqualTo(taiOriginalRefCount - 1); + } + + private void assertExpectedCumulation( + ByteBuf newTail, String expectedTailReadable, int expectedReaderIndex) { + // Verify the readable part of the cumulation: + // 1. Readable composite head (initial) data + // 2. Readable part of the tail + // 3. Readable part of the incoming data + String expectedCumulationData = + compositeHeadData.concat(expectedTailReadable).substring(expectedReaderIndex); + assertEquals(expectedCumulationData, composite.toString(US_ASCII)); + + // Cumulation capacity includes: + // 1. Full composite head, including discardable bytes + // 2. Expanded tail readable bytes + int expectedCumulationCapacity = compositeHeadData.length() + expectedTailReadable.length(); + assertEquals(expectedCumulationCapacity, composite.capacity()); + + // Composite Reader index must stay where it was. + assertEquals(expectedReaderIndex, composite.readerIndex()); + // Composite writer index must be at the end. + assertEquals(expectedCumulationCapacity, composite.writerIndex()); + + // Composite cumulation is retained and owns the new tail. + assertEquals(1, composite.refCnt()); + assertEquals(1, newTail.refCnt()); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_mergedReleaseOnThrow() { + final UnsupportedOperationException expectedError = new UnsupportedOperationException(); + CompositeByteBuf compositeThrows = new CompositeByteBuf(alloc, false, Integer.MAX_VALUE, + tail) { + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, + ByteBuf buffer) { + throw expectedError; + } + }; + + try { + NettyAdaptiveCumulator.mergeWithCompositeTail(alloc, compositeThrows, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(expectedError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // Tail released + assertEquals(0, tail.refCnt()); + // Composite cumulation is retained + assertEquals(1, compositeThrows.refCnt()); + // Composite cumulation loses the tail + assertEquals(0, compositeThrows.numComponents()); + } finally { + compositeThrows.release(); + } + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_mergedReleaseOnThrow() { + final UnsupportedOperationException expectedError = new UnsupportedOperationException(); + CompositeByteBuf compositeRo = new CompositeByteBuf(alloc, false, Integer.MAX_VALUE, + tail.asReadOnly()) { + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, + ByteBuf buffer) { + throw expectedError; + } + }; + + // Return our instance of the new buffer to ensure it's released. + int newTailSize = tail.readableBytes() + in.readableBytes(); + ByteBuf newTail = alloc.buffer(alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE)); + ByteBufAllocator mockAlloc = mock(ByteBufAllocator.class); + when(mockAlloc.buffer(anyInt())).thenReturn(newTail); + + try { + NettyAdaptiveCumulator.mergeWithCompositeTail(mockAlloc, compositeRo, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(expectedError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // New buffer released + assertEquals(0, newTail.refCnt()); + // Composite cumulation is retained + assertEquals(1, compositeRo.refCnt()); + // Composite cumulation loses the tail + assertEquals(0, compositeRo.numComponents()); + } finally { + compositeRo.release(); + } + } + } + + /** + * Miscellaneous tests for {@link NettyAdaptiveCumulator#mergeWithCompositeTail} that don't + * fit into {@link MergeWithCompositeTailTests}, and require custom-crafted scenarios. + */ + @RunWith(JUnit4.class) + public static class MergeWithCompositeTailMiscTests { + private final ByteBufAllocator alloc = new PooledByteBufAllocator(); + + /** + * Test the issue with {@link CompositeByteBuf#component(int)} returning a ByteBuf with + * the indexes out-of-sync with {@code CompositeByteBuf.Component} offsets. + */ + @Test + public void mergeWithCompositeTail_outOfSyncComposite() { + NettyAdaptiveCumulator cumulator = new NettyAdaptiveCumulator(1024); + + // Create underlying buffer spacious enough for the test data. + ByteBuf buf = alloc.buffer(32).writeBytes("---01234".getBytes(US_ASCII)); + + // Start with a regular cumulation and add the buf as the only component. + CompositeByteBuf composite1 = alloc.compositeBuffer(8).addFlattenedComponents(true, buf); + // Read composite1 buf to the beginning of the numbers. + assertThat(composite1.readCharSequence(3, US_ASCII).toString()).isEqualTo("---"); + + // Wrap composite1 into another cumulation. This is similar to + // what NettyAdaptiveCumulator.cumulate() does in the case the cumulation has refCnt != 1. + CompositeByteBuf composite2 = + alloc.compositeBuffer(8).addFlattenedComponents(true, composite1); + assertThat(composite2.toString(US_ASCII)).isEqualTo("01234"); + + // The previous operation does not adjust the read indexes of the underlying buffers, + // only the internal Component offsets. When the cumulator attempts to append the input to + // the tail buffer, it extracts it from the cumulation, writes to it, and then adds it back. + // Because the readerIndex on the tail buffer is not adjusted during the read operation + // on the CompositeByteBuf, adding the tail back results in the discarded bytes of the tail + // to be added back to the cumulator as if they were never read. + // + // If the reader index of the tail is not manually corrected, the resulting + // cumulation will contain the discarded part of the tail: "---". + // If it's corrected, it will only contain the numbers. + CompositeByteBuf cumulation = (CompositeByteBuf) cumulator.cumulate(alloc, composite2, + ByteBufUtil.writeAscii(alloc, "56789")); + assertThat(cumulation.toString(US_ASCII)).isEqualTo("0123456789"); + + // Correctness check: we still have a single component, and this component is still the + // original underlying buffer. + assertThat(cumulation.numComponents()).isEqualTo(1); + // Replace '2' with '*', and '8' with '$'. + buf.setByte(5, '*').setByte(11, '$'); + assertThat(cumulation.toString(US_ASCII)).isEqualTo("01*34567$9"); + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index d47942858a3..5ec82446cd6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -54,9 +54,11 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; +import io.grpc.CallOptions; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.internal.AbstractStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport; @@ -68,6 +70,7 @@ import io.grpc.internal.StreamListener; import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ClientHeadersDecoder; +import io.grpc.testing.TestMethodDescriptors; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; @@ -118,7 +121,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase setKeepaliveManagerFor = ImmutableList.of("cancelShouldSucceed", @@ -136,12 +139,31 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase streamListenerMessageQueue = new LinkedList<>(); + private NettyClientStream stream; @Override protected void manualSetUp() throws Exception { setUp(); } + @Override + protected AbstractStream stream() throws Exception { + if (stream == null) { + stream = new NettyClientStream(streamTransportState, + TestMethodDescriptors.voidMethod(), + new Metadata(), + channel(), + AsciiString.of("localhost"), + AsciiString.of("http"), + AsciiString.of("agent"), + StatsTraceContext.NOOP, + transportTracer, + CallOptions.DEFAULT, + false); + } + return stream; + } + /** * Set up for test. */ @@ -201,7 +223,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio // Create a new stream with id 3. ChannelFuture createFuture = enqueue( newCreateStreamCommand(grpcHeaders, streamTransportState)); - assertEquals(3, streamTransportState.id()); + assertEquals(STREAM_ID, streamTransportState.id()); // Cancel the stream. cancelStream(Status.CANCELLED); @@ -212,7 +234,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio @Test public void createStreamShouldSucceed() throws Exception { createStream(); - verifyWrite().writeHeaders(eq(ctx()), eq(3), eq(grpcHeaders), eq(0), + verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(grpcHeaders), eq(0), eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class)); } @@ -221,7 +243,7 @@ public void cancelShouldSucceed() throws Exception { createStream(); cancelStream(Status.CANCELLED); - verifyWrite().writeRstStream(eq(ctx()), eq(3), eq(Http2Error.CANCEL.code()), + verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verify(mockKeepAliveManager, times(1)).onTransportIdle(); // onStreamClosed @@ -233,7 +255,7 @@ public void cancelDeadlineExceededShouldSucceed() throws Exception { createStream(); cancelStream(Status.DEADLINE_EXCEEDED); - verifyWrite().writeRstStream(eq(ctx()), eq(3), eq(Http2Error.CANCEL.code()), + verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); } @@ -262,7 +284,7 @@ public void cancelTwiceShouldSucceed() throws Exception { cancelStream(Status.CANCELLED); - verifyWrite().writeRstStream(any(ChannelHandlerContext.class), eq(3), + verifyWrite().writeRstStream(any(ChannelHandlerContext.class), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); ChannelFuture future = cancelStream(Status.CANCELLED); @@ -275,7 +297,7 @@ public void cancelTwiceDifferentReasons() throws Exception { cancelStream(Status.DEADLINE_EXCEEDED); - verifyWrite().writeRstStream(eq(ctx()), eq(3), eq(Http2Error.CANCEL.code()), + verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); ChannelFuture future = cancelStream(Status.CANCELLED); @@ -291,7 +313,7 @@ public void sendFrameShouldSucceed() throws Exception { = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(3), eq(content()), eq(0), eq(true), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verifyNoMoreInteractions(mockKeepAliveManager); @@ -313,7 +335,7 @@ public void inboundShouldForwardToStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) .set(as("magic"), as("value")); - ByteBuf headersFrame = headersFrame(3, headers); + ByteBuf headersFrame = headersFrame(STREAM_ID, headers); channelRead(headersFrame); ArgumentCaptor captor = ArgumentCaptor.forClass(Metadata.class); verify(streamListener).headersRead(captor.capture()); @@ -323,7 +345,7 @@ public void inboundShouldForwardToStream() throws Exception { streamTransportState.requestMessagesFromDeframerForTesting(1); // Create a data frame and then trigger the handler to read it. - ByteBuf frame = grpcDataFrame(3, false, contentAsArray()); + ByteBuf frame = grpcDataFrame(STREAM_ID, false, contentAsArray()); channelRead(frame); InputStream message = streamListenerMessageQueue.poll(); assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message)); @@ -580,7 +602,7 @@ public void close() throws SecurityException { public void cancelStreamShouldCreateAndThenFailBufferedStream() throws Exception { receiveMaxConcurrentStreams(0); enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); - assertEquals(3, streamTransportState.id()); + assertEquals(STREAM_ID, streamTransportState.id()); cancelStream(Status.CANCELLED); verify(streamListener).closed(eq(Status.CANCELLED), same(PROCESSED), any(Metadata.class)); } @@ -627,7 +649,7 @@ public void connectionWindowShouldBeOverridden() throws Exception { public void createIncrementsIdsForActualAndBufferdStreams() throws Exception { receiveMaxConcurrentStreams(2); enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); - assertEquals(3, streamTransportState.id()); + assertEquals(STREAM_ID, streamTransportState.id()); streamTransportState = new TransportStateImpl( handler(), @@ -766,7 +788,7 @@ public void oustandingUserPingShouldNotInteractWithDataPing() throws Exception { ArgumentCaptor captor = ArgumentCaptor.forClass(long.class); verifyWrite().writePing(eq(ctx()), eq(false), captor.capture(), any(ChannelPromise.class)); long payload = captor.getValue(); - channelRead(grpcDataFrame(3, false, contentAsArray())); + channelRead(grpcDataFrame(STREAM_ID, false, contentAsArray())); long pingData = handler().flowControlPing().payload(); channelRead(pingFrame(true, pingData)); @@ -789,18 +811,18 @@ public void bdpPingAvoidsTooManyPingsOnSpecialServers() throws Exception { Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - channelRead(headersFrame(3, headers)); - channelRead(dataFrame(3, false, content())); + channelRead(headersFrame(STREAM_ID, headers)); + channelRead(dataFrame(STREAM_ID, false, content())); verifyWrite().writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + channelRead(dataFrame(STREAM_ID, false, content())); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); + channelRead(dataFrame(STREAM_ID, false, content())); // No ping was sent - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); } @Test @@ -820,26 +842,26 @@ public void bdpPingAllowedAfterSendingData() throws Exception { Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - channelRead(headersFrame(3, headers)); - channelRead(dataFrame(3, false, content())); + channelRead(headersFrame(STREAM_ID, headers)); + channelRead(dataFrame(STREAM_ID, false, content())); verifyWrite().writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + channelRead(dataFrame(STREAM_ID, false, content())); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); + channelRead(dataFrame(STREAM_ID, false, content())); // No ping was sent - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(windowUpdate(0, 2024)); - channelRead(windowUpdate(3, 2024)); + channelRead(windowUpdate(STREAM_ID, 2024)); assertTrue(future.isDone()); assertTrue(future.isSuccess()); // But now one is sent - channelRead(dataFrame(3, false, content())); - verifyWrite(times(3)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + channelRead(dataFrame(STREAM_ID, false, content())); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); } @Override @@ -869,7 +891,7 @@ protected void makeStream() throws Exception { // both client- and server-side. Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - ByteBuf headersFrame = headersFrame(3, headers); + ByteBuf headersFrame = headersFrame(STREAM_ID, headers); channelRead(headersFrame); } @@ -928,7 +950,8 @@ public Stopwatch get() { transportTracer, Attributes.EMPTY, "someauthority", - null); + null, + fakeClock().getTicker()); } @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 018ca9b6594..5f47c7b14c5 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -36,6 +36,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Ticker; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; @@ -196,7 +197,7 @@ public void setSoLingerChannelOption() throws IOException { newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, - new SocketPicker(), new FakeChannelLogger(), false); + new SocketPicker(), new FakeChannelLogger(), false, Ticker.systemTicker()); transports.add(transport); callMeMaybe(transport.start(clientTransportListener)); @@ -448,7 +449,7 @@ public void failingToConstructChannelShouldFailGracefully() throws Exception { newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker(), - new FakeChannelLogger(), false); + new FakeChannelLogger(), false, Ticker.systemTicker()); transports.add(transport); // Should not throw @@ -763,7 +764,8 @@ private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int max negotiator, false, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, keepAliveTimeNano, keepAliveTimeoutNano, false, authority, userAgent, tooManyPingsRunnable, - new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false); + new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false, + Ticker.systemTicker()); transports.add(transport); return transport; } diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index 04f65eed145..fbab1ca5fae 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -30,6 +30,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.InternalChannelz.TransportStats; +import io.grpc.internal.AbstractStream; import io.grpc.internal.FakeClock; import io.grpc.internal.MessageFramer; import io.grpc.internal.StatsTraceContext; @@ -64,8 +65,10 @@ import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; import java.io.ByteArrayInputStream; +import java.nio.ByteBuffer; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -80,6 +83,7 @@ @RunWith(JUnit4.class) public abstract class NettyHandlerTestBase { + protected static final int STREAM_ID = 3; private ByteBuf content; private EmbeddedChannel channel; @@ -328,6 +332,8 @@ protected final Http2Connection connection() { return handler().connection(); } + protected abstract AbstractStream stream() throws Exception; + @CanIgnoreReturnValue protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) { ChannelFuture future = writeQueue.enqueue(command, true); @@ -415,18 +421,15 @@ public void windowUpdateMatchesTarget() throws Exception { AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); - ByteBuf data = ctx().alloc().buffer(1024); - while (data.isWritable()) { - data.writeLong(1111); - } - int length = data.readableBytes(); - ByteBuf frame = dataFrame(3, false, data.copy()); + byte[] data = initXkbBuffer(1); + int wireSize = data.length + 5; // 5 is the size of the header + ByteBuf frame = grpcDataFrame(3, false, data); channelRead(frame); - int accumulator = length; + int accumulator = wireSize; // 40 is arbitrary, any number large enough to trigger a window update would work for (int i = 0; i < 40; i++) { - channelRead(dataFrame(3, false, data.copy())); - accumulator += length; + channelRead(grpcDataFrame(3, false, data)); + accumulator += wireSize; } long pingData = handler.flowControlPing().payload(); channelRead(pingFrame(true, pingData)); @@ -444,8 +447,10 @@ public void windowShouldNotExceedMaxWindowSize() throws Exception { Http2Stream connectionStream = connection().connectionStream(); Http2LocalFlowController localFlowController = connection().local().flowController(); int maxWindow = handler.flowControlPing().maxWindow(); + fakeClock.forwardTime(10, TimeUnit.SECONDS); handler.flowControlPing().setDataSizeAndSincePing(maxWindow); + fakeClock.forwardTime(1, TimeUnit.SECONDS); long payload = handler.flowControlPing().payload(); channelRead(pingFrame(true, payload)); @@ -456,8 +461,8 @@ public void windowShouldNotExceedMaxWindowSize() throws Exception { public void transportTracer_windowSizeDefault() throws Exception { manualSetUp(); TransportStats transportStats = transportTracer.getStats(); - assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow); - assertEquals(flowControlWindow, transportStats.localFlowControlWindow); + assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.localFlowControlWindow); + assertEquals(flowControlWindow, transportStats.remoteFlowControlWindow); } @Test @@ -465,31 +470,31 @@ public void transportTracer_windowSize() throws Exception { flowControlWindow = 1024 * 1024; manualSetUp(); TransportStats transportStats = transportTracer.getStats(); - assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow); - assertEquals(flowControlWindow, transportStats.localFlowControlWindow); + assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.localFlowControlWindow); + assertEquals(flowControlWindow, transportStats.remoteFlowControlWindow); } @Test public void transportTracer_windowUpdate_remote() throws Exception { manualSetUp(); TransportStats before = transportTracer.getStats(); - assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.localFlowControlWindow); + assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); ByteBuf serializedSettings = windowUpdate(0, 1000); channelRead(serializedSettings); TransportStats after = transportTracer.getStats(); assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE + 1000, - after.remoteFlowControlWindow); - assertEquals(flowControlWindow, after.localFlowControlWindow); + after.localFlowControlWindow); + assertEquals(flowControlWindow, after.remoteFlowControlWindow); } @Test public void transportTracer_windowUpdate_local() throws Exception { manualSetUp(); TransportStats before = transportTracer.getStats(); - assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); - assertEquals(flowControlWindow, before.localFlowControlWindow); + assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.localFlowControlWindow); + assertEquals(flowControlWindow, before.remoteFlowControlWindow); // If the window size is below a certain threshold, netty will wait to apply the update. // Use a large increment to be sure that it exceeds the threshold. @@ -497,8 +502,128 @@ public void transportTracer_windowUpdate_local() throws Exception { connection().connectionStream(), 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE); TransportStats after = transportTracer.getStats(); - assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, after.remoteFlowControlWindow); + assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, after.localFlowControlWindow); assertEquals(flowControlWindow + 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE, connection().local().flowController().windowSize(connection().connectionStream())); } + + private AbstractNettyHandler setupPingTest() throws Exception { + this.flowControlWindow = 1024 * 64; + manualSetUp(); + makeStream(); + + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + return handler; + } + + @Test + public void bdpPingLimitOutstanding() throws Exception { + AbstractNettyHandler handler = setupPingTest(); + long pingData = handler.flowControlPing().payload(); + + byte[] data1KbBuf = initXkbBuffer(1); + byte[] data40KbBuf = initXkbBuffer(40); + + readXCopies(1, data1KbBuf); // should initiate a ping + + readXCopies(1, data40KbBuf); // no ping, already active + fakeClock().forwardTime(20, TimeUnit.MILLISECONDS); + readPingAck(pingData); + assertEquals(1, handler.flowControlPing().getPingCount()); + assertEquals(1, handler.flowControlPing().getPingReturn()); + + readXCopies(4, data40KbBuf); // initiate ping + assertEquals(2, handler.flowControlPing().getPingCount()); + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + readPingAck(pingData); + + readXCopies(1, data1KbBuf); // ping again since had 160K data since last ping started + assertEquals(3, handler.flowControlPing().getPingCount()); + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + readPingAck(pingData); + + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + readXCopies(1, data1KbBuf); // no ping, too little data + assertEquals(3, handler.flowControlPing().getPingCount()); + } + + @Test + public void testPingBackoff() throws Exception { + AbstractNettyHandler handler = setupPingTest(); + long pingData = handler.flowControlPing().payload(); + byte[] data40KbBuf = initXkbBuffer(40); + + handler.flowControlPing().setDataSizeAndSincePing(200000); + + for (int i = 0; i <= 10; i++) { + int beforeCount = handler.flowControlPing().getPingCount(); + // should resize on 0 + readXCopies(6, data40KbBuf); // initiate ping on i= {0, 1, 3, 6, 10} + int afterCount = handler.flowControlPing().getPingCount(); + fakeClock().forwardNanos(200); + if (afterCount > beforeCount) { + readPingAck(pingData); // should increase backoff multiplier + } + } + assertEquals(6, handler.flowControlPing().getPingCount()); + } + + @Test + public void bdpPingWindowResizing() throws Exception { + this.flowControlWindow = 1024 * 8; + manualSetUp(); + makeStream(); + + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + Http2LocalFlowController localFlowController = connection().local().flowController(); + long pingData = handler.flowControlPing().payload(); + int initialWindowSize = localFlowController.initialWindowSize(); + byte[] data1Kb = initXkbBuffer(1); + byte[] data10Kb = initXkbBuffer(10); + + readXCopies(1, data1Kb); // initiate ping + fakeClock().forwardNanos(2); + readPingAck(pingData); // should not resize window because of small target window + assertEquals(initialWindowSize, localFlowController.initialWindowSize()); + + readXCopies(2, data10Kb); // initiate ping on first + fakeClock().forwardNanos(200); + readPingAck(pingData); // should resize window + int windowSizeA = localFlowController.initialWindowSize(); + Assert.assertNotEquals(initialWindowSize, windowSizeA); + + readXCopies(3, data10Kb); // initiate ping w/ first 10K packet + fakeClock().forwardNanos(5000); + readPingAck(pingData); // should not resize window as bandwidth didn't increase + Assert.assertEquals(windowSizeA, localFlowController.initialWindowSize()); + + readXCopies(6, data10Kb); // initiate ping with fist packet + fakeClock().forwardNanos(100); + readPingAck(pingData); // should resize window + int windowSizeB = localFlowController.initialWindowSize(); + Assert.assertNotEquals(windowSizeA, windowSizeB); + } + + private void readPingAck(long pingData) throws Exception { + channelRead(pingFrame(true, pingData)); + } + + private void readXCopies(int copies, byte[] data) throws Exception { + for (int i = 0; i < copies; i++) { + channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it + stream().request(1); // consume it + } + } + + private byte[] initXkbBuffer(int multiple) { + ByteBuffer data = ByteBuffer.allocate(1024 * multiple); + + for (int i = 0; i < multiple * 1024 / 4; i++) { + data.putInt(4 * i, 1111); + } + return data.array(); + } + } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 170273e2c60..926ce8261a4 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -60,7 +60,9 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StreamTracer; +import io.grpc.internal.AbstractStream; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; @@ -111,8 +113,6 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase resultCaptor; + + UdsNameResolverProvider udsNameResolverProvider = new UdsNameResolverProvider(); + + + @Test + public void testUnixRelativePath() { + UdsNameResolver udsNameResolver = + udsNameResolverProvider.newNameResolver(URI.create("unix:sock.sock"), null); + assertThat(udsNameResolver).isNotNull(); + udsNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); + DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.path()).isEqualTo("sock.sock"); + } + + @Test + public void testUnixAbsolutePath() { + UdsNameResolver udsNameResolver = + udsNameResolverProvider.newNameResolver(URI.create("unix:/sock.sock"), null); + assertThat(udsNameResolver).isNotNull(); + udsNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); + DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); + } + + @Test + public void testUnixAbsoluteAlternatePath() { + UdsNameResolver udsNameResolver = + udsNameResolverProvider.newNameResolver(URI.create("unix:///sock.sock"), null); + assertThat(udsNameResolver).isNotNull(); + udsNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); + DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); + } + + @Test + public void testUnixPathWithAuthority() { + try { + udsNameResolverProvider.newNameResolver(URI.create("unix://localhost/sock.sock"), null); + fail("exception expected"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java b/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java new file mode 100644 index 00000000000..8eb010e23e5 --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.verify; + +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.netty.channel.unix.DomainSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link UdsNameResolver}. */ +@RunWith(JUnit4.class) +public class UdsNameResolverTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private NameResolver.Listener2 mockListener; + + @Captor + private ArgumentCaptor resultCaptor; + + private UdsNameResolver udsNameResolver; + + @Test + public void testValidTargetPath() { + udsNameResolver = new UdsNameResolver(null, "sock.sock"); + udsNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); + DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.path()).isEqualTo("sock.sock"); + assertThat(udsNameResolver.getServiceAuthority()).isEqualTo("sock.sock"); + } + + @Test + public void testNonNullAuthority() { + try { + udsNameResolver = new UdsNameResolver("authority", "sock.sock"); + fail("exception expected"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java b/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java new file mode 100644 index 00000000000..e0c3d5a8525 --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java @@ -0,0 +1,169 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InternalServiceProviders; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ManagedChannelProvider; +import io.grpc.ManagedChannelProvider.NewChannelBuilderResult; +import io.grpc.ManagedChannelRegistryAccessor; +import io.grpc.TlsChannelCredentials; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerDomainSocketChannel; +import io.netty.channel.unix.DomainSocketAddress; +import java.io.IOException; +import org.junit.After; +import org.junit.Assume; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link UdsNettyChannelProvider}. */ +@RunWith(JUnit4.class) +public class UdsNettyChannelProviderTest { + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + + private UdsNettyChannelProvider provider = new UdsNettyChannelProvider(); + + private EventLoopGroup elg; + private EventLoopGroup boss; + + @After + public void tearDown() { + if (elg != null) { + elg.shutdownGracefully(); + } + if (boss != null) { + boss.shutdownGracefully(); + } + } + + @Test + public void provided() { + for (ManagedChannelProvider current + : InternalServiceProviders.getCandidatesViaServiceLoader( + ManagedChannelProvider.class, getClass().getClassLoader())) { + if (current instanceof UdsNettyChannelProvider) { + return; + } + } + fail("ServiceLoader unable to load UdsNettyChannelProvider"); + } + + @Test + public void providedHardCoded() { + for (Class current : ManagedChannelRegistryAccessor.getHardCodedClasses()) { + if (current == UdsNettyChannelProvider.class) { + return; + } + } + fail("Hard coded unable to load UdsNettyChannelProvider"); + } + + @Test + public void basicMethods() { + Assume.assumeTrue(provider.isAvailable()); + assertEquals(3, provider.priority()); + } + + @Test + public void newChannelBuilder_success() { + Assume.assumeTrue(Utils.isEpollAvailable()); + NewChannelBuilderResult result = + provider.newChannelBuilder("unix:sock.sock", TlsChannelCredentials.create()); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + } + + @Test + public void managedChannelRegistry_newChannelBuilder() { + Assume.assumeTrue(Utils.isEpollAvailable()); + ManagedChannelBuilder managedChannelBuilder + = Grpc.newChannelBuilder("unix:///sock.sock", InsecureChannelCredentials.create()); + assertThat(managedChannelBuilder).isNotNull(); + ManagedChannel channel = managedChannelBuilder.build(); + assertThat(channel).isNotNull(); + assertThat(channel.authority()).isEqualTo("/sock.sock"); + channel.shutdownNow(); + } + + @Test + public void udsClientServerTestUsingProvider() throws IOException { + Assume.assumeTrue(Utils.isEpollAvailable()); + String socketPath = tempFolder.getRoot().getAbsolutePath() + "/test.socket"; + createUdsServer(socketPath); + ManagedChannelBuilder channelBuilder = + Grpc.newChannelBuilder("unix://" + socketPath, InsecureChannelCredentials.create()); + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); + assertThat(unaryRpc("buddy", stub)).isEqualTo("Hello buddy"); + } + + /** Say hello to server. */ + private static String unaryRpc( + String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { + SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build(); + SimpleResponse response = blockingStub.unaryRpc(request); + return response.getResponseMessage(); + } + + private void createUdsServer(String name) throws IOException { + elg = new EpollEventLoopGroup(); + boss = new EpollEventLoopGroup(1); + cleanupRule.register( + NettyServerBuilder.forAddress(new DomainSocketAddress(name)) + .bossEventLoopGroup(boss) + .workerEventLoopGroup(elg) + .channelType(EpollServerDomainSocketChannel.class) + .addService(new SimpleServiceImpl()) + .directExecutor() + .build() + .start()); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + + @Override + public void unaryRpc(SimpleRequest req, StreamObserver responseObserver) { + SimpleResponse response = + SimpleResponse.newBuilder() + .setResponseMessage("Hello " + req.getRequestMessage()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + } +} diff --git a/observability/build.gradle b/observability/build.gradle deleted file mode 100644 index 137e5f6978a..00000000000 --- a/observability/build.gradle +++ /dev/null @@ -1,40 +0,0 @@ -plugins { - id "java-library" - id "maven-publish" - - id "com.google.protobuf" - id "ru.vyarus.animalsniffer" -} - -description = "gRPC: Observability" -dependencies { - def cloudLoggingVersion = '3.6.1' - - api project(':grpc-api'), - project(':grpc-alts') - - implementation project(':grpc-protobuf'), - project(':grpc-stub'), - ('com.google.guava:guava:31.0.1-jre'), - ('com.google.errorprone:error_prone_annotations:2.11.0'), - ('com.google.auth:google-auth-library-credentials:1.4.0'), - ('org.checkerframework:checker-qual:3.20.0'), - ('com.google.auto.value:auto-value-annotations:1.9'), - ('com.google.http-client:google-http-client:1.41.0'), - ('com.google.http-client:google-http-client-gson:1.41.0'), - ('com.google.api.grpc:proto-google-common-protos:2.7.1'), - ("com.google.cloud:google-cloud-logging:${cloudLoggingVersion}") - - testImplementation project(':grpc-testing'), - project(':grpc-testing-proto'), - project(':grpc-netty-shaded') - testImplementation (libraries.guava_testlib) { - exclude group: 'junit', module: 'junit' - } - - signature "org.codehaus.mojo.signature:java18:1.0@signature" -} - -configureProtoCompilation() - -[publishMavenPublicationToMavenRepository]*.onlyIf { false } diff --git a/observability/src/main/java/io/grpc/observability/LoggingChannelProvider.java b/observability/src/main/java/io/grpc/observability/LoggingChannelProvider.java deleted file mode 100644 index 5011068c176..00000000000 --- a/observability/src/main/java/io/grpc/observability/LoggingChannelProvider.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability; - -import static com.google.common.base.Preconditions.checkNotNull; - -import io.grpc.ChannelCredentials; -import io.grpc.InternalManagedChannelProvider; -import io.grpc.ManagedChannelBuilder; -import io.grpc.ManagedChannelProvider; -import io.grpc.ManagedChannelRegistry; -import io.grpc.observability.interceptors.InternalLoggingChannelInterceptor; - -/** A channel provider that injects logging interceptor. */ -final class LoggingChannelProvider extends ManagedChannelProvider { - private final ManagedChannelProvider prevProvider; - private final InternalLoggingChannelInterceptor.Factory clientInterceptorFactory; - - private static LoggingChannelProvider instance; - - private LoggingChannelProvider(InternalLoggingChannelInterceptor.Factory factory) { - prevProvider = ManagedChannelProvider.provider(); - clientInterceptorFactory = factory; - } - - static synchronized void init(InternalLoggingChannelInterceptor.Factory factory) { - if (instance != null) { - throw new IllegalStateException("LoggingChannelProvider already initialized!"); - } - instance = new LoggingChannelProvider(factory); - ManagedChannelRegistry.getDefaultRegistry().register(instance); - } - - static synchronized void finish() { - if (instance == null) { - throw new IllegalStateException("LoggingChannelProvider not initialized!"); - } - ManagedChannelRegistry.getDefaultRegistry().deregister(instance); - instance = null; - } - - @Override - protected boolean isAvailable() { - return true; - } - - @Override - protected int priority() { - return 6; - } - - private ManagedChannelBuilder addInterceptor(ManagedChannelBuilder builder) { - return builder.intercept(clientInterceptorFactory.create()); - } - - @Override - protected ManagedChannelBuilder builderForAddress(String name, int port) { - return addInterceptor( - InternalManagedChannelProvider.builderForAddress(prevProvider, name, port)); - } - - @Override - protected ManagedChannelBuilder builderForTarget(String target) { - return addInterceptor(InternalManagedChannelProvider.builderForTarget(prevProvider, target)); - } - - @Override - protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds) { - NewChannelBuilderResult result = InternalManagedChannelProvider.newChannelBuilder(prevProvider, - target, creds); - ManagedChannelBuilder builder = result.getChannelBuilder(); - if (builder != null) { - return NewChannelBuilderResult.channelBuilder( - addInterceptor(builder)); - } - checkNotNull(result.getError(), "Expected error to be set!"); - return result; - } -} diff --git a/observability/src/main/java/io/grpc/observability/LoggingServerProvider.java b/observability/src/main/java/io/grpc/observability/LoggingServerProvider.java deleted file mode 100644 index 5277bcf572b..00000000000 --- a/observability/src/main/java/io/grpc/observability/LoggingServerProvider.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability; - -import static com.google.common.base.Preconditions.checkNotNull; - -import io.grpc.InternalServerProvider; -import io.grpc.ServerBuilder; -import io.grpc.ServerCredentials; -import io.grpc.ServerProvider; -import io.grpc.ServerRegistry; -import io.grpc.observability.interceptors.InternalLoggingServerInterceptor; - -/** A server provider that injects the logging interceptor. */ -final class LoggingServerProvider extends ServerProvider { - private final ServerProvider prevProvider; - private final InternalLoggingServerInterceptor.Factory serverInterceptorFactory; - - private static LoggingServerProvider instance; - - private LoggingServerProvider(InternalLoggingServerInterceptor.Factory factory) { - prevProvider = ServerProvider.provider(); - serverInterceptorFactory = factory; - } - - static synchronized void init(InternalLoggingServerInterceptor.Factory factory) { - if (instance != null) { - throw new IllegalStateException("LoggingServerProvider already initialized!"); - } - instance = new LoggingServerProvider(factory); - ServerRegistry.getDefaultRegistry().register(instance); - } - - static synchronized void finish() { - if (instance == null) { - throw new IllegalStateException("LoggingServerProvider not initialized!"); - } - ServerRegistry.getDefaultRegistry().deregister(instance); - instance = null; - } - - @Override - protected boolean isAvailable() { - return true; - } - - @Override - protected int priority() { - return 6; - } - - private ServerBuilder addInterceptor(ServerBuilder builder) { - return builder.intercept(serverInterceptorFactory.create()); - } - - @Override - protected ServerBuilder builderForPort(int port) { - return addInterceptor(InternalServerProvider.builderForPort(prevProvider, port)); - } - - @Override - protected NewServerBuilderResult newServerBuilderForPort(int port, ServerCredentials creds) { - ServerProvider.NewServerBuilderResult result = InternalServerProvider.newServerBuilderForPort( - prevProvider, port, - creds); - ServerBuilder builder = result.getServerBuilder(); - if (builder != null) { - return ServerProvider.NewServerBuilderResult.serverBuilder( - addInterceptor(builder)); - } - checkNotNull(result.getError(), "Expected error to be set!"); - return result; - } -} diff --git a/observability/src/main/java/io/grpc/observability/Observability.java b/observability/src/main/java/io/grpc/observability/Observability.java deleted file mode 100644 index a7a6b04ad07..00000000000 --- a/observability/src/main/java/io/grpc/observability/Observability.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability; - -import io.grpc.ExperimentalApi; -import io.grpc.ManagedChannelProvider.ProviderNotFoundException; -import io.grpc.observability.interceptors.InternalLoggingChannelInterceptor; -import io.grpc.observability.interceptors.InternalLoggingServerInterceptor; - -/** The main class for gRPC Observability features. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8869") -public final class Observability { - private static boolean initialized = false; - - /** - * Initialize grpc-observability. - * - * @throws ProviderNotFoundException if no underlying channel/server provider is available. - */ - public static synchronized void grpcInit() { - if (initialized) { - throw new IllegalStateException("Observability already initialized!"); - } - LoggingChannelProvider.init(new InternalLoggingChannelInterceptor.FactoryImpl()); - LoggingServerProvider.init(new InternalLoggingServerInterceptor.FactoryImpl()); - // TODO(sanjaypujare): initialize customTags map - initialized = true; - } - - /** Un-initialize or finish grpc-observability. */ - public static synchronized void grpcFinish() { - if (!initialized) { - throw new IllegalStateException("Observability not initialized!"); - } - LoggingChannelProvider.finish(); - LoggingServerProvider.finish(); - // TODO(sanjaypujare): finish customTags map - initialized = false; - } - - private Observability() { - } -} diff --git a/observability/src/main/java/io/grpc/observability/interceptors/InternalLoggingServerInterceptor.java b/observability/src/main/java/io/grpc/observability/interceptors/InternalLoggingServerInterceptor.java deleted file mode 100644 index 17385653121..00000000000 --- a/observability/src/main/java/io/grpc/observability/interceptors/InternalLoggingServerInterceptor.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability.interceptors; - -import io.grpc.Internal; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; - -/** A logging interceptor for {@code LoggingServerProvider}. */ -@Internal -public final class InternalLoggingServerInterceptor implements ServerInterceptor { - - public interface Factory { - ServerInterceptor create(); - } - - public static class FactoryImpl implements Factory { - - @Override - public ServerInterceptor create() { - return new InternalLoggingServerInterceptor(); - } - } - - @Override - public ServerCall.Listener interceptCall(ServerCall call, - Metadata headers, ServerCallHandler next) { - // TODO(dnvindhya) implement the interceptor - return null; - } -} diff --git a/observability/src/main/java/io/grpc/observability/logging/CloudLoggingHandler.java b/observability/src/main/java/io/grpc/observability/logging/CloudLoggingHandler.java deleted file mode 100644 index f325b9d8383..00000000000 --- a/observability/src/main/java/io/grpc/observability/logging/CloudLoggingHandler.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability.logging; - -import com.google.cloud.MonitoredResource; -import com.google.cloud.logging.LogEntry; -import com.google.cloud.logging.Logging; -import com.google.cloud.logging.LoggingOptions; -import com.google.cloud.logging.Payload.JsonPayload; -import com.google.cloud.logging.Severity; -import com.google.common.base.Strings; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.util.JsonFormat; -import io.grpc.Internal; -import io.grpc.internal.JsonParser; -import io.grpc.observabilitylog.v1.GrpcLogRecord; -import java.io.IOException; -import java.util.Collections; -import java.util.Map; -import java.util.logging.Handler; -import java.util.logging.Level; -import java.util.logging.LogRecord; - -/** - * Custom logging handler that outputs logs generated using {@link java.util.logging.Logger} to - * Cloud Logging. - */ -// TODO(vindhyan): replace custom JUL handler with internal sink implementation to eliminate -// JUL dependency -@Internal -public class CloudLoggingHandler extends Handler { - - private static final String DEFAULT_LOG_NAME = "grpc-observability"; - private static final Level DEFAULT_LOG_LEVEL = Level.ALL; - - private final LoggingOptions loggingOptions; - private final Logging loggingClient; - private final Level baseLevel; - private final String cloudLogName; - - /** - * Creates a custom logging handler that publishes message to Cloud logging. Default log level is - * set to Level.FINEST if level is not passed. - */ - public CloudLoggingHandler() { - this(DEFAULT_LOG_LEVEL, null, null); - } - - /** - * Creates a custom logging handler that publishes message to Cloud logging. - * - * @param level set the level for which message levels will be logged by the custom logger - */ - public CloudLoggingHandler(Level level) { - this(level, null, null); - } - - /** - * Creates a custom logging handler that publishes message to Cloud logging. - * - * @param level set the level for which message levels will be logged by the custom logger - * @param logName the name of the log to which log entries are written - */ - public CloudLoggingHandler(Level level, String logName) { - this(level, logName, null); - } - - /** - * Creates a custom logging handler that publishes message to Cloud logging. - * - * @param level set the level for which message levels will be logged by the custom logger - * @param logName the name of the log to which log entries are written - * @param destinationProjectId the value of cloud project id to which logs are sent to by the - * custom logger - */ - public CloudLoggingHandler(Level level, String logName, String destinationProjectId) { - baseLevel = - (level != null) ? (level.equals(DEFAULT_LOG_LEVEL) ? Level.FINEST : level) : Level.FINEST; - setLevel(baseLevel); - cloudLogName = logName != null ? logName : DEFAULT_LOG_NAME; - - // TODO(dnvindhya) read the value from config instead of taking it as an argument - if (Strings.isNullOrEmpty(destinationProjectId)) { - loggingOptions = LoggingOptions.getDefaultInstance(); - } else { - loggingOptions = LoggingOptions.newBuilder().setProjectId(destinationProjectId).build(); - } - loggingClient = loggingOptions.getService(); - } - - @Override - public void publish(LogRecord record) { - if (!(record instanceof LogRecordExtension)) { - throw new IllegalArgumentException("Expected record of type LogRecordExtension"); - } - Level logLevel = record.getLevel(); - GrpcLogRecord protoRecord = ((LogRecordExtension) record).getGrpcLogRecord(); - writeLog(protoRecord, logLevel); - } - - private void writeLog(GrpcLogRecord logProto, Level logLevel) { - if (loggingClient == null) { - throw new IllegalStateException("Logging client not initialized"); - } - try { - Severity cloudLogLevel = getCloudLoggingLevel(logLevel); - Map mapPayload = protoToMapConverter(logProto); - - // TODO(vindhyan): make sure all (int, long) values are not displayed as double - LogEntry grpcLogEntry = - LogEntry.newBuilder(JsonPayload.of(mapPayload)) - .setSeverity(cloudLogLevel) - .setLogName(cloudLogName) - .setResource(MonitoredResource.newBuilder("global").build()) - .build(); - loggingClient.write(Collections.singleton(grpcLogEntry)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @SuppressWarnings("unchecked") - private Map protoToMapConverter(GrpcLogRecord logProto) - throws InvalidProtocolBufferException, IOException { - JsonFormat.Printer printer = JsonFormat.printer().preservingProtoFieldNames(); - String recordJson = printer.print(logProto); - return (Map) JsonParser.parse(recordJson); - } - - @Override - public void flush() { - if (loggingClient == null) { - throw new IllegalStateException("Logging client not initialized"); - } - loggingClient.flush(); - } - - @Override - public synchronized void close() throws SecurityException { - if (loggingClient == null) { - throw new IllegalStateException("Logging client not initialized"); - } - try { - loggingClient.close(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private Severity getCloudLoggingLevel(Level recordLevel) { - switch (recordLevel.intValue()) { - case 300: // FINEST - case 400: // FINER - case 500: // FINE - return Severity.DEBUG; - case 700: // CONFIG - case 800: // INFO - return Severity.INFO; - case 900: // WARNING - return Severity.WARNING; - case 1000: // SEVERE - return Severity.ERROR; - default: - return Severity.DEFAULT; - } - } -} diff --git a/observability/src/test/java/io/grpc/observability/LoggingChannelProviderTest.java b/observability/src/test/java/io/grpc/observability/LoggingChannelProviderTest.java deleted file mode 100644 index 639bcbc6d0d..00000000000 --- a/observability/src/test/java/io/grpc/observability/LoggingChannelProviderTest.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.AdditionalAnswers.delegatesTo; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.Grpc; -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.ManagedChannelProvider; -import io.grpc.MethodDescriptor; -import io.grpc.TlsChannelCredentials; -import io.grpc.observability.interceptors.InternalLoggingChannelInterceptor; -import io.grpc.testing.TestMethodDescriptors; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentMatchers; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -@RunWith(JUnit4.class) -public class LoggingChannelProviderTest { - @Rule - public final MockitoRule mocks = MockitoJUnit.rule(); - - private final MethodDescriptor method = TestMethodDescriptors.voidMethod(); - - @Test - public void initTwiceCausesException() { - ManagedChannelProvider prevProvider = ManagedChannelProvider.provider(); - assertThat(prevProvider).isNotInstanceOf(LoggingChannelProvider.class); - LoggingChannelProvider.init(new InternalLoggingChannelInterceptor.FactoryImpl()); - assertThat(ManagedChannelProvider.provider()).isInstanceOf(LoggingChannelProvider.class); - try { - LoggingChannelProvider.init(new InternalLoggingChannelInterceptor.FactoryImpl()); - fail("should have failed for calling init() again"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("LoggingChannelProvider already initialized!"); - } - LoggingChannelProvider.finish(); - assertThat(ManagedChannelProvider.provider()).isSameInstanceAs(prevProvider); - } - - @Test - public void forTarget_interceptorCalled() { - ClientInterceptor interceptor = mock(ClientInterceptor.class, - delegatesTo(new NoopInterceptor())); - InternalLoggingChannelInterceptor.Factory factory = mock( - InternalLoggingChannelInterceptor.Factory.class); - when(factory.create()).thenReturn(interceptor); - LoggingChannelProvider.init(factory); - ManagedChannelBuilder builder = ManagedChannelBuilder.forTarget("localhost"); - ManagedChannel channel = builder.build(); - CallOptions callOptions = CallOptions.DEFAULT; - - ClientCall unused = channel.newCall(method, callOptions); - verify(interceptor) - .interceptCall(same(method), same(callOptions), ArgumentMatchers.any()); - channel.shutdownNow(); - LoggingChannelProvider.finish(); - } - - @Test - public void forAddress_interceptorCalled() { - ClientInterceptor interceptor = mock(ClientInterceptor.class, - delegatesTo(new NoopInterceptor())); - InternalLoggingChannelInterceptor.Factory factory = mock( - InternalLoggingChannelInterceptor.Factory.class); - when(factory.create()).thenReturn(interceptor); - LoggingChannelProvider.init(factory); - ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress("localhost", 80); - ManagedChannel channel = builder.build(); - CallOptions callOptions = CallOptions.DEFAULT; - - ClientCall unused = channel.newCall(method, callOptions); - verify(interceptor) - .interceptCall(same(method), same(callOptions), ArgumentMatchers.any()); - channel.shutdownNow(); - LoggingChannelProvider.finish(); - } - - @Test - public void newChannelBuilder_interceptorCalled() { - ClientInterceptor interceptor = mock(ClientInterceptor.class, - delegatesTo(new NoopInterceptor())); - InternalLoggingChannelInterceptor.Factory factory = mock( - InternalLoggingChannelInterceptor.Factory.class); - when(factory.create()).thenReturn(interceptor); - LoggingChannelProvider.init(factory); - ManagedChannelBuilder builder = Grpc.newChannelBuilder("localhost", - TlsChannelCredentials.create()); - ManagedChannel channel = builder.build(); - CallOptions callOptions = CallOptions.DEFAULT; - - ClientCall unused = channel.newCall(method, callOptions); - verify(interceptor) - .interceptCall(same(method), same(callOptions), ArgumentMatchers.any()); - channel.shutdownNow(); - LoggingChannelProvider.finish(); - } - - private static class NoopInterceptor implements ClientInterceptor { - @Override - public ClientCall interceptCall(MethodDescriptor method, - CallOptions callOptions, Channel next) { - return next.newCall(method, callOptions); - } - } -} diff --git a/observability/src/test/java/io/grpc/observability/LoggingServerProviderTest.java b/observability/src/test/java/io/grpc/observability/LoggingServerProviderTest.java deleted file mode 100644 index fd6b60b4738..00000000000 --- a/observability/src/test/java/io/grpc/observability/LoggingServerProviderTest.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.AdditionalAnswers.delegatesTo; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.Grpc; -import io.grpc.InsecureServerCredentials; -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.Metadata; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.ServerProvider; -import io.grpc.observability.interceptors.InternalLoggingServerInterceptor; -import io.grpc.stub.StreamObserver; -import io.grpc.testing.GrpcCleanupRule; -import io.grpc.testing.protobuf.SimpleRequest; -import io.grpc.testing.protobuf.SimpleResponse; -import io.grpc.testing.protobuf.SimpleServiceGrpc; -import java.io.IOException; -import java.util.function.Supplier; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentMatchers; - -@RunWith(JUnit4.class) -public class LoggingServerProviderTest { - @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - - @Test - public void initTwiceCausesException() { - ServerProvider prevProvider = ServerProvider.provider(); - assertThat(prevProvider).isNotInstanceOf(LoggingServerProvider.class); - LoggingServerProvider.init(new InternalLoggingServerInterceptor.FactoryImpl()); - assertThat(ServerProvider.provider()).isInstanceOf(ServerProvider.class); - try { - LoggingServerProvider.init(new InternalLoggingServerInterceptor.FactoryImpl()); - fail("should have failed for calling init() again"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("LoggingServerProvider already initialized!"); - } - LoggingServerProvider.finish(); - assertThat(ServerProvider.provider()).isSameInstanceAs(prevProvider); - } - - @Test - public void forPort_interceptorCalled() throws IOException { - serverBuilder_interceptorCalled(() -> ServerBuilder.forPort(0)); - } - - @Test - public void newServerBuilder_interceptorCalled() throws IOException { - serverBuilder_interceptorCalled( - () -> Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())); - } - - @SuppressWarnings("unchecked") - private void serverBuilder_interceptorCalled(Supplier> serverBuilderSupplier) - throws IOException { - ServerInterceptor interceptor = - mock(ServerInterceptor.class, delegatesTo(new NoopInterceptor())); - InternalLoggingServerInterceptor.Factory factory = mock( - InternalLoggingServerInterceptor.Factory.class); - when(factory.create()).thenReturn(interceptor); - LoggingServerProvider.init(factory); - Server server = serverBuilderSupplier.get().addService(new SimpleServiceImpl()).build().start(); - int port = cleanupRule.register(server).getPort(); - ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext() - .build(); - SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub( - cleanupRule.register(channel)); - assertThat(unaryRpc("buddy", stub)).isEqualTo("Hello buddy"); - verify(interceptor).interceptCall(any(ServerCall.class), any(Metadata.class), anyCallHandler()); - LoggingServerProvider.finish(); - } - - private ServerCallHandler anyCallHandler() { - return ArgumentMatchers.any(); - } - - private static String unaryRpc( - String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { - SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build(); - SimpleResponse response = blockingStub.unaryRpc(request); - return response.getResponseMessage(); - } - - private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { - - @Override - public void unaryRpc(SimpleRequest req, StreamObserver responseObserver) { - SimpleResponse response = - SimpleResponse.newBuilder() - .setResponseMessage("Hello " + req.getRequestMessage()) - .build(); - responseObserver.onNext(response); - responseObserver.onCompleted(); - } - } - - private static class NoopInterceptor implements ServerInterceptor { - @Override - public ServerCall.Listener interceptCall( - ServerCall call, - Metadata headers, - ServerCallHandler next) { - return next.startCall(call, headers); - } - } -} diff --git a/observability/src/test/java/io/grpc/observability/ObservabilityTest.java b/observability/src/test/java/io/grpc/observability/ObservabilityTest.java deleted file mode 100644 index 6c71a0e2640..00000000000 --- a/observability/src/test/java/io/grpc/observability/ObservabilityTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.observability; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class ObservabilityTest { - - @Test - public void initFinish() { - Observability.grpcInit(); - try { - Observability.grpcInit(); - fail("should have failed for calling grpcInit() again"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Observability already initialized!"); - } - Observability.grpcFinish(); - try { - Observability.grpcFinish(); - fail("should have failed for calling grpcFinit() on uninitialized"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Observability not initialized!"); - } - } -} diff --git a/okhttp/BUILD.bazel b/okhttp/BUILD.bazel index d690086df8f..e550634aca0 100644 --- a/okhttp/BUILD.bazel +++ b/okhttp/BUILD.bazel @@ -11,6 +11,7 @@ java_library( deps = [ "//api", "//core:internal", + "//core:util", "@com_google_code_findbugs_jsr305//jar", "@com_google_errorprone_error_prone_annotations//jar", "@com_google_guava_guava//jar", diff --git a/okhttp/build.gradle b/okhttp/build.gradle index 999f21e7c10..439abaa3373 100644 --- a/okhttp/build.gradle +++ b/okhttp/build.gradle @@ -11,18 +11,21 @@ description = "gRPC: OkHttp" evaluationDependsOn(project(':grpc-core').path) dependencies { - api project(':grpc-core'), - libraries.okhttp + api project(':grpc-core') implementation libraries.okio, libraries.guava, - libraries.perfmark + libraries.perfmark.api + // Make okhttp dependencies compile only + compileOnly libraries.okhttp // Tests depend on base class defined by core module. testImplementation project(':grpc-core').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, project(':grpc-testing'), - project(':grpc-netty') - signature "org.codehaus.mojo.signature:java17:1.0@signature" - signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" + project(':grpc-testing-proto'), + libraries.netty.codec.http2, + libraries.okhttp + signature libraries.signature.java + signature libraries.signature.android } project.sourceSets { @@ -30,15 +33,17 @@ project.sourceSets { test { java { srcDir "${projectDir}/third_party/okhttp/test/java" } } } -checkstyleMain.exclude '**/io/grpc/okhttp/internal/**' +tasks.named("checkstyleMain").configure { + exclude '**/io/grpc/okhttp/internal/**' +} -javadoc { +tasks.named("javadoc").configure { options.links 'http://square.github.io/okhttp/2.x/okhttp/' exclude 'io/grpc/okhttp/Internal*' exclude 'io/grpc/okhttp/internal/**' } -jacocoTestReport { +tasks.named("jacocoTestReport").configure { classDirectories.from = sourceSets.main.output.collect { fileTree(dir: it, exclude: [ diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java index a8cfbcaad2f..faf1b7e3012 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -21,6 +21,9 @@ import io.grpc.internal.SerializingExecutor; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Settings; import io.perfmark.Link; import io.perfmark.PerfMark; import java.io.IOException; @@ -33,7 +36,8 @@ /** * A sink that asynchronously write / flushes a buffer internally. AsyncSink provides flush - * coalescing to minimize network packing transmit. + * coalescing to minimize network packing transmit. Because I/O is handled asynchronously, most I/O + * exceptions will be delivered via a callback. */ final class AsyncSink implements Sink { @@ -41,6 +45,7 @@ final class AsyncSink implements Sink { private final Buffer buffer = new Buffer(); private final SerializingExecutor serializingExecutor; private final TransportExceptionHandler transportExceptionHandler; + private final int maxQueuedControlFrames; @GuardedBy("lock") private boolean writeEnqueued = false; @@ -51,15 +56,26 @@ final class AsyncSink implements Sink { private Sink sink; @Nullable private Socket socket; + private boolean controlFramesExceeded; + private int controlFramesInWrite; + @GuardedBy("lock") + private int queuedControlFrames; - private AsyncSink(SerializingExecutor executor, TransportExceptionHandler exceptionHandler) { + private AsyncSink(SerializingExecutor executor, TransportExceptionHandler exceptionHandler, + int maxQueuedControlFrames) { this.serializingExecutor = checkNotNull(executor, "executor"); this.transportExceptionHandler = checkNotNull(exceptionHandler, "exceptionHandler"); + this.maxQueuedControlFrames = maxQueuedControlFrames; } + /** + * {@code maxQueuedControlFrames} is only effective for frames written with + * {@link #limitControlFramesWriter(FrameWriter)}. + */ static AsyncSink sink( - SerializingExecutor executor, TransportExceptionHandler exceptionHandler) { - return new AsyncSink(executor, exceptionHandler); + SerializingExecutor executor, TransportExceptionHandler exceptionHandler, + int maxQueuedControlFrames) { + return new AsyncSink(executor, exceptionHandler, maxQueuedControlFrames); } /** @@ -74,6 +90,10 @@ void becomeConnected(Sink sink, Socket socket) { this.socket = checkNotNull(socket, "socket"); } + FrameWriter limitControlFramesWriter(FrameWriter delegate) { + return new LimitControlFramesWriter(delegate); + } + @Override public void write(Buffer source, long byteCount) throws IOException { checkNotNull(source, "source"); @@ -82,12 +102,29 @@ public void write(Buffer source, long byteCount) throws IOException { } PerfMark.startTask("AsyncSink.write"); try { + boolean closeSocket = false; synchronized (lock) { buffer.write(source, byteCount); - if (writeEnqueued || flushEnqueued || buffer.completeSegmentByteCount() <= 0) { - return; + + queuedControlFrames += controlFramesInWrite; + controlFramesInWrite = 0; + if (!controlFramesExceeded && queuedControlFrames > maxQueuedControlFrames) { + controlFramesExceeded = true; + closeSocket = true; + } else { + if (writeEnqueued || flushEnqueued || buffer.completeSegmentByteCount() <= 0) { + return; + } + writeEnqueued = true; } - writeEnqueued = true; + } + if (closeSocket) { + try { + socket.close(); + } catch (IOException e) { + transportExceptionHandler.onException(e); + } + return; } serializingExecutor.execute(new WriteRunnable() { final Link link = PerfMark.linkOut(); @@ -97,11 +134,18 @@ public void doRun() throws IOException { PerfMark.linkIn(link); Buffer buf = new Buffer(); try { + int writingControlFrames; synchronized (lock) { buf.write(buffer, buffer.completeSegmentByteCount()); writeEnqueued = false; + // Imprecise because we only tranfer complete segments, but not by much and error + // won't accumulate over time + writingControlFrames = queuedControlFrames; } sink.write(buf, buf.size()); + synchronized (lock) { + queuedControlFrames -= writingControlFrames; + } } finally { PerfMark.stopTask("WriteRunnable.runWrite"); } @@ -163,6 +207,13 @@ public void close() { serializingExecutor.execute(new Runnable() { @Override public void run() { + try { + if (sink != null && buffer.size() > 0) { + sink.write(buffer, buffer.size()); + } + } catch (IOException e) { + transportExceptionHandler.onException(e); + } buffer.close(); try { if (sink != null) { @@ -197,4 +248,30 @@ public final void run() { public abstract void doRun() throws IOException; } -} \ No newline at end of file + + private class LimitControlFramesWriter extends ForwardingFrameWriter { + public LimitControlFramesWriter(FrameWriter delegate) { + super(delegate); + } + + @Override + public void ackSettings(Settings peerSettings) throws IOException { + controlFramesInWrite++; + super.ackSettings(peerSettings); + } + + @Override + public void rstStream(int streamId, ErrorCode errorCode) throws IOException { + controlFramesInWrite++; + super.rstStream(streamId, errorCode); + } + + @Override + public void ping(boolean ack, int payload1, int payload2) throws IOException { + if (ack) { + controlFramesInWrite++; + } + super.ping(ack, payload1, payload2); + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java index 9f7074121fc..2e21b1547d8 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java +++ b/okhttp/src/main/java/io/grpc/okhttp/ExceptionHandlingFrameWriter.java @@ -31,6 +31,12 @@ import okio.Buffer; import okio.ByteString; +/** + * FrameWriter that propagates IOExceptions via callback instead of throwing. This allows + * centralized handling of errors. Exceptions only impact the single call that throws them; callers + * should be sure to kill the connection after an exception (potentially after sending a GOAWAY) as + * otherwise additional frames after the failed/omitted one could cause HTTP/2 confusion. + */ final class ExceptionHandlingFrameWriter implements FrameWriter { private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); @@ -39,23 +45,14 @@ final class ExceptionHandlingFrameWriter implements FrameWriter { private final FrameWriter frameWriter; - private final OkHttpFrameLogger frameLogger; + private final OkHttpFrameLogger frameLogger = + new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class); ExceptionHandlingFrameWriter( TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) { - this(transportExceptionHandler, frameWriter, - new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class)); - } - - @VisibleForTesting - ExceptionHandlingFrameWriter( - TransportExceptionHandler transportExceptionHandler, - FrameWriter frameWriter, - OkHttpFrameLogger frameLogger) { this.transportExceptionHandler = checkNotNull(transportExceptionHandler, "transportExceptionHandler"); this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); - this.frameLogger = Preconditions.checkNotNull(frameLogger, "frameLogger"); } @Override diff --git a/okhttp/src/main/java/io/grpc/okhttp/ForwardingFrameWriter.java b/okhttp/src/main/java/io/grpc/okhttp/ForwardingFrameWriter.java new file mode 100644 index 00000000000..ae173376e8d --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/ForwardingFrameWriter.java @@ -0,0 +1,116 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.Settings; +import java.io.IOException; +import java.util.List; +import okio.Buffer; + + +/** FrameWriter that forwards all calls to a delegate. */ +abstract class ForwardingFrameWriter implements FrameWriter { + private final FrameWriter delegate; + + public ForwardingFrameWriter(FrameWriter delegate) { + this.delegate = Preconditions.checkNotNull(delegate, "delegate"); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public void connectionPreface() throws IOException { + delegate.connectionPreface(); + } + + @Override + public void ackSettings(Settings peerSettings) throws IOException { + delegate.ackSettings(peerSettings); + } + + @Override + public void pushPromise(int streamId, int promisedStreamId, List

requestHeaders) + throws IOException { + delegate.pushPromise(streamId, promisedStreamId, requestHeaders); + } + + @Override + public void flush() throws IOException { + delegate.flush(); + } + + @Override + public void synStream(boolean outFinished, boolean inFinished, int streamId, + int associatedStreamId, List
headerBlock) throws IOException { + delegate.synStream(outFinished, inFinished, streamId, associatedStreamId, headerBlock); + } + + @Override + public void synReply(boolean outFinished, int streamId, List
headerBlock) + throws IOException { + delegate.synReply(outFinished, streamId, headerBlock); + } + + @Override + public void headers(int streamId, List
headerBlock) throws IOException { + delegate.headers(streamId, headerBlock); + } + + @Override + public void rstStream(int streamId, ErrorCode errorCode) throws IOException { + delegate.rstStream(streamId, errorCode); + } + + @Override + public int maxDataLength() { + return delegate.maxDataLength(); + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) + throws IOException { + delegate.data(outFinished, streamId, source, byteCount); + } + + @Override + public void settings(Settings okHttpSettings) throws IOException { + delegate.settings(okHttpSettings); + } + + @Override + public void ping(boolean ack, int payload1, int payload2) throws IOException { + delegate.ping(ack, payload1, payload2); + } + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, byte[] debugData) + throws IOException { + delegate.goAway(lastGoodStreamId, errorCode, debugData); + } + + @Override + public void windowUpdate(int streamId, long windowSizeIncrement) throws IOException { + delegate.windowUpdate(streamId, windowSizeIncrement); + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/HandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/HandshakerSocketFactory.java new file mode 100644 index 00000000000..a6cf8db9b4f --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/HandshakerSocketFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.Attributes; +import io.grpc.InternalChannelz; +import java.io.IOException; +import java.net.Socket; + +/** Handshakes new connections. */ +interface HandshakerSocketFactory { + /** When the returned socket is closed, {@code socket} must be closed. */ + HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException; + + static final class HandshakeResult { + public final Socket socket; + public final Attributes attributes; + public final InternalChannelz.Security securityInfo; + + public HandshakeResult( + Socket socket, Attributes attributes, InternalChannelz.Security securityInfo) { + this.socket = Preconditions.checkNotNull(socket, "socket"); + this.attributes = Preconditions.checkNotNull(attributes, "attributes"); + this.securityInfo = securityInfo; + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/Headers.java b/okhttp/src/main/java/io/grpc/okhttp/Headers.java index 15008f8040f..ff2033b35be 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Headers.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Headers.java @@ -16,9 +16,6 @@ package io.grpc.okhttp; -import static io.grpc.internal.GrpcUtil.CONTENT_TYPE_KEY; -import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY; - import com.google.common.base.Preconditions; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -39,7 +36,7 @@ class Headers { public static final Header METHOD_HEADER = new Header(Header.TARGET_METHOD, GrpcUtil.HTTP_METHOD); public static final Header METHOD_GET_HEADER = new Header(Header.TARGET_METHOD, "GET"); public static final Header CONTENT_TYPE_HEADER = - new Header(CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC); + new Header(GrpcUtil.CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC); public static final Header TE_HEADER = new Header("te", GrpcUtil.TE_TRAILERS); /** @@ -58,10 +55,7 @@ public static List
createRequestHeaders( Preconditions.checkNotNull(defaultPath, "defaultPath"); Preconditions.checkNotNull(authority, "authority"); - // Discard any application supplied duplicates of the reserved headers - headers.discardAll(GrpcUtil.CONTENT_TYPE_KEY); - headers.discardAll(GrpcUtil.TE_HEADER); - headers.discardAll(GrpcUtil.USER_AGENT_KEY); + stripNonApplicationHeaders(headers); // 7 is the number of explicit add calls below. List
okhttpHeaders = new ArrayList<>(7 + InternalMetadata.headerCount(headers)); @@ -89,27 +83,72 @@ public static List
createRequestHeaders( okhttpHeaders.add(TE_HEADER); // Now add any application-provided headers. - byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers); + return addMetadata(okhttpHeaders, headers); + } + + /** + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * starting a response. Since this serializes the headers, this method should be called in the + * application thread context. + */ + public static List
createResponseHeaders(Metadata headers) { + stripNonApplicationHeaders(headers); + + // 2 is the number of explicit add calls below. + List
okhttpHeaders = new ArrayList<>(2 + InternalMetadata.headerCount(headers)); + okhttpHeaders.add(new Header(Header.RESPONSE_STATUS, "200")); + // All non-pseudo headers must come after pseudo headers. + okhttpHeaders.add(CONTENT_TYPE_HEADER); + return addMetadata(okhttpHeaders, headers); + } + + /** + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * finishing a response. Since this serializes the headers, this method should be called in the + * application thread context. + */ + public static List
createResponseTrailers(Metadata trailers, boolean headersSent) { + if (!headersSent) { + return createResponseHeaders(trailers); + } + stripNonApplicationHeaders(trailers); + + List
okhttpTrailers = new ArrayList<>(InternalMetadata.headerCount(trailers)); + return addMetadata(okhttpTrailers, trailers); + } + + /** + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * failing with an HTTP response. + */ + public static List
createHttpResponseHeaders( + int httpCode, String contentType, Metadata headers) { + // 2 is the number of explicit add calls below. + List
okhttpHeaders = new ArrayList<>(2 + InternalMetadata.headerCount(headers)); + okhttpHeaders.add(new Header(Header.RESPONSE_STATUS, "" + httpCode)); + // All non-pseudo headers must come after pseudo headers. + okhttpHeaders.add(new Header(GrpcUtil.CONTENT_TYPE_KEY.name(), contentType)); + return addMetadata(okhttpHeaders, headers); + } + + private static List
addMetadata(List
okhttpHeaders, Metadata toAdd) { + byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(toAdd); for (int i = 0; i < serializedHeaders.length; i += 2) { ByteString key = ByteString.of(serializedHeaders[i]); - String keyString = key.utf8(); - if (isApplicationHeader(keyString)) { - ByteString value = ByteString.of(serializedHeaders[i + 1]); - okhttpHeaders.add(new Header(key, value)); + // Don't allow HTTP/2 pseudo headers to be added by the application. + if (key.size() == 0 || key.getByte(0) == ':') { + continue; } + ByteString value = ByteString.of(serializedHeaders[i + 1]); + okhttpHeaders.add(new Header(key, value)); } - return okhttpHeaders; } - /** - * Returns {@code true} if the given header is an application-provided header. Otherwise, returns - * {@code false} if the header is reserved by GRPC. - */ - private static boolean isApplicationHeader(String key) { - // Don't allow HTTP/2 pseudo headers or content-type to be added by the application. - return (!key.startsWith(":") - && !CONTENT_TYPE_KEY.name().equalsIgnoreCase(key)) - && !USER_AGENT_KEY.name().equalsIgnoreCase(key); + /** Strips all non-pseudo headers reserved by gRPC, to avoid duplicates and misinterpretation. */ + private static void stripNonApplicationHeaders(Metadata headers) { + headers.discardAll(GrpcUtil.CONTENT_TYPE_KEY); + headers.discardAll(GrpcUtil.TE_HEADER); + headers.discardAll(GrpcUtil.USER_AGENT_KEY); } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/InternalOkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/InternalOkHttpServerBuilder.java new file mode 100644 index 00000000000..78a409a3f85 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/InternalOkHttpServerBuilder.java @@ -0,0 +1,46 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import io.grpc.Internal; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.InternalServer; +import io.grpc.internal.TransportTracer; +import java.util.List; + +/** + * Internal {@link OkHttpServerBuilder} accessor. This is intended for usage internal to + * the gRPC team. If you *really* think you need to use this, contact the gRPC team first. + */ +@Internal +public final class InternalOkHttpServerBuilder { + public static InternalServer buildTransportServers(OkHttpServerBuilder builder, + List streamTracerFactories) { + return builder.buildTransportServers(streamTracerFactories); + } + + public static void setTransportTracerFactory(OkHttpServerBuilder builder, + TransportTracer.Factory transportTracerFactory) { + builder.setTransportTracerFactory(transportTracerFactory); + } + + public static void setStatsEnabled(OkHttpServerBuilder builder, boolean value) { + builder.setStatsEnabled(value); + } + + private InternalOkHttpServerBuilder() {} +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 68ff6dbd787..a3f99b67cfa 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -37,25 +37,28 @@ import io.grpc.internal.AtomicBackoff; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcUtil; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.ManagedChannelImplBuilder; import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; -import io.grpc.internal.SharedResourceHolder; +import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.internal.SharedResourcePool; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.internal.CipherSuite; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.Platform; import io.grpc.okhttp.internal.TlsVersion; +import io.grpc.util.CertificateUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.GeneralSecurityException; import java.security.KeyStore; -import java.security.cert.CertificateFactory; +import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.EnumSet; import java.util.Set; @@ -71,6 +74,7 @@ import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; @@ -100,7 +104,7 @@ private enum NegotiationType { PLAINTEXT } - @VisibleForTesting + // @VisibleForTesting static final ConnectionSpec INTERNAL_DEFAULT_CONNECTION_SPEC = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) .cipherSuites( @@ -137,6 +141,8 @@ public void close(Executor executor) { ((ExecutorService) executor).shutdown(); } }; + static final ObjectPool DEFAULT_TRANSPORT_EXECUTOR_POOL = + SharedResourcePool.forResource(SHARED_EXECUTOR); /** Creates a new builder for the given server host and port. */ public static OkHttpChannelBuilder forAddress(String host, int port) { @@ -168,8 +174,9 @@ public static OkHttpChannelBuilder forTarget(String target, ChannelCredentials c return new OkHttpChannelBuilder(target, creds, result.callCredentials, result.factory); } - private Executor transportExecutor; - private ScheduledExecutorService scheduledExecutorService; + private ObjectPool transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; + private ObjectPool scheduledExecutorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); private SocketFactory socketFactory; private SSLSocketFactory sslSocketFactory; @@ -247,7 +254,11 @@ OkHttpChannelBuilder setTransportTracerFactory(TransportTracer.Factory transport * to shutdown the executor when appropriate. */ public OkHttpChannelBuilder transportExecutor(@Nullable Executor transportExecutor) { - this.transportExecutor = transportExecutor; + if (transportExecutor == null) { + this.transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; + } else { + this.transportExecutorPool = new FixedObjectPool<>(transportExecutor); + } return this; } @@ -468,8 +479,8 @@ public OkHttpChannelBuilder useTransportSecurity() { */ public OkHttpChannelBuilder scheduledExecutorService( ScheduledExecutorService scheduledExecutorService) { - this.scheduledExecutorService = - checkNotNull(scheduledExecutorService, "scheduledExecutorService"); + this.scheduledExecutorServicePool = + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); return this; } @@ -505,11 +516,11 @@ public OkHttpChannelBuilder maxInboundMessageSize(int max) { return this; } - ClientTransportFactory buildTransportFactory() { + OkHttpTransportFactory buildTransportFactory() { boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED; return new OkHttpTransportFactory( - transportExecutor, - scheduledExecutorService, + transportExecutorPool, + scheduledExecutorServicePool, socketFactory, createSslSocketFactory(), hostnameVerifier, @@ -588,7 +599,16 @@ static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) { if (tlsCreds.getKeyManagers() != null) { km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]); } else if (tlsCreds.getPrivateKey() != null) { - return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager"); + if (tlsCreds.getPrivateKeyPassword() != null) { + return SslSocketFactoryResult.error("byte[]-based private key with password unsupported. " + + "Use unencrypted file or KeyManager"); + } + try { + km = createKeyManager(tlsCreds.getCertificateChain(), tlsCreds.getPrivateKey()); + } catch (GeneralSecurityException gse) { + log.log(Level.FINE, "Exception loading private key from credential", gse); + return SslSocketFactoryResult.error("Unable to load private key: " + gse.getMessage()); + } } // else don't have a client cert TrustManager[] tm = null; if (tlsCreds.getTrustManagers() != null) { @@ -643,6 +663,39 @@ static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) { } } + static KeyManager[] createKeyManager(byte[] certChain, byte[] privateKey) + throws GeneralSecurityException { + X509Certificate[] chain; + ByteArrayInputStream inCertChain = new ByteArrayInputStream(certChain); + try { + chain = CertificateUtils.getX509Certificates(inCertChain); + } finally { + GrpcUtil.closeQuietly(inCertChain); + } + PrivateKey key; + ByteArrayInputStream inPrivateKey = new ByteArrayInputStream(privateKey); + try { + key = CertificateUtils.getPrivateKey(inPrivateKey); + } catch (IOException uee) { + throw new GeneralSecurityException("Unable to decode private key", uee); + } finally { + GrpcUtil.closeQuietly(inPrivateKey); + } + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + ks.setKeyEntry("key", key, new char[0], chain); + + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(ks, new char[0]); + return keyManagerFactory.getKeyManagers(); + } + static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); try { @@ -651,15 +704,17 @@ static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurit // Shouldn't really happen, as we're not loading any data. throw new GeneralSecurityException(ex); } - CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate[] certs; ByteArrayInputStream in = new ByteArrayInputStream(rootCerts); try { - X509Certificate cert = (X509Certificate) cf.generateCertificate(in); - X500Principal principal = cert.getSubjectX500Principal(); - ks.setCertificateEntry(principal.getName("RFC2253"), cert); + certs = CertificateUtils.getX509Certificates(in); } finally { GrpcUtil.closeQuietly(in); } + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); @@ -712,30 +767,30 @@ public SslSocketFactoryResult withCallCredentials(CallCredentials callCreds) { */ @Internal static final class OkHttpTransportFactory implements ClientTransportFactory { - private final Executor executor; - private final boolean usingSharedExecutor; - private final boolean usingSharedScheduler; - private final TransportTracer.Factory transportTracerFactory; - private final SocketFactory socketFactory; - @Nullable private final SSLSocketFactory sslSocketFactory; + private final ObjectPool executorPool; + final Executor executor; + private final ObjectPool scheduledExecutorServicePool; + final ScheduledExecutorService scheduledExecutorService; + final TransportTracer.Factory transportTracerFactory; + final SocketFactory socketFactory; + @Nullable final SSLSocketFactory sslSocketFactory; @Nullable - private final HostnameVerifier hostnameVerifier; - private final ConnectionSpec connectionSpec; - private final int maxMessageSize; + final HostnameVerifier hostnameVerifier; + final ConnectionSpec connectionSpec; + final int maxMessageSize; private final boolean enableKeepAlive; private final long keepAliveTimeNanos; private final AtomicBackoff keepAliveBackoff; private final long keepAliveTimeoutNanos; - private final int flowControlWindow; + final int flowControlWindow; private final boolean keepAliveWithoutCalls; - private final int maxInboundMetadataSize; - private final ScheduledExecutorService timeoutService; - private final boolean useGetForSafeMethods; + final int maxInboundMetadataSize; + final boolean useGetForSafeMethods; private boolean closed; private OkHttpTransportFactory( - Executor executor, - @Nullable ScheduledExecutorService timeoutService, + ObjectPool executorPool, + ObjectPool scheduledExecutorServicePool, @Nullable SocketFactory socketFactory, @Nullable SSLSocketFactory sslSocketFactory, @Nullable HostnameVerifier hostnameVerifier, @@ -749,9 +804,10 @@ private OkHttpTransportFactory( int maxInboundMetadataSize, TransportTracer.Factory transportTracerFactory, boolean useGetForSafeMethods) { - usingSharedScheduler = timeoutService == null; - this.timeoutService = usingSharedScheduler - ? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService; + this.executorPool = executorPool; + this.executor = executorPool.getObject(); + this.scheduledExecutorServicePool = scheduledExecutorServicePool; + this.scheduledExecutorService = scheduledExecutorServicePool.getObject(); this.socketFactory = socketFactory; this.sslSocketFactory = sslSocketFactory; this.hostnameVerifier = hostnameVerifier; @@ -766,15 +822,8 @@ private OkHttpTransportFactory( this.maxInboundMetadataSize = maxInboundMetadataSize; this.useGetForSafeMethods = useGetForSafeMethods; - usingSharedExecutor = executor == null; this.transportTracerFactory = Preconditions.checkNotNull(transportTracerFactory, "transportTracerFactory"); - if (usingSharedExecutor) { - // The executor was unspecified, using the shared executor. - this.executor = SharedResourceHolder.get(SHARED_EXECUTOR); - } else { - this.executor = executor; - } } @Override @@ -793,22 +842,13 @@ public void run() { InetSocketAddress inetSocketAddr = (InetSocketAddress) addr; // TODO(carl-mastrangelo): Pass channelLogger in. OkHttpClientTransport transport = new OkHttpClientTransport( + this, inetSocketAddr, options.getAuthority(), options.getUserAgent(), options.getEagAttributes(), - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - connectionSpec, - maxMessageSize, - flowControlWindow, options.getHttpConnectProxiedSocketAddress(), - tooManyPingsRunnable, - maxInboundMetadataSize, - transportTracerFactory.create(), - useGetForSafeMethods); + tooManyPingsRunnable); if (enableKeepAlive) { transport.enableKeepAlive( true, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls); @@ -818,7 +858,7 @@ public void run() { @Override public ScheduledExecutorService getScheduledExecutorService() { - return timeoutService; + return scheduledExecutorService; } @Nullable @@ -830,8 +870,8 @@ public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials ch return null; } ClientTransportFactory factory = new OkHttpTransportFactory( - executor, - timeoutService, + executorPool, + scheduledExecutorServicePool, socketFactory, result.factory, hostnameVerifier, @@ -855,13 +895,8 @@ public void close() { } closed = true; - if (usingSharedScheduler) { - SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timeoutService); - } - - if (usingSharedExecutor) { - SharedResourceHolder.release(SHARED_EXECUTOR, executor); - } + executorPool.returnObject(executor); + scheduledExecutorServicePool.returnObject(scheduledExecutorService); } } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java index 19f99d05029..17a2512a66a 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java @@ -20,6 +20,10 @@ import io.grpc.Internal; import io.grpc.InternalServiceProviders; import io.grpc.ManagedChannelProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** * Provider for {@link OkHttpChannelBuilder} instances. @@ -57,4 +61,9 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder( target, creds, result.callCredentials, result.factory)); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index baf659a6278..46396b2a41f 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -53,8 +53,6 @@ class OkHttpClientStream extends AbstractClientStream { private final String userAgent; private final StatsTraceContext statsTraceCtx; private String authority; - private Object outboundFlowState; - private volatile int id = ABSENT_ID; private final TransportState state; private final Sink sink = new Sink(); private final Attributes attributes; @@ -120,10 +118,6 @@ public MethodDescriptor.MethodType getType() { return method.getType(); } - public int id() { - return id; - } - /** * Returns whether the stream uses GET. This is not known until after {@link Sink#writeHeaders} is * invoked. @@ -198,7 +192,8 @@ public void cancel(Status reason) { } } - class TransportState extends Http2ClientStreamTransportState { + class TransportState extends Http2ClientStreamTransportState + implements OutboundFlowController.Stream { private final int initialWindowSize; private final Object lock; @GuardedBy("lock") @@ -223,6 +218,9 @@ class TransportState extends Http2ClientStreamTransportState { @GuardedBy("lock") private boolean canStart = true; private final Tag tag; + @GuardedBy("lock") + private OutboundFlowController.StreamState outboundFlowState; + private int id = ABSENT_ID; public TransportState( int maxMessageSize, @@ -249,6 +247,7 @@ public TransportState( public void start(int streamId) { checkState(id == ABSENT_ID, "the stream has been started with id %s", streamId); id = streamId; + outboundFlowState = outboundFlow.createState(this, streamId); // TODO(b/145386688): This access should be guarded by 'OkHttpClientStream.this.state.lock'; // instead found: 'this.lock' state.onStreamAllocated(); @@ -260,7 +259,9 @@ public void start(int streamId) { requestHeaders = null; if (pendingData.size() > 0) { - outboundFlow.data(pendingDataHasEndOfStream, id, pendingData, flushPendingData); + outboundFlow.data( + pendingDataHasEndOfStream, outboundFlowState, pendingData, flushPendingData); + } canStart = false; } @@ -396,7 +397,7 @@ private void sendBuffer(Buffer buffer, boolean endOfStream, boolean flush) { checkState(id() != ABSENT_ID, "streamId should be set"); // If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is // properly chunked. - outboundFlow.data(endOfStream, id(), buffer, flush); + outboundFlow.data(endOfStream, outboundFlowState, buffer, flush); } } @@ -419,13 +420,15 @@ private void streamReady(Metadata metadata, String path) { Tag tag() { return tag; } - } - void setOutboundFlowState(Object outboundFlowState) { - this.outboundFlowState = outboundFlowState; - } + int id() { + return id; + } - Object getOutboundFlowState() { - return outboundFlowState; + OutboundFlowController.StreamState getOutboundFlowState() { + synchronized (lock) { + return outboundFlowState; + } + } } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index cf7cae19d8a..6eaaf832a6b 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -17,7 +17,6 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkState; -import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; import static io.grpc.okhttp.Utils.DEFAULT_WINDOW_SIZE; import static io.grpc.okhttp.Utils.DEFAULT_WINDOW_UPDATE_RATIO; @@ -28,10 +27,6 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import com.squareup.okhttp.Credentials; -import com.squareup.okhttp.HttpUrl; -import com.squareup.okhttp.Request; -import com.squareup.okhttp.internal.http.StatusLine; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -56,11 +51,12 @@ import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager.ClientKeepAlivePinger; import io.grpc.internal.SerializingExecutor; -import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; import io.grpc.okhttp.internal.ConnectionSpec; +import io.grpc.okhttp.internal.Credentials; +import io.grpc.okhttp.internal.StatusLine; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; import io.grpc.okhttp.internal.framed.FrameWriter; @@ -69,6 +65,8 @@ import io.grpc.okhttp.internal.framed.Http2; import io.grpc.okhttp.internal.framed.Settings; import io.grpc.okhttp.internal.framed.Variant; +import io.grpc.okhttp.internal.proxy.HttpUrl; +import io.grpc.okhttp.internal.proxy.Request; import io.perfmark.PerfMark; import java.io.EOFException; import java.io.IOException; @@ -82,6 +80,7 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; @@ -107,10 +106,10 @@ /** * A okhttp-based {@link ConnectionClientTransport} implementation. */ -class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler { +class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler, + OutboundFlowController.Transport { private static final Map ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap(); private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); - private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; private static Map buildErrorCodeToStatusMap() { Map errorToStatus = new EnumMap<>(ErrorCode.class); @@ -148,9 +147,8 @@ private static Map buildErrorCodeToStatusMap() { // Returns new unstarted stopwatches private final Supplier stopwatchFactory; private final int initialWindowSize; + private final Variant variant; private Listener listener; - private FrameReader testFrameReader; - private OkHttpFrameLogger testFrameLogger; @GuardedBy("lock") private ExceptionHandlingFrameWriter frameWriter; private OutboundFlowController outboundFlow; @@ -163,6 +161,7 @@ private static Map buildErrorCodeToStatusMap() { private final Executor executor; // Wrap on executor, to guarantee some operations be executed serially. private final SerializingExecutor serializingExecutor; + private final ScheduledExecutorService scheduler; private final int maxMessageSize; private int connectionUnacknowledgedBytesRead; private ClientFrameHandler clientFrameHandler; @@ -192,8 +191,6 @@ private static Map buildErrorCodeToStatusMap() { @GuardedBy("lock") private final Deque pendingStreams = new LinkedList<>(); private final ConnectionSpec connectionSpec; - private FrameWriter testFrameWriter; - private ScheduledExecutorService scheduler; private KeepAliveManager keepAliveManager; private boolean enableKeepAlive; private long keepAliveTimeNanos; @@ -224,51 +221,72 @@ protected void handleNotInUse() { @Nullable final HttpConnectProxiedSocketAddress proxiedAddr; + @VisibleForTesting + int proxySocketTimeout = 30000; + // The following fields should only be used for test. Runnable connectingCallback; SettableFuture connectedFuture; - OkHttpClientTransport( + public OkHttpClientTransport( + OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, InetSocketAddress address, String authority, @Nullable String userAgent, Attributes eagAttrs, - Executor executor, - @Nullable SocketFactory socketFactory, - @Nullable SSLSocketFactory sslSocketFactory, - @Nullable HostnameVerifier hostnameVerifier, - ConnectionSpec connectionSpec, - int maxMessageSize, - int initialWindowSize, @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable, - int maxInboundMetadataSize, - TransportTracer transportTracer, - boolean useGetForSafeMethods) { + Runnable tooManyPingsRunnable) { + this( + transportFactory, + address, + authority, + userAgent, + eagAttrs, + GrpcUtil.STOPWATCH_SUPPLIER, + new Http2(), + proxiedAddr, + tooManyPingsRunnable); + } + + private OkHttpClientTransport( + OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Supplier stopwatchFactory, + Variant variant, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; - this.maxMessageSize = maxMessageSize; - this.initialWindowSize = initialWindowSize; - this.executor = Preconditions.checkNotNull(executor, "executor"); - serializingExecutor = new SerializingExecutor(executor); + this.maxMessageSize = transportFactory.maxMessageSize; + this.initialWindowSize = transportFactory.flowControlWindow; + this.executor = Preconditions.checkNotNull(transportFactory.executor, "executor"); + serializingExecutor = new SerializingExecutor(transportFactory.executor); + this.scheduler = Preconditions.checkNotNull( + transportFactory.scheduledExecutorService, "scheduledExecutorService"); // Client initiated streams are odd, server initiated ones are even. Server should not need to // use it. We start clients at 3 to avoid conflicting with HTTP negotiation. nextStreamId = 3; - this.socketFactory = socketFactory == null ? SocketFactory.getDefault() : socketFactory; - this.sslSocketFactory = sslSocketFactory; - this.hostnameVerifier = hostnameVerifier; - this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); - this.stopwatchFactory = GrpcUtil.STOPWATCH_SUPPLIER; + this.socketFactory = transportFactory.socketFactory == null + ? SocketFactory.getDefault() : transportFactory.socketFactory; + this.sslSocketFactory = transportFactory.sslSocketFactory; + this.hostnameVerifier = transportFactory.hostnameVerifier; + this.connectionSpec = Preconditions.checkNotNull( + transportFactory.connectionSpec, "connectionSpec"); + this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); + this.variant = Preconditions.checkNotNull(variant, "variant"); this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); this.proxiedAddr = proxiedAddr; this.tooManyPingsRunnable = Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); - this.maxInboundMetadataSize = maxInboundMetadataSize; - this.transportTracer = Preconditions.checkNotNull(transportTracer); + this.maxInboundMetadataSize = transportFactory.maxInboundMetadataSize; + this.transportTracer = transportFactory.transportTracerFactory.create(); this.logId = InternalLogId.allocate(getClass(), address.toString()); this.attributes = Attributes.newBuilder() .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); - this.useGetForSafeMethods = useGetForSafeMethods; + this.useGetForSafeMethods = transportFactory.useGetForSafeMethods; initTransportTracer(); } @@ -277,45 +295,25 @@ protected void handleNotInUse() { */ @VisibleForTesting OkHttpClientTransport( + OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, String userAgent, - Executor executor, - FrameReader frameReader, - FrameWriter testFrameWriter, - OkHttpFrameLogger testFrameLogger, - int nextStreamId, - Socket socket, Supplier stopwatchFactory, + Variant variant, @Nullable Runnable connectingCallback, SettableFuture connectedFuture, - int maxMessageSize, - int initialWindowSize, - Runnable tooManyPingsRunnable, - TransportTracer transportTracer) { - useGetForSafeMethods = false; - address = null; - this.maxMessageSize = maxMessageSize; - this.initialWindowSize = initialWindowSize; - defaultAuthority = "notarealauthority:80"; - this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent); - this.executor = Preconditions.checkNotNull(executor, "executor"); - serializingExecutor = new SerializingExecutor(executor); - this.socketFactory = SocketFactory.getDefault(); - this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader"); - this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter"); - this.testFrameLogger = Preconditions.checkNotNull(testFrameLogger, "testFrameLogger"); - this.socket = Preconditions.checkNotNull(socket, "socket"); - this.nextStreamId = nextStreamId; - this.stopwatchFactory = stopwatchFactory; - this.connectionSpec = null; + Runnable tooManyPingsRunnable) { + this( + transportFactory, + new InetSocketAddress("127.0.0.1", 80), + "notarealauthority:80", + userAgent, + Attributes.EMPTY, + stopwatchFactory, + variant, + null, + tooManyPingsRunnable); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); - this.proxiedAddr = null; - this.tooManyPingsRunnable = - Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); - this.maxInboundMetadataSize = Integer.MAX_VALUE; - this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); - this.logId = InternalLogId.allocate(getClass(), String.valueOf(socket.getInetAddress())); - initTransportTracer(); } // sslSocketFactory is set to null when use plaintext. @@ -329,8 +327,10 @@ private void initTransportTracer() { @Override public TransportTracer.FlowControlWindows read() { synchronized (lock) { - long local = -1; // okhttp does not track the local window size - long remote = outboundFlow == null ? -1 : outboundFlow.windowUpdate(null, 0); + long local = outboundFlow == null ? -1 : outboundFlow.windowUpdate(null, 0); + // connectionUnacknowledgedBytesRead is only readable by ClientFrameHandler, so we + // provide a lower bound. + long remote = (long) (initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO); return new TransportTracer.FlowControlWindows(local, remote); } } @@ -349,10 +349,6 @@ void enableKeepAlive(boolean enable, long keepAliveTimeNanos, this.keepAliveWithoutCalls = keepAliveWithoutCalls; } - private boolean isForTest() { - return address == null; - } - @Override public void ping(final PingCallback callback, Executor executor) { long data = 0; @@ -432,7 +428,7 @@ void streamReadyToStart(OkHttpClientStream clientStream) { @GuardedBy("lock") private void startStream(OkHttpClientStream stream) { Preconditions.checkState( - stream.id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); + stream.transportState().id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); streams.put(nextStreamId, stream); setInUse(stream); // TODO(b/145386688): This access should be guarded by 'stream.transportState().lock'; instead @@ -482,41 +478,22 @@ public Runnable start(Listener listener) { this.listener = Preconditions.checkNotNull(listener, "listener"); if (enableKeepAlive) { - scheduler = SharedResourceHolder.get(TIMER_SERVICE); keepAliveManager = new KeepAliveManager( new ClientKeepAlivePinger(this), scheduler, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls); keepAliveManager.onTransportStarted(); } - if (isForTest()) { - synchronized (lock) { - frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter, - testFrameLogger); - outboundFlow = new OutboundFlowController(OkHttpClientTransport.this, frameWriter); - } - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - if (connectingCallback != null) { - connectingCallback.run(); - } - clientFrameHandler = new ClientFrameHandler(testFrameReader, testFrameLogger); - executor.execute(clientFrameHandler); - synchronized (lock) { - maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); - } - connectedFuture.set(null); - } - }); - return null; - } - final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); - final Variant variant = new Http2(); - FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true); + int maxQueuedControlFrames = 10000; + final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this, maxQueuedControlFrames); + FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter( + variant.newWriter(Okio.buffer(asyncSink), true)); synchronized (lock) { + // Handle FrameWriter exceptions centrally, since there are many callers. Note that errors + // coming from rawFrameWriter are generally broken invariants/bugs, as AsyncSink does not + // propagate syscall errors through the FrameWriter. But we handle the AsyncSink failures with + // the same TransportExceptionHandler instance so it is all mixed back together. frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter); outboundFlow = new OutboundFlowController(this, frameWriter); } @@ -616,6 +593,9 @@ sslSocketFactory, hostnameVerifier, sock, getOverridenHost(), getOverridenPort() serializingExecutor.execute(new Runnable() { @Override public void run() { + if (connectingCallback != null) { + connectingCallback.run(); + } // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it // may send goAway immediately. executor.execute(clientFrameHandler); @@ -623,6 +603,9 @@ public void run() { maxConcurrentStreams = Integer.MAX_VALUE; startPendingStreams(); } + if (connectedFuture != null) { + connectedFuture.set(null); + } } }); return null; @@ -631,8 +614,7 @@ public void run() { /** * Should only be called once when the transport is first established. */ - @VisibleForTesting - void sendConnectionPrefaceAndSettings() { + private void sendConnectionPrefaceAndSettings() { synchronized (lock) { frameWriter.connectionPreface(); Settings settings = new Settings(); @@ -647,8 +629,8 @@ void sendConnectionPrefaceAndSettings() { private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress, String proxyUsername, String proxyPassword) throws StatusException { + Socket sock = null; try { - Socket sock; // The proxy address may not be resolved if (proxyAddress.getAddress() != null) { sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort()); @@ -657,6 +639,9 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort()); } sock.setTcpNoDelay(true); + // A socket timeout is needed because lost network connectivity while reading from the proxy, + // can cause reading from the socket to hang. + sock.setSoTimeout(proxySocketTimeout); Source source = Okio.source(sock); BufferedSink sink = Okio.buffer(Okio.sink(sock)); @@ -664,7 +649,8 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres // Prepare headers and request method line Request proxyRequest = createHttpProxyRequest(address, proxyUsername, proxyPassword); HttpUrl url = proxyRequest.httpUrl(); - String requestLine = String.format("CONNECT %s:%d HTTP/1.1", url.host(), url.port()); + String requestLine = + String.format(Locale.US, "CONNECT %s:%d HTTP/1.1", url.host(), url.port()); // Write request to socket sink.writeUtf8(requestLine).writeUtf8("\r\n"); @@ -696,25 +682,32 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres // ignored } String message = String.format( + Locale.US, "Response returned from proxy was not successful (expected 2xx, got %d %s). " + "Response body:\n%s", statusLine.code, statusLine.message, body.readUtf8()); throw Status.UNAVAILABLE.withDescription(message).asException(); } + // As the socket will be used for RPCs from here on, we want the socket timeout back to zero. + sock.setSoTimeout(0); return sock; } catch (IOException e) { + if (sock != null) { + GrpcUtil.closeQuietly(sock); + } throw Status.UNAVAILABLE.withDescription("Failed trying to connect with proxy").withCause(e) .asException(); } } private Request createHttpProxyRequest(InetSocketAddress address, String proxyUsername, - String proxyPassword) { + String proxyPassword) { HttpUrl tunnelUrl = new HttpUrl.Builder() .scheme("https") .host(address.getHostName()) .port(address.getPort()) .build(); + Request.Builder request = new Request.Builder() .url(tunnelUrl) .header("Host", tunnelUrl.host() + ":" + tunnelUrl.port()) @@ -831,9 +824,16 @@ public Attributes getAttributes() { /** * Gets all active streams as an array. */ - OkHttpClientStream[] getActiveStreams() { + @Override + public OutboundFlowController.StreamState[] getActiveStreams() { synchronized (lock) { - return streams.values().toArray(EMPTY_STREAM_ARRAY); + OutboundFlowController.StreamState[] flowStreams = + new OutboundFlowController.StreamState[streams.size()]; + int i = 0; + for (OkHttpClientStream stream : streams.values()) { + flowStreams[i++] = stream.transportState().getOutboundFlowState(); + } + return flowStreams; } } @@ -854,6 +854,13 @@ int getPendingStreamSize() { } } + @VisibleForTesting + void setNextStreamId(int nextStreamId) { + synchronized (lock) { + this.nextStreamId = nextStreamId; + } + } + /** * Finish all active streams due to an IOException, then close the transport. */ @@ -907,7 +914,7 @@ private void startGoAway(int lastKnownStreamId, ErrorCode errorCode, Status stat } /** - * Called when a stream is closed, we do things like: + * Called when a stream is closed. We do things like: *
    *
  • Removing the stream from the map. *
  • Optionally reporting the status. @@ -966,8 +973,6 @@ private void stopIfNecessary() { if (keepAliveManager != null) { keepAliveManager.onTransportTermination(); - // KeepAliveManager should stop using the scheduler after onTransportTermination gets called. - scheduler = SharedResourceHolder.release(TIMER_SERVICE, scheduler); } if (ping != null) { @@ -1080,21 +1085,15 @@ public ListenableFuture getStats() { /** * Runnable which reads frames and dispatches them to in flight calls. */ - @VisibleForTesting class ClientFrameHandler implements FrameReader.Handler, Runnable { - private final OkHttpFrameLogger logger; + private final OkHttpFrameLogger logger = + new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class); FrameReader frameReader; boolean firstSettings = true; ClientFrameHandler(FrameReader frameReader) { - this(frameReader, new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class)); - } - - @VisibleForTesting - ClientFrameHandler(FrameReader frameReader, OkHttpFrameLogger frameLogger) { this.frameReader = frameReader; - logger = frameLogger; } @Override @@ -1149,7 +1148,7 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length if (stream == null) { if (mayHaveCreatedStream(streamId)) { synchronized (lock) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + frameWriter.rstStream(streamId, ErrorCode.STREAM_CLOSED); } in.skip(length); } else { @@ -1200,6 +1199,7 @@ public void headers(boolean outFinished, if (metadataSize > maxInboundMetadataSize) { failedStatus = Status.RESOURCE_EXHAUSTED.withDescription( String.format( + Locale.US, "Response %s metadata larger than %d: %d", inFinished ? "trailer" : "header", maxInboundMetadataSize, @@ -1210,7 +1210,7 @@ public void headers(boolean outFinished, OkHttpClientStream stream = streams.get(streamId); if (stream == null) { if (mayHaveCreatedStream(streamId)) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + frameWriter.rstStream(streamId, ErrorCode.STREAM_CLOSED); } else { unknownStream = true; } @@ -1315,8 +1315,9 @@ public void ping(boolean ack, int payload1, int payload2) { p = ping; ping = null; } else { - log.log(Level.WARNING, String.format("Received unexpected ping ack. " - + "Expecting %d, got %d", ping.payload(), ackPayload)); + log.log(Level.WARNING, String.format( + Locale.US, "Received unexpected ping ack. Expecting %d, got %d", + ping.payload(), ackPayload)); } } else { log.warning("Received unexpected ping ack. No ping outstanding"); @@ -1389,7 +1390,7 @@ public void windowUpdate(int streamId, long delta) { OkHttpClientStream stream = streams.get(streamId); if (stream != null) { - outboundFlow.windowUpdate(stream, (int) delta); + outboundFlow.windowUpdate(stream.transportState().getOutboundFlowState(), (int) delta); } else if (!mayHaveCreatedStream(streamId)) { unknownStream = true; } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServer.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServer.java new file mode 100644 index 00000000000..f63950e4a03 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServer.java @@ -0,0 +1,189 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.InternalChannelz; +import io.grpc.InternalInstrumented; +import io.grpc.InternalLogId; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerListener; +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ServerSocketFactory; + +final class OkHttpServer implements InternalServer { + private static final Logger log = Logger.getLogger(OkHttpServer.class.getName()); + + private final SocketAddress originalListenAddress; + private final ServerSocketFactory socketFactory; + private final ObjectPool transportExecutorPool; + private final ObjectPool scheduledExecutorServicePool; + private final OkHttpServerTransport.Config transportConfig; + private final InternalChannelz channelz; + private ServerSocket serverSocket; + private SocketAddress actualListenAddress; + private InternalInstrumented listenInstrumented; + private Executor transportExecutor; + private ScheduledExecutorService scheduledExecutorService; + private ServerListener listener; + private boolean shutdown; + + public OkHttpServer( + OkHttpServerBuilder builder, + List streamTracerFactories, + InternalChannelz channelz) { + this.originalListenAddress = Preconditions.checkNotNull(builder.listenAddress, "listenAddress"); + this.socketFactory = Preconditions.checkNotNull(builder.socketFactory, "socketFactory"); + this.transportExecutorPool = + Preconditions.checkNotNull(builder.transportExecutorPool, "transportExecutorPool"); + this.scheduledExecutorServicePool = + Preconditions.checkNotNull( + builder.scheduledExecutorServicePool, "scheduledExecutorServicePool"); + this.transportConfig = new OkHttpServerTransport.Config(builder, streamTracerFactories); + this.channelz = Preconditions.checkNotNull(channelz, "channelz"); + } + + @Override + public void start(ServerListener listener) throws IOException { + this.listener = Preconditions.checkNotNull(listener, "listener"); + ServerSocket serverSocket = socketFactory.createServerSocket(); + try { + serverSocket.bind(originalListenAddress); + } catch (IOException t) { + serverSocket.close(); + throw t; + } + + this.serverSocket = serverSocket; + this.actualListenAddress = serverSocket.getLocalSocketAddress(); + this.listenInstrumented = new ListenSocket(serverSocket); + this.transportExecutor = transportExecutorPool.getObject(); + // Keep reference alive to avoid frequent re-creation by server transports + this.scheduledExecutorService = scheduledExecutorServicePool.getObject(); + channelz.addListenSocket(this.listenInstrumented); + transportExecutor.execute(this::acceptConnections); + } + + private void acceptConnections() { + try { + while (true) { + Socket socket; + try { + socket = serverSocket.accept(); + } catch (IOException ex) { + if (shutdown) { + break; + } + throw ex; + } + OkHttpServerTransport transport = new OkHttpServerTransport(transportConfig, socket); + transport.start(listener.transportCreated(transport)); + } + } catch (Throwable t) { + log.log(Level.SEVERE, "Accept loop failed", t); + } + listener.serverShutdown(); + } + + @Override + public void shutdown() { + if (shutdown) { + return; + } + shutdown = true; + + if (serverSocket == null) { + return; + } + channelz.removeListenSocket(this.listenInstrumented); + try { + serverSocket.close(); + } catch (IOException ex) { + log.log(Level.WARNING, "Failed closing server socket", serverSocket); + } + transportExecutor = transportExecutorPool.returnObject(transportExecutor); + scheduledExecutorService = scheduledExecutorServicePool.returnObject(scheduledExecutorService); + } + + @Override + public SocketAddress getListenSocketAddress() { + return actualListenAddress; + } + + @Override + public InternalInstrumented getListenSocketStats() { + return listenInstrumented; + } + + @Override + public List getListenSocketAddresses() { + return Collections.singletonList(getListenSocketAddress()); + } + + @Override + public List> getListenSocketStatsList() { + return Collections.singletonList(getListenSocketStats()); + } + + private static final class ListenSocket + implements InternalInstrumented { + private final InternalLogId id; + private final ServerSocket socket; + + public ListenSocket(ServerSocket socket) { + this.socket = socket; + this.id = InternalLogId.allocate(getClass(), String.valueOf(socket.getLocalSocketAddress())); + } + + @Override + public ListenableFuture getStats() { + return Futures.immediateFuture(new InternalChannelz.SocketStats( + /*data=*/ null, + socket.getLocalSocketAddress(), + /*remote=*/ null, + new InternalChannelz.SocketOptions.Builder().build(), + /*security=*/ null)); + } + + @Override + public InternalLogId getLogId() { + return id; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("logId", id.getId()) + .add("socket", socket) + .toString(); + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java new file mode 100644 index 00000000000..45d6b9efc54 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -0,0 +1,553 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.DoNotCall; +import io.grpc.ChoiceServerCredentials; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingServerBuilder; +import io.grpc.InsecureServerCredentials; +import io.grpc.Internal; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.ServerStreamTracer; +import io.grpc.TlsServerCredentials; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerImplBuilder; +import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TransportTracer; +import io.grpc.okhttp.internal.Platform; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.security.GeneralSecurityException; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ServerSocketFactory; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; + +/** + * Build servers with the OkHttp transport. + * + * @since 1.49.0 + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") +public final class OkHttpServerBuilder extends ForwardingServerBuilder { + private static final Logger log = Logger.getLogger(OkHttpServerBuilder.class.getName()); + private static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535; + + static final long MAX_CONNECTION_IDLE_NANOS_DISABLED = Long.MAX_VALUE; + private static final long MIN_MAX_CONNECTION_IDLE_NANO = TimeUnit.SECONDS.toNanos(1L); + static final long MAX_CONNECTION_AGE_NANOS_DISABLED = Long.MAX_VALUE; + static final long MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE = Long.MAX_VALUE; + private static final long MIN_MAX_CONNECTION_AGE_NANO = TimeUnit.SECONDS.toNanos(1L); + + private static final long AS_LARGE_AS_INFINITE = TimeUnit.DAYS.toNanos(1000L); + private static final ObjectPool DEFAULT_TRANSPORT_EXECUTOR_POOL = + OkHttpChannelBuilder.DEFAULT_TRANSPORT_EXECUTOR_POOL; + + /** + * Always throws, to shadow {@code ServerBuilder.forPort()}. + * + * @deprecated Use {@link #forPort(int, ServerCredentials)} instead + */ + @DoNotCall("Always throws. Use forPort(int, ServerCredentials) instead") + @Deprecated + public static OkHttpServerBuilder forPort(int port) { + throw new UnsupportedOperationException(); + } + + /** + * Creates a builder for a server listening on {@code port}. + */ + public static OkHttpServerBuilder forPort(int port, ServerCredentials creds) { + return forPort(new InetSocketAddress(port), creds); + } + + /** + * Creates a builder for a server listening on {@code address}. + */ + public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentials creds) { + HandshakerSocketFactoryResult result = handshakerSocketFactoryFrom(creds); + if (result.error != null) { + throw new IllegalArgumentException(result.error); + } + return new OkHttpServerBuilder(address, result.factory); + } + + final ServerImplBuilder serverImplBuilder = new ServerImplBuilder(this::buildTransportServers); + final SocketAddress listenAddress; + final HandshakerSocketFactory handshakerSocketFactory; + TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); + + ObjectPool transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; + ObjectPool scheduledExecutorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + + ServerSocketFactory socketFactory = ServerSocketFactory.getDefault(); + long keepAliveTimeNanos = GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; + long keepAliveTimeoutNanos = GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; + int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; + int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE; + int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; + long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; + boolean permitKeepAliveWithoutCalls; + long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5); + long maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED; + long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; + + @VisibleForTesting + OkHttpServerBuilder( + SocketAddress address, HandshakerSocketFactory handshakerSocketFactory) { + this.listenAddress = Preconditions.checkNotNull(address, "address"); + this.handshakerSocketFactory = + Preconditions.checkNotNull(handshakerSocketFactory, "handshakerSocketFactory"); + } + + @Internal + @Override + protected ServerBuilder delegate() { + return serverImplBuilder; + } + + // @VisibleForTesting + OkHttpServerBuilder setTransportTracerFactory(TransportTracer.Factory transportTracerFactory) { + this.transportTracerFactory = transportTracerFactory; + return this; + } + + /** + * Override the default executor necessary for internal transport use. + * + *

    The channel does not take ownership of the given executor. It is the caller' responsibility + * to shutdown the executor when appropriate. + */ + public OkHttpServerBuilder transportExecutor(Executor transportExecutor) { + if (transportExecutor == null) { + this.transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; + } else { + this.transportExecutorPool = new FixedObjectPool<>(transportExecutor); + } + return this; + } + + /** + * Override the default {@link ServerSocketFactory} used to listen. If the socket factory is not + * set or set to null, a default one will be used. + */ + public OkHttpServerBuilder socketFactory(ServerSocketFactory socketFactory) { + if (socketFactory == null) { + this.socketFactory = ServerSocketFactory.getDefault(); + } else { + this.socketFactory = socketFactory; + } + return this; + } + + /** + * Sets the time without read activity before sending a keepalive ping. An unreasonably small + * value might be increased, and {@code Long.MAX_VALUE} nano seconds or an unreasonably large + * value will disable keepalive. Defaults to two hours. + * + * @throws IllegalArgumentException if time is not positive + */ + @Override + public OkHttpServerBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + Preconditions.checkArgument(keepAliveTime > 0L, "keepalive time must be positive"); + keepAliveTimeNanos = timeUnit.toNanos(keepAliveTime); + keepAliveTimeNanos = KeepAliveManager.clampKeepAliveTimeInNanos(keepAliveTimeNanos); + if (keepAliveTimeNanos >= AS_LARGE_AS_INFINITE) { + // Bump keepalive time to infinite. This disables keepalive. + keepAliveTimeNanos = GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; + } + return this; + } + + /** + * Sets a custom max connection idle time, connection being idle for longer than which will be + * gracefully terminated. Idleness duration is defined since the most recent time the number of + * outstanding RPCs became zero or the connection establishment. An unreasonably small value might + * be increased. {@code Long.MAX_VALUE} nano seconds or an unreasonably large value will disable + * max connection idle. + */ + @Override + public OkHttpServerBuilder maxConnectionIdle(long maxConnectionIdle, TimeUnit timeUnit) { + checkArgument(maxConnectionIdle > 0L, "max connection idle must be positive: %s", + maxConnectionIdle); + maxConnectionIdleInNanos = timeUnit.toNanos(maxConnectionIdle); + if (maxConnectionIdleInNanos >= AS_LARGE_AS_INFINITE) { + maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; + } + if (maxConnectionIdleInNanos < MIN_MAX_CONNECTION_IDLE_NANO) { + maxConnectionIdleInNanos = MIN_MAX_CONNECTION_IDLE_NANO; + } + return this; + } + + /** + * Sets a custom max connection age, connection lasting longer than which will be gracefully + * terminated. An unreasonably small value might be increased. A random jitter of +/-10% will be + * added to it. {@code Long.MAX_VALUE} nano seconds or an unreasonably large value will disable + * max connection age. + */ + @Override + public OkHttpServerBuilder maxConnectionAge(long maxConnectionAge, TimeUnit timeUnit) { + checkArgument(maxConnectionAge > 0L, "max connection age must be positive: %s", + maxConnectionAge); + maxConnectionAgeInNanos = timeUnit.toNanos(maxConnectionAge); + if (maxConnectionAgeInNanos >= AS_LARGE_AS_INFINITE) { + maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED; + } + if (maxConnectionAgeInNanos < MIN_MAX_CONNECTION_AGE_NANO) { + maxConnectionAgeInNanos = MIN_MAX_CONNECTION_AGE_NANO; + } + return this; + } + + /** + * Sets a custom grace time for the graceful connection termination. Once the max connection age + * is reached, RPCs have the grace time to complete. RPCs that do not complete in time will be + * cancelled, allowing the connection to terminate. {@code Long.MAX_VALUE} nano seconds or an + * unreasonably large value are considered infinite. + * + * @see #maxConnectionAge(long, TimeUnit) + */ + @Override + public OkHttpServerBuilder maxConnectionAgeGrace(long maxConnectionAgeGrace, TimeUnit timeUnit) { + checkArgument(maxConnectionAgeGrace >= 0L, "max connection age grace must be non-negative: %s", + maxConnectionAgeGrace); + maxConnectionAgeGraceInNanos = timeUnit.toNanos(maxConnectionAgeGrace); + if (maxConnectionAgeGraceInNanos >= AS_LARGE_AS_INFINITE) { + maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; + } + return this; + } + + /** + * Sets a time waiting for read activity after sending a keepalive ping. If the time expires + * without any read activity on the connection, the connection is considered dead. An unreasonably + * small value might be increased. Defaults to 20 seconds. + * + *

    This value should be at least multiple times the RTT to allow for lost packets. + * + * @throws IllegalArgumentException if timeout is not positive + */ + @Override + public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { + Preconditions.checkArgument(keepAliveTimeout > 0L, "keepalive timeout must be positive"); + keepAliveTimeoutNanos = timeUnit.toNanos(keepAliveTimeout); + keepAliveTimeoutNanos = KeepAliveManager.clampKeepAliveTimeoutInNanos(keepAliveTimeoutNanos); + return this; + } + + /** + * Specify the most aggressive keep-alive time clients are permitted to configure. The server will + * try to detect clients exceeding this rate and when detected will forcefully close the + * connection. The default is 5 minutes. + * + *

    Even though a default is defined that allows some keep-alives, clients must not use + * keep-alive without approval from the service owner. Otherwise, they may experience failures in + * the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a + * significant amount of traffic and CPU usage, so clients and servers should be conservative in + * what they use and accept. + * + * @see #permitKeepAliveWithoutCalls(boolean) + */ + @CanIgnoreReturnValue + @Override + public OkHttpServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s", + keepAliveTime); + permitKeepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime); + return this; + } + + /** + * Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding + * RPCs on the connection. Defaults to {@code false}. + * + * @see #permitKeepAliveTime(long, TimeUnit) + */ + @CanIgnoreReturnValue + @Override + public OkHttpServerBuilder permitKeepAliveWithoutCalls(boolean permit) { + permitKeepAliveWithoutCalls = permit; + return this; + } + + /** + * Sets the flow control window in bytes. If not called, the default value is 64 KiB. + */ + public OkHttpServerBuilder flowControlWindow(int flowControlWindow) { + Preconditions.checkState(flowControlWindow > 0, "flowControlWindow must be positive"); + this.flowControlWindow = flowControlWindow; + return this; + } + + /** + * Provides a custom scheduled executor service. + * + *

    It's an optional parameter. If the user has not provided a scheduled executor service when + * the channel is built, the builder will use a static thread pool. + * + * @return this + */ + public OkHttpServerBuilder scheduledExecutorService( + ScheduledExecutorService scheduledExecutorService) { + this.scheduledExecutorServicePool = new FixedObjectPool<>( + Preconditions.checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + return this; + } + + /** + * Sets the maximum size of metadata allowed to be received. Defaults to 8 KiB. + * + *

    The implementation does not currently limit memory usage; this value is checked only after + * the metadata is decoded from the wire. It does prevent large metadata from being passed to the + * application. + * + * @param bytes the maximum size of received metadata + * @return this + * @throws IllegalArgumentException if bytes is non-positive + */ + @Override + public OkHttpServerBuilder maxInboundMetadataSize(int bytes) { + Preconditions.checkArgument(bytes > 0, "maxInboundMetadataSize must be > 0"); + this.maxInboundMetadataSize = bytes; + return this; + } + + /** + * Sets the maximum message size allowed to be received on the server. If not called, defaults to + * defaults to 4 MiB. The default provides protection to servers who haven't considered the + * possibility of receiving large messages while trying to be large enough to not be hit in normal + * usage. + * + * @param bytes the maximum number of bytes a single message can be. + * @return this + * @throws IllegalArgumentException if bytes is negative. + */ + @Override + public OkHttpServerBuilder maxInboundMessageSize(int bytes) { + Preconditions.checkArgument(bytes >= 0, "negative max bytes"); + maxInboundMessageSize = bytes; + return this; + } + + void setStatsEnabled(boolean value) { + this.serverImplBuilder.setStatsEnabled(value); + } + + InternalServer buildTransportServers( + List streamTracerFactories) { + return new OkHttpServer(this, streamTracerFactories, serverImplBuilder.getChannelz()); + } + + private static final EnumSet understoodTlsFeatures = + EnumSet.of( + TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); + + static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentials creds) { + if (creds instanceof TlsServerCredentials) { + TlsServerCredentials tlsCreds = (TlsServerCredentials) creds; + Set incomprehensible = + tlsCreds.incomprehensible(understoodTlsFeatures); + if (!incomprehensible.isEmpty()) { + return HandshakerSocketFactoryResult.error( + "TLS features not understood: " + incomprehensible); + } + KeyManager[] km = null; + if (tlsCreds.getKeyManagers() != null) { + km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]); + } else if (tlsCreds.getPrivateKey() != null) { + if (tlsCreds.getPrivateKeyPassword() != null) { + return HandshakerSocketFactoryResult.error("byte[]-based private key with password " + + "unsupported. Use unencrypted file or KeyManager"); + } + try { + km = OkHttpChannelBuilder.createKeyManager( + tlsCreds.getCertificateChain(), tlsCreds.getPrivateKey()); + } catch (GeneralSecurityException gse) { + log.log(Level.FINE, "Exception loading private key from credential", gse); + return HandshakerSocketFactoryResult.error( + "Unable to load private key: " + gse.getMessage()); + } + } // else don't have a client cert + TrustManager[] tm = null; + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + try { + tm = OkHttpChannelBuilder.createTrustManager(tlsCreds.getRootCertificates()); + } catch (GeneralSecurityException gse) { + log.log(Level.FINE, "Exception loading root certificates from credential", gse); + return HandshakerSocketFactoryResult.error( + "Unable to load root certificates: " + gse.getMessage()); + } + } // else use system default + SSLContext sslContext; + try { + sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider()); + sslContext.init(km, tm, null); + } catch (GeneralSecurityException gse) { + throw new RuntimeException("TLS Provider failure", gse); + } + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + switch (tlsCreds.getClientAuth()) { + case OPTIONAL: + sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, false); + break; + + case REQUIRE: + sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, true); + break; + + case NONE: + // NOOP; this is the SSLContext default + break; + + default: + return HandshakerSocketFactoryResult.error( + "Unknown TlsServerCredentials.ClientAuth value: " + tlsCreds.getClientAuth()); + } + return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory( + new SslSocketFactoryServerCredentials.ServerCredentials(sslSocketFactory))); + + } else if (creds instanceof InsecureServerCredentials) { + return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory()); + + } else if (creds instanceof SslSocketFactoryServerCredentials.ServerCredentials) { + SslSocketFactoryServerCredentials.ServerCredentials factoryCreds = + (SslSocketFactoryServerCredentials.ServerCredentials) creds; + return HandshakerSocketFactoryResult.factory( + new TlsServerHandshakerSocketFactory(factoryCreds)); + + } else if (creds instanceof ChoiceServerCredentials) { + ChoiceServerCredentials choiceCreds = (ChoiceServerCredentials) creds; + StringBuilder error = new StringBuilder(); + for (ServerCredentials innerCreds : choiceCreds.getCredentialsList()) { + HandshakerSocketFactoryResult result = handshakerSocketFactoryFrom(innerCreds); + if (result.error == null) { + return result; + } + error.append(", "); + error.append(result.error); + } + return HandshakerSocketFactoryResult.error(error.substring(2)); + + } else { + return HandshakerSocketFactoryResult.error( + "Unsupported credential type: " + creds.getClass().getName()); + } + } + + static final class HandshakerSocketFactoryResult { + public final HandshakerSocketFactory factory; + public final String error; + + private HandshakerSocketFactoryResult(HandshakerSocketFactory factory, String error) { + this.factory = factory; + this.error = error; + } + + public static HandshakerSocketFactoryResult error(String error) { + return new HandshakerSocketFactoryResult( + null, Preconditions.checkNotNull(error, "error")); + } + + public static HandshakerSocketFactoryResult factory(HandshakerSocketFactory factory) { + return new HandshakerSocketFactoryResult( + Preconditions.checkNotNull(factory, "factory"), null); + } + } + + static final class ClientCertRequestingSocketFactory extends SSLSocketFactory { + private final SSLSocketFactory socketFactory; + private final boolean required; + + public ClientCertRequestingSocketFactory(SSLSocketFactory socketFactory, boolean required) { + this.socketFactory = Preconditions.checkNotNull(socketFactory, "socketFactory"); + this.required = required; + } + + private Socket apply(Socket s) throws IOException { + if (!(s instanceof SSLSocket)) { + throw new IOException( + "SocketFactory " + socketFactory + " did not produce an SSLSocket: " + s.getClass()); + } + SSLSocket sslSocket = (SSLSocket) s; + if (required) { + sslSocket.setNeedClientAuth(true); + } else { + sslSocket.setWantClientAuth(true); + } + return sslSocket; + } + + @Override public Socket createSocket(Socket s, String host, int port, boolean autoClose) + throws IOException { + return apply(socketFactory.createSocket(s, host, port, autoClose)); + } + + @Override public Socket createSocket(String host, int port) throws IOException { + return apply(socketFactory.createSocket(host, port)); + } + + @Override public Socket createSocket( + String host, int port, InetAddress localHost, int localPort) throws IOException { + return apply(socketFactory.createSocket(host, port, localHost, localPort)); + } + + @Override public Socket createSocket(InetAddress host, int port) throws IOException { + return apply(socketFactory.createSocket(host, port)); + } + + @Override public Socket createSocket( + InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException { + return apply(socketFactory.createSocket(host, port, localAddress, localPort)); + } + + @Override public String[] getDefaultCipherSuites() { + return socketFactory.getDefaultCipherSuites(); + } + + @Override public String[] getSupportedCipherSuites() { + return socketFactory.getSupportedCipherSuites(); + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java new file mode 100644 index 00000000000..1def5c17e04 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java @@ -0,0 +1,302 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.AbstractServerStream; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.TransportTracer; +import io.grpc.internal.WritableBuffer; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.Header; +import io.perfmark.PerfMark; +import io.perfmark.Tag; +import java.util.List; +import javax.annotation.concurrent.GuardedBy; +import okio.Buffer; + +/** + * Server stream for the okhttp transport. + */ +class OkHttpServerStream extends AbstractServerStream { + private final String authority; + private final TransportState state; + private final Sink sink = new Sink(); + private final TransportTracer transportTracer; + private final Attributes attributes; + + public OkHttpServerStream( + TransportState state, + Attributes transportAttrs, + String authority, + StatsTraceContext statsTraceCtx, + TransportTracer transportTracer) { + super(new OkHttpWritableBufferAllocator(), statsTraceCtx); + this.state = Preconditions.checkNotNull(state, "state"); + this.attributes = Preconditions.checkNotNull(transportAttrs, "transportAttrs"); + this.authority = authority; + this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); + } + + @Override + protected TransportState transportState() { + return state; + } + + @Override + protected Sink abstractServerStreamSink() { + return sink; + } + + @Override + public int streamId() { + return state.streamId; + } + + @Override + public String getAuthority() { + return authority; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + class Sink implements AbstractServerStream.Sink { + @Override + public void writeHeaders(Metadata metadata) { + PerfMark.startTask("OkHttpServerStream$Sink.writeHeaders"); + try { + List

    responseHeaders = Headers.createResponseHeaders(metadata); + synchronized (state.lock) { + state.sendHeaders(responseHeaders); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.writeHeaders"); + } + } + + @Override + public void writeFrame(WritableBuffer frame, boolean flush, int numMessages) { + PerfMark.startTask("OkHttpServerStream$Sink.writeFrame"); + Buffer buffer = ((OkHttpWritableBuffer) frame).buffer(); + int size = (int) buffer.size(); + if (size > 0) { + onSendingBytes(size); + } + + try { + synchronized (state.lock) { + state.sendBuffer(buffer, flush); + transportTracer.reportMessageSent(numMessages); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.writeFrame"); + } + } + + @Override + public void writeTrailers(Metadata trailers, boolean headersSent, Status status) { + PerfMark.startTask("OkHttpServerStream$Sink.writeTrailers"); + try { + List
    responseTrailers = Headers.createResponseTrailers(trailers, headersSent); + synchronized (state.lock) { + state.sendTrailers(responseTrailers); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.writeTrailers"); + } + } + + @Override + public void cancel(Status reason) { + PerfMark.startTask("OkHttpServerStream$Sink.cancel"); + try { + synchronized (state.lock) { + state.cancel(ErrorCode.CANCEL, reason); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.cancel"); + } + } + } + + static class TransportState extends AbstractServerStream.TransportState + implements OutboundFlowController.Stream, OkHttpServerTransport.StreamState { + @GuardedBy("lock") + private final OkHttpServerTransport transport; + private final int streamId; + private final int initialWindowSize; + private final Object lock; + @GuardedBy("lock") + private boolean cancelSent = false; + @GuardedBy("lock") + private int window; + @GuardedBy("lock") + private int processedWindow; + @GuardedBy("lock") + private final ExceptionHandlingFrameWriter frameWriter; + @GuardedBy("lock") + private final OutboundFlowController outboundFlow; + @GuardedBy("lock") + private boolean receivedEndOfStream; + private final Tag tag; + private final OutboundFlowController.StreamState outboundFlowState; + + public TransportState( + OkHttpServerTransport transport, + int streamId, + int maxMessageSize, + StatsTraceContext statsTraceCtx, + Object lock, + ExceptionHandlingFrameWriter frameWriter, + OutboundFlowController outboundFlow, + int initialWindowSize, + TransportTracer transportTracer, + String methodName) { + super(maxMessageSize, statsTraceCtx, transportTracer); + this.transport = Preconditions.checkNotNull(transport, "transport"); + this.streamId = streamId; + this.lock = Preconditions.checkNotNull(lock, "lock"); + this.frameWriter = frameWriter; + this.outboundFlow = outboundFlow; + this.window = initialWindowSize; + this.processedWindow = initialWindowSize; + this.initialWindowSize = initialWindowSize; + tag = PerfMark.createTag(methodName); + outboundFlowState = outboundFlow.createState(this, streamId); + } + + @Override + @GuardedBy("lock") + public void deframeFailed(Throwable cause) { + cancel(ErrorCode.INTERNAL_ERROR, Status.fromThrowable(cause)); + } + + @Override + @GuardedBy("lock") + public void bytesRead(int processedBytes) { + processedWindow -= processedBytes; + if (processedWindow <= initialWindowSize * Utils.DEFAULT_WINDOW_UPDATE_RATIO) { + int delta = initialWindowSize - processedWindow; + window += delta; + processedWindow += delta; + frameWriter.windowUpdate(streamId, delta); + frameWriter.flush(); + } + } + + @Override + @GuardedBy("lock") + public void runOnTransportThread(final Runnable r) { + synchronized (lock) { + r.run(); + } + } + + /** + * Must be called with holding the transport lock. + */ + @Override + public void inboundDataReceived(okio.Buffer frame, int windowConsumed, boolean endOfStream) { + synchronized (lock) { + PerfMark.event("OkHttpServerTransport$FrameHandler.data", tag); + if (endOfStream) { + this.receivedEndOfStream = true; + } + window -= windowConsumed; + super.inboundDataReceived(new OkHttpReadableBuffer(frame), endOfStream); + } + } + + /** Must be called with holding the transport lock. */ + @Override + public void inboundRstReceived(Status status) { + PerfMark.event("OkHttpServerTransport$FrameHandler.rstStream", tag); + transportReportStatus(status); + } + + /** Must be called with holding the transport lock. */ + @Override + public boolean hasReceivedEndOfStream() { + synchronized (lock) { + return receivedEndOfStream; + } + } + + /** Must be called with holding the transport lock. */ + @Override + public int inboundWindowAvailable() { + synchronized (lock) { + return window; + } + } + + @GuardedBy("lock") + private void sendBuffer(Buffer buffer, boolean flush) { + if (cancelSent) { + return; + } + // If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is + // properly chunked. + outboundFlow.data(false, outboundFlowState, buffer, flush); + } + + @GuardedBy("lock") + private void sendHeaders(List
    responseHeaders) { + frameWriter.synReply(false, streamId, responseHeaders); + frameWriter.flush(); + } + + @GuardedBy("lock") + private void sendTrailers(List
    responseTrailers) { + outboundFlow.notifyWhenNoPendingData( + outboundFlowState, () -> sendTrailersAfterFlowControlled(responseTrailers)); + } + + private void sendTrailersAfterFlowControlled(List
    responseTrailers) { + synchronized (lock) { + frameWriter.synReply(true, streamId, responseTrailers); + if (!receivedEndOfStream) { + frameWriter.rstStream(streamId, ErrorCode.NO_ERROR); + } + transport.streamClosed(streamId, /*flush=*/ true); + complete(); + } + } + + @GuardedBy("lock") + private void cancel(ErrorCode http2Error, Status reason) { + if (cancelSent) { + return; + } + cancelSent = true; + frameWriter.rstStream(streamId, http2Error); + transportReportStatus(reason); + transport.streamClosed(streamId, /*flush=*/ true); + } + + @Override + public OutboundFlowController.StreamState getOutboundFlowState() { + return outboundFlowState; + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java new file mode 100644 index 00000000000..1fd98079ede --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -0,0 +1,1179 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; +import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.Attributes; +import io.grpc.InternalChannelz; +import io.grpc.InternalLogId; +import io.grpc.InternalStatus; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; +import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.LogExceptionRunnable; +import io.grpc.internal.MaxConnectionIdleManager; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SerializingExecutor; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.TransportTracer; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameReader; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.HeadersMode; +import io.grpc.okhttp.internal.framed.Http2; +import io.grpc.okhttp.internal.framed.Settings; +import io.grpc.okhttp.internal.framed.Variant; +import java.io.IOException; +import java.net.Socket; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.GuardedBy; +import okio.Buffer; +import okio.BufferedSource; +import okio.ByteString; +import okio.Okio; + +/** + * OkHttp-based server transport. + */ +final class OkHttpServerTransport implements ServerTransport, + ExceptionHandlingFrameWriter.TransportExceptionHandler, OutboundFlowController.Transport { + private static final Logger log = Logger.getLogger(OkHttpServerTransport.class.getName()); + private static final int GRACEFUL_SHUTDOWN_PING = 0x1111; + private static final int KEEPALIVE_PING = 0xDEAD; + private static final ByteString HTTP_METHOD = ByteString.encodeUtf8(":method"); + private static final ByteString CONNECT_METHOD = ByteString.encodeUtf8("CONNECT"); + private static final ByteString POST_METHOD = ByteString.encodeUtf8("POST"); + private static final ByteString SCHEME = ByteString.encodeUtf8(":scheme"); + private static final ByteString PATH = ByteString.encodeUtf8(":path"); + private static final ByteString AUTHORITY = ByteString.encodeUtf8(":authority"); + private static final ByteString CONNECTION = ByteString.encodeUtf8("connection"); + private static final ByteString HOST = ByteString.encodeUtf8("host"); + private static final ByteString TE = ByteString.encodeUtf8("te"); + private static final ByteString TE_TRAILERS = ByteString.encodeUtf8("trailers"); + private static final ByteString CONTENT_TYPE = ByteString.encodeUtf8("content-type"); + private static final ByteString CONTENT_LENGTH = ByteString.encodeUtf8("content-length"); + + private final Config config; + private final Socket bareSocket; + private final Variant variant = new Http2(); + private final TransportTracer tracer; + private final InternalLogId logId; + private ServerTransportListener listener; + private Executor transportExecutor; + private ScheduledExecutorService scheduledExecutorService; + private Attributes attributes; + private KeepAliveManager keepAliveManager; + private MaxConnectionIdleManager maxConnectionIdleManager; + private ScheduledFuture maxConnectionAgeMonitor; + private final KeepAliveEnforcer keepAliveEnforcer; + + private final Object lock = new Object(); + @GuardedBy("lock") + private boolean abruptShutdown; + @GuardedBy("lock") + private boolean gracefulShutdown; + @GuardedBy("lock") + private boolean handshakeShutdown; + @GuardedBy("lock") + private InternalChannelz.Security securityInfo; + @GuardedBy("lock") + private ExceptionHandlingFrameWriter frameWriter; + @GuardedBy("lock") + private OutboundFlowController outboundFlow; + @GuardedBy("lock") + private final Map streams = new TreeMap<>(); + @GuardedBy("lock") + private int lastStreamId; + @GuardedBy("lock") + private int goAwayStreamId = Integer.MAX_VALUE; + /** + * Indicates the transport is in go-away state: no new streams will be processed, but existing + * streams may continue. + */ + @GuardedBy("lock") + private Status goAwayStatus; + /** Non-{@code null} when gracefully shutting down and have not yet sent second GOAWAY. */ + @GuardedBy("lock") + private ScheduledFuture secondGoawayTimer; + /** Non-{@code null} when waiting for forceful close GOAWAY to be sent. */ + @GuardedBy("lock") + private ScheduledFuture forcefulCloseTimer; + + public OkHttpServerTransport(Config config, Socket bareSocket) { + this.config = Preconditions.checkNotNull(config, "config"); + this.bareSocket = Preconditions.checkNotNull(bareSocket, "bareSocket"); + + tracer = config.transportTracerFactory.create(); + tracer.setFlowControlWindowReader(this::readFlowControlWindow); + logId = InternalLogId.allocate(getClass(), bareSocket.getRemoteSocketAddress().toString()); + transportExecutor = config.transportExecutorPool.getObject(); + scheduledExecutorService = config.scheduledExecutorServicePool.getObject(); + keepAliveEnforcer = new KeepAliveEnforcer(config.permitKeepAliveWithoutCalls, + config.permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS); + } + + public void start(ServerTransportListener listener) { + this.listener = Preconditions.checkNotNull(listener, "listener"); + + SerializingExecutor serializingExecutor = new SerializingExecutor(transportExecutor); + serializingExecutor.execute(() -> startIo(serializingExecutor)); + } + + private void startIo(SerializingExecutor serializingExecutor) { + try { + bareSocket.setTcpNoDelay(true); + HandshakerSocketFactory.HandshakeResult result = + config.handshakerSocketFactory.handshake(bareSocket, Attributes.EMPTY); + Socket socket = result.socket; + this.attributes = result.attributes; + + int maxQueuedControlFrames = 10000; + AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this, maxQueuedControlFrames); + asyncSink.becomeConnected(Okio.sink(socket), socket); + FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter( + variant.newWriter(Okio.buffer(asyncSink), false)); + FrameWriter writeMonitoringFrameWriter = new ForwardingFrameWriter(rawFrameWriter) { + @Override + public void synReply(boolean outFinished, int streamId, List
    headerBlock) + throws IOException { + keepAliveEnforcer.resetCounters(); + super.synReply(outFinished, streamId, headerBlock); + } + + @Override + public void headers(int streamId, List
    headerBlock) throws IOException { + keepAliveEnforcer.resetCounters(); + super.headers(streamId, headerBlock); + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) + throws IOException { + keepAliveEnforcer.resetCounters(); + super.data(outFinished, streamId, source, byteCount); + } + }; + synchronized (lock) { + this.securityInfo = result.securityInfo; + + // Handle FrameWriter exceptions centrally, since there are many callers. Note that + // errors coming from rawFrameWriter are generally broken invariants/bugs, as AsyncSink + // does not propagate syscall errors through the FrameWriter. But we handle the + // AsyncSink failures with the same TransportExceptionHandler instance so it is all + // mixed back together. + frameWriter = new ExceptionHandlingFrameWriter(this, writeMonitoringFrameWriter); + outboundFlow = new OutboundFlowController(this, frameWriter); + + // These writes will be queued in the serializingExecutor waiting for this function to + // return. + frameWriter.connectionPreface(); + Settings settings = new Settings(); + OkHttpSettingsUtil.set(settings, + OkHttpSettingsUtil.INITIAL_WINDOW_SIZE, config.flowControlWindow); + OkHttpSettingsUtil.set(settings, + OkHttpSettingsUtil.MAX_HEADER_LIST_SIZE, config.maxInboundMetadataSize); + frameWriter.settings(settings); + if (config.flowControlWindow > Utils.DEFAULT_WINDOW_SIZE) { + frameWriter.windowUpdate( + Utils.CONNECTION_STREAM_ID, config.flowControlWindow - Utils.DEFAULT_WINDOW_SIZE); + } + frameWriter.flush(); + } + + if (config.keepAliveTimeNanos != GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED) { + keepAliveManager = new KeepAliveManager( + new KeepAlivePinger(), scheduledExecutorService, config.keepAliveTimeNanos, + config.keepAliveTimeoutNanos, true); + keepAliveManager.onTransportStarted(); + } + + if (config.maxConnectionIdleNanos != MAX_CONNECTION_IDLE_NANOS_DISABLED) { + maxConnectionIdleManager = new MaxConnectionIdleManager(config.maxConnectionIdleNanos); + maxConnectionIdleManager.start(this::shutdown, scheduledExecutorService); + } + + if (config.maxConnectionAgeInNanos != MAX_CONNECTION_AGE_NANOS_DISABLED) { + long maxConnectionAgeInNanos = + (long) ((.9D + Math.random() * .2D) * config.maxConnectionAgeInNanos); + maxConnectionAgeMonitor = scheduledExecutorService.schedule( + new LogExceptionRunnable(() -> shutdown(config.maxConnectionAgeGraceInNanos)), + maxConnectionAgeInNanos, + TimeUnit.NANOSECONDS); + } + + transportExecutor.execute( + new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false))); + } catch (Error | IOException | RuntimeException ex) { + synchronized (lock) { + if (!handshakeShutdown) { + log.log(Level.INFO, "Socket failed to handshake", ex); + } + } + GrpcUtil.closeQuietly(bareSocket); + terminated(); + } + } + + @Override + public void shutdown() { + shutdown(TimeUnit.SECONDS.toNanos(1L)); + } + + private void shutdown(Long graceTimeInNanos) { + synchronized (lock) { + if (gracefulShutdown || abruptShutdown) { + return; + } + gracefulShutdown = true; + if (frameWriter == null) { + handshakeShutdown = true; + GrpcUtil.closeQuietly(bareSocket); + } else { + // RFC7540 §6.8. Begin double-GOAWAY graceful shutdown. To wait one RTT we use a PING, but + // we also set a timer to limit the upper bound in case the PING is excessively stalled or + // the client is malicious. + secondGoawayTimer = scheduledExecutorService.schedule( + this::triggerGracefulSecondGoaway, graceTimeInNanos, TimeUnit.NANOSECONDS); + frameWriter.goAway(Integer.MAX_VALUE, ErrorCode.NO_ERROR, new byte[0]); + frameWriter.ping(false, 0, GRACEFUL_SHUTDOWN_PING); + frameWriter.flush(); + } + } + } + + private void triggerGracefulSecondGoaway() { + synchronized (lock) { + if (secondGoawayTimer == null) { + return; + } + secondGoawayTimer.cancel(false); + secondGoawayTimer = null; + frameWriter.goAway(lastStreamId, ErrorCode.NO_ERROR, new byte[0]); + goAwayStreamId = lastStreamId; + if (streams.isEmpty()) { + frameWriter.close(); + } else { + frameWriter.flush(); + } + } + } + + @Override + public void shutdownNow(Status reason) { + synchronized (lock) { + if (frameWriter == null) { + handshakeShutdown = true; + GrpcUtil.closeQuietly(bareSocket); + return; + } + } + abruptShutdown(ErrorCode.NO_ERROR, "", reason, true); + } + + /** + * Finish all active streams due to an IOException, then close the transport. + */ + @Override + public void onException(Throwable failureCause) { + Preconditions.checkNotNull(failureCause, "failureCause"); + Status status = Status.UNAVAILABLE.withCause(failureCause); + abruptShutdown(ErrorCode.INTERNAL_ERROR, "I/O failure", status, false); + } + + private void abruptShutdown( + ErrorCode errorCode, String moreDetail, Status reason, boolean rstStreams) { + synchronized (lock) { + if (abruptShutdown) { + return; + } + abruptShutdown = true; + goAwayStatus = reason; + + if (secondGoawayTimer != null) { + secondGoawayTimer.cancel(false); + secondGoawayTimer = null; + } + for (Map.Entry entry : streams.entrySet()) { + if (rstStreams) { + frameWriter.rstStream(entry.getKey(), ErrorCode.CANCEL); + } + entry.getValue().transportReportStatus(reason); + } + streams.clear(); + + // RFC7540 §5.4.1. Attempt to inform the client what went wrong. We try to write the GOAWAY + // _and then_ close our side of the connection. But place an upper-bound for how long we wait + // for I/O with a timer, which forcefully closes the socket. + frameWriter.goAway(lastStreamId, errorCode, moreDetail.getBytes(GrpcUtil.US_ASCII)); + goAwayStreamId = lastStreamId; + frameWriter.close(); + forcefulCloseTimer = scheduledExecutorService.schedule( + this::triggerForcefulClose, 1, TimeUnit.SECONDS); + } + } + + private void triggerForcefulClose() { + // Safe to do unconditionally; no need to check if timer cancellation raced + GrpcUtil.closeQuietly(bareSocket); + } + + private void terminated() { + synchronized (lock) { + if (forcefulCloseTimer != null) { + forcefulCloseTimer.cancel(false); + forcefulCloseTimer = null; + } + } + if (keepAliveManager != null) { + keepAliveManager.onTransportTermination(); + } + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportTermination(); + } + + if (maxConnectionAgeMonitor != null) { + maxConnectionAgeMonitor.cancel(false); + } + transportExecutor = config.transportExecutorPool.returnObject(transportExecutor); + scheduledExecutorService = + config.scheduledExecutorServicePool.returnObject(scheduledExecutorService); + listener.transportTerminated(); + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return scheduledExecutorService; + } + + @Override + public ListenableFuture getStats() { + synchronized (lock) { + return Futures.immediateFuture(new InternalChannelz.SocketStats( + tracer.getStats(), + bareSocket.getLocalSocketAddress(), + bareSocket.getRemoteSocketAddress(), + Utils.getSocketOptions(bareSocket), + securityInfo)); + } + } + + private TransportTracer.FlowControlWindows readFlowControlWindow() { + synchronized (lock) { + long local = outboundFlow == null ? -1 : outboundFlow.windowUpdate(null, 0); + // connectionUnacknowledgedBytesRead is only readable by FrameHandler, so we provide a lower + // bound. + long remote = (long) (config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO); + return new TransportTracer.FlowControlWindows(local, remote); + } + } + + @Override + public InternalLogId getLogId() { + return logId; + } + + @Override + public OutboundFlowController.StreamState[] getActiveStreams() { + synchronized (lock) { + OutboundFlowController.StreamState[] flowStreams = + new OutboundFlowController.StreamState[streams.size()]; + int i = 0; + for (StreamState stream : streams.values()) { + flowStreams[i++] = stream.getOutboundFlowState(); + } + return flowStreams; + } + } + + /** + * Notify the transport that the stream was closed. Any frames for the stream must be enqueued + * before calling. + */ + void streamClosed(int streamId, boolean flush) { + synchronized (lock) { + streams.remove(streamId); + if (streams.isEmpty()) { + keepAliveEnforcer.onTransportIdle(); + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportIdle(); + } + } + if (gracefulShutdown && streams.isEmpty()) { + frameWriter.close(); + } else { + if (flush) { + frameWriter.flush(); + } + } + } + } + + private static String asciiString(ByteString value) { + // utf8() string is cached in ByteString, so we prefer it when the contents are ASCII. This + // provides benefit if the header was reused via HPACK. + for (int i = 0; i < value.size(); i++) { + if (value.getByte(i) >= 0x80) { + return value.string(GrpcUtil.US_ASCII); + } + } + return value.utf8(); + } + + private static int headerFind(List
    header, ByteString key, int startIndex) { + for (int i = startIndex; i < header.size(); i++) { + if (header.get(i).name.equals(key)) { + return i; + } + } + return -1; + } + + private static boolean headerContains(List
    header, ByteString key) { + return headerFind(header, key, 0) != -1; + } + + private static void headerRemove(List
    header, ByteString key) { + int i = 0; + while ((i = headerFind(header, key, i)) != -1) { + header.remove(i); + } + } + + /** Assumes that caller requires this field, so duplicates are treated as missing. */ + private static ByteString headerGetRequiredSingle(List
    header, ByteString key) { + int i = headerFind(header, key, 0); + if (i == -1) { + return null; + } + if (headerFind(header, key, i + 1) != -1) { + return null; + } + return header.get(i).value; + } + + static final class Config { + final List streamTracerFactories; + final ObjectPool transportExecutorPool; + final ObjectPool scheduledExecutorServicePool; + final TransportTracer.Factory transportTracerFactory; + final HandshakerSocketFactory handshakerSocketFactory; + final long keepAliveTimeNanos; + final long keepAliveTimeoutNanos; + final int flowControlWindow; + final int maxInboundMessageSize; + final int maxInboundMetadataSize; + final long maxConnectionIdleNanos; + final boolean permitKeepAliveWithoutCalls; + final long permitKeepAliveTimeInNanos; + final long maxConnectionAgeInNanos; + final long maxConnectionAgeGraceInNanos; + + public Config( + OkHttpServerBuilder builder, + List streamTracerFactories) { + this.streamTracerFactories = Preconditions.checkNotNull( + streamTracerFactories, "streamTracerFactories"); + transportExecutorPool = Preconditions.checkNotNull( + builder.transportExecutorPool, "transportExecutorPool"); + scheduledExecutorServicePool = Preconditions.checkNotNull( + builder.scheduledExecutorServicePool, "scheduledExecutorServicePool"); + transportTracerFactory = Preconditions.checkNotNull( + builder.transportTracerFactory, "transportTracerFactory"); + handshakerSocketFactory = Preconditions.checkNotNull( + builder.handshakerSocketFactory, "handshakerSocketFactory"); + keepAliveTimeNanos = builder.keepAliveTimeNanos; + keepAliveTimeoutNanos = builder.keepAliveTimeoutNanos; + flowControlWindow = builder.flowControlWindow; + maxInboundMessageSize = builder.maxInboundMessageSize; + maxInboundMetadataSize = builder.maxInboundMetadataSize; + maxConnectionIdleNanos = builder.maxConnectionIdleInNanos; + permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls; + permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos; + maxConnectionAgeInNanos = builder.maxConnectionAgeInNanos; + maxConnectionAgeGraceInNanos = builder.maxConnectionAgeGraceInNanos; + } + } + + /** + * Runnable which reads frames and dispatches them to in flight calls. + */ + class FrameHandler implements FrameReader.Handler, Runnable { + private final OkHttpFrameLogger frameLogger = + new OkHttpFrameLogger(Level.FINE, OkHttpServerTransport.class); + private final FrameReader frameReader; + private boolean receivedSettings; + private int connectionUnacknowledgedBytesRead; + + public FrameHandler(FrameReader frameReader) { + this.frameReader = frameReader; + } + + @Override + public void run() { + String threadName = Thread.currentThread().getName(); + Thread.currentThread().setName("OkHttpServerTransport"); + try { + frameReader.readConnectionPreface(); + if (!frameReader.nextFrame(this)) { + connectionError(ErrorCode.INTERNAL_ERROR, "Failed to read initial SETTINGS"); + return; + } + if (!receivedSettings) { + connectionError(ErrorCode.PROTOCOL_ERROR, + "First HTTP/2 frame must be SETTINGS. RFC7540 section 3.5"); + return; + } + // Read until the underlying socket closes. + while (frameReader.nextFrame(this)) { + if (keepAliveManager != null) { + keepAliveManager.onDataReceived(); + } + } + // frameReader.nextFrame() returns false when the underlying read encounters an IOException, + // it may be triggered by the socket closing, in such case, the startGoAway() will do + // nothing, otherwise, we finish all streams since it's a real IO issue. + Status status; + synchronized (lock) { + status = goAwayStatus; + } + if (status == null) { + status = Status.UNAVAILABLE.withDescription("TCP connection closed or IOException"); + } + abruptShutdown(ErrorCode.INTERNAL_ERROR, "I/O failure", status, false); + } catch (Throwable t) { + log.log(Level.WARNING, "Error decoding HTTP/2 frames", t); + abruptShutdown(ErrorCode.INTERNAL_ERROR, "Error in frame decoder", + Status.INTERNAL.withDescription("Error decoding HTTP/2 frames").withCause(t), false); + } finally { + // Wait for the abrupt shutdown to be processed by AsyncSink and close the socket + try { + GrpcUtil.exhaust(bareSocket.getInputStream()); + } catch (IOException ex) { + // Unable to wait, so just proceed to tear-down. The socket is probably already closed so + // the GOAWAY can't be sent anyway. + } + GrpcUtil.closeQuietly(bareSocket); + terminated(); + Thread.currentThread().setName(threadName); + } + } + + /** + * Handle HTTP2 HEADER and CONTINUATION frames. + */ + @Override + public void headers(boolean outFinished, + boolean inFinished, + int streamId, + int associatedStreamId, + List
    headerBlock, + HeadersMode headersMode) { + frameLogger.logHeaders( + OkHttpFrameLogger.Direction.INBOUND, streamId, headerBlock, inFinished); + // streamId == 0 checking is in HTTP/2 decoder + if ((streamId & 1) == 0) { + // The server doesn't use PUSH_PROMISE, so all even streams are IDLE + connectionError(ErrorCode.PROTOCOL_ERROR, + "Clients cannot open even numbered streams. RFC7540 section 5.1.1"); + return; + } + boolean newStream; + synchronized (lock) { + if (streamId > goAwayStreamId) { + return; + } + newStream = streamId > lastStreamId; + if (newStream) { + lastStreamId = streamId; + } + } + + int metadataSize = headerBlockSize(headerBlock); + if (metadataSize > config.maxInboundMetadataSize) { + respondWithHttpError(streamId, inFinished, 431, Status.Code.RESOURCE_EXHAUSTED, + String.format( + Locale.US, + "Request metadata larger than %d: %d", + config.maxInboundMetadataSize, + metadataSize)); + return; + } + + headerRemove(headerBlock, ByteString.EMPTY); + + ByteString httpMethod = null; + ByteString scheme = null; + ByteString path = null; + ByteString authority = null; + while (headerBlock.size() > 0 && headerBlock.get(0).name.getByte(0) == ':') { + Header header = headerBlock.remove(0); + if (HTTP_METHOD.equals(header.name) && httpMethod == null) { + httpMethod = header.value; + } else if (SCHEME.equals(header.name) && scheme == null) { + scheme = header.value; + } else if (PATH.equals(header.name) && path == null) { + path = header.value; + } else if (AUTHORITY.equals(header.name) && authority == null) { + authority = header.value; + } else { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Unexpected pseudo header. RFC7540 section 8.1.2.1"); + return; + } + } + for (int i = 0; i < headerBlock.size(); i++) { + if (headerBlock.get(i).name.getByte(0) == ':') { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Pseudo header not before regular headers. RFC7540 section 8.1.2.1"); + return; + } + } + if (!CONNECT_METHOD.equals(httpMethod) + && newStream + && (httpMethod == null || scheme == null || path == null)) { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Missing required pseudo header. RFC7540 section 8.1.2.3"); + return; + } + if (headerContains(headerBlock, CONNECTION)) { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Connection-specific headers not permitted. RFC7540 section 8.1.2.2"); + return; + } + + if (!newStream) { + if (inFinished) { + synchronized (lock) { + StreamState stream = streams.get(streamId); + if (stream == null) { + streamError(streamId, ErrorCode.STREAM_CLOSED, "Received headers for closed stream"); + return; + } + if (stream.hasReceivedEndOfStream()) { + streamError(streamId, ErrorCode.STREAM_CLOSED, + "Received HEADERS for half-closed (remote) stream. RFC7540 section 5.1"); + return; + } + // Ignore the trailers, but still half-close the stream + stream.inboundDataReceived(new Buffer(), 0, true); + return; + } + } else { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Headers disallowed in the middle of the stream. RFC7540 section 8.1"); + return; + } + } + + if (authority == null) { + int i = headerFind(headerBlock, HOST, 0); + if (i != -1) { + if (headerFind(headerBlock, HOST, i + 1) != -1) { + respondWithHttpError(streamId, inFinished, 400, Status.Code.INTERNAL, + "Multiple host headers disallowed. RFC7230 section 5.4"); + return; + } + authority = headerBlock.get(i).value; + } + } + headerRemove(headerBlock, HOST); + + // Remove the leading slash of the path and get the fully qualified method name + if (path.size() == 0 || path.getByte(0) != '/') { + respondWithHttpError(streamId, inFinished, 404, Status.Code.UNIMPLEMENTED, + "Expected path to start with /: " + asciiString(path)); + return; + } + String method = asciiString(path).substring(1); + + ByteString contentType = headerGetRequiredSingle(headerBlock, CONTENT_TYPE); + if (contentType == null) { + respondWithHttpError(streamId, inFinished, 415, Status.Code.INTERNAL, + "Content-Type is missing or duplicated"); + return; + } + String contentTypeString = asciiString(contentType); + if (!GrpcUtil.isGrpcContentType(contentTypeString)) { + respondWithHttpError(streamId, inFinished, 415, Status.Code.INTERNAL, + "Content-Type is not supported: " + contentTypeString); + return; + } + + if (!POST_METHOD.equals(httpMethod)) { + respondWithHttpError(streamId, inFinished, 405, Status.Code.INTERNAL, + "HTTP Method is not supported: " + asciiString(httpMethod)); + return; + } + + ByteString te = headerGetRequiredSingle(headerBlock, TE); + if (!TE_TRAILERS.equals(te)) { + respondWithGrpcError(streamId, inFinished, Status.Code.INTERNAL, + String.format("Expected header TE: %s, but %s is received. " + + "Some intermediate proxy may not support trailers", + asciiString(TE_TRAILERS), te == null ? "" : asciiString(te))); + return; + } + headerRemove(headerBlock, CONTENT_LENGTH); + + Metadata metadata = Utils.convertHeaders(headerBlock); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext(config.streamTracerFactories, method, metadata); + synchronized (lock) { + OkHttpServerStream.TransportState stream = new OkHttpServerStream.TransportState( + OkHttpServerTransport.this, + streamId, + config.maxInboundMessageSize, + statsTraceCtx, + lock, + frameWriter, + outboundFlow, + config.flowControlWindow, + tracer, + method); + OkHttpServerStream streamForApp = new OkHttpServerStream( + stream, + attributes, + authority == null ? null : asciiString(authority), + statsTraceCtx, + tracer); + if (streams.isEmpty()) { + keepAliveEnforcer.onTransportActive(); + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportActive(); + } + } + streams.put(streamId, stream); + listener.streamCreated(streamForApp, method, metadata); + stream.onStreamAllocated(); + if (inFinished) { + stream.inboundDataReceived(new Buffer(), 0, inFinished); + } + } + } + + private int headerBlockSize(List
    headerBlock) { + // Calculate as defined for SETTINGS_MAX_HEADER_LIST_SIZE in RFC 7540 §6.5.2. + long size = 0; + for (int i = 0; i < headerBlock.size(); i++) { + Header header = headerBlock.get(i); + size += 32 + header.name.size() + header.value.size(); + } + size = Math.min(size, Integer.MAX_VALUE); + return (int) size; + } + + /** + * Handle an HTTP2 DATA frame. + */ + @Override + public void data(boolean inFinished, int streamId, BufferedSource in, int length) + throws IOException { + frameLogger.logData( + OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished); + if (streamId == 0) { + connectionError(ErrorCode.PROTOCOL_ERROR, + "Stream 0 is reserved for control messages. RFC7540 section 5.1.1"); + return; + } + if ((streamId & 1) == 0) { + // The server doesn't use PUSH_PROMISE, so all even streams are IDLE + connectionError(ErrorCode.PROTOCOL_ERROR, + "Clients cannot open even numbered streams. RFC7540 section 5.1.1"); + return; + } + + // Wait until the frame is complete. We only support 16 KiB frames, and the max permitted in + // HTTP/2 is 16 MiB. This is verified in OkHttp's Http2 deframer, so we don't need to be + // concerned with the window being exceeded at this point. + in.require(length); + + synchronized (lock) { + StreamState stream = streams.get(streamId); + if (stream == null) { + in.skip(length); + streamError(streamId, ErrorCode.STREAM_CLOSED, "Received data for closed stream"); + return; + } + if (stream.hasReceivedEndOfStream()) { + in.skip(length); + streamError(streamId, ErrorCode.STREAM_CLOSED, + "Received DATA for half-closed (remote) stream. RFC7540 section 5.1"); + return; + } + if (stream.inboundWindowAvailable() < length) { + in.skip(length); + streamError(streamId, ErrorCode.FLOW_CONTROL_ERROR, + "Received DATA size exceeded window size. RFC7540 section 6.9"); + return; + } + Buffer buf = new Buffer(); + buf.write(in.getBuffer(), length); + stream.inboundDataReceived(buf, length, inFinished); + } + + // connection window update + connectionUnacknowledgedBytesRead += length; + if (connectionUnacknowledgedBytesRead + >= config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO) { + synchronized (lock) { + frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + frameWriter.flush(); + } + connectionUnacknowledgedBytesRead = 0; + } + } + + @Override + public void rstStream(int streamId, ErrorCode errorCode) { + frameLogger.logRstStream(OkHttpFrameLogger.Direction.INBOUND, streamId, errorCode); + // streamId == 0 checking is in HTTP/2 decoder + + if (!(ErrorCode.NO_ERROR.equals(errorCode) + || ErrorCode.CANCEL.equals(errorCode) + || ErrorCode.STREAM_CLOSED.equals(errorCode))) { + log.log(Level.INFO, "Received RST_STREAM: " + errorCode); + } + Status status = GrpcUtil.Http2Error.statusForCode(errorCode.httpCode) + .withDescription("RST_STREAM"); + synchronized (lock) { + StreamState stream = streams.get(streamId); + if (stream != null) { + stream.inboundRstReceived(status); + streamClosed(streamId, /*flush=*/ false); + } + } + } + + @Override + public void settings(boolean clearPrevious, Settings settings) { + frameLogger.logSettings(OkHttpFrameLogger.Direction.INBOUND, settings); + synchronized (lock) { + boolean outboundWindowSizeIncreased = false; + if (OkHttpSettingsUtil.isSet(settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE)) { + int initialWindowSize = OkHttpSettingsUtil.get( + settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE); + outboundWindowSizeIncreased = outboundFlow.initialOutboundWindowSize(initialWindowSize); + } + + // The changed settings are not finalized until SETTINGS acknowledgment frame is sent. Any + // writes due to update in settings must be sent after SETTINGS acknowledgment frame, + // otherwise it will cause a stream error (RST_STREAM). + frameWriter.ackSettings(settings); + frameWriter.flush(); + if (!receivedSettings) { + receivedSettings = true; + attributes = listener.transportReady(attributes); + } + + // send any pending bytes / streams + if (outboundWindowSizeIncreased) { + outboundFlow.writeStreams(); + } + } + } + + @Override + public void ping(boolean ack, int payload1, int payload2) { + if (!keepAliveEnforcer.pingAcceptable()) { + abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings", + Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false); + return; + } + long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); + if (!ack) { + frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload); + synchronized (lock) { + frameWriter.ping(true, payload1, payload2); + frameWriter.flush(); + } + } else { + frameLogger.logPingAck(OkHttpFrameLogger.Direction.INBOUND, payload); + if (KEEPALIVE_PING == payload) { + return; + } + if (GRACEFUL_SHUTDOWN_PING == payload) { + triggerGracefulSecondGoaway(); + return; + } + log.log(Level.INFO, "Received unexpected ping ack: " + payload); + } + } + + @Override + public void ackSettings() {} + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) { + frameLogger.logGoAway( + OkHttpFrameLogger.Direction.INBOUND, lastGoodStreamId, errorCode, debugData); + String description = String.format("Received GOAWAY: %s '%s'", errorCode, debugData.utf8()); + Status status = GrpcUtil.Http2Error.statusForCode(errorCode.httpCode) + .withDescription(description); + if (!ErrorCode.NO_ERROR.equals(errorCode)) { + log.log( + Level.WARNING, "Received GOAWAY: {0} {1}", new Object[] {errorCode, debugData.utf8()}); + } + synchronized (lock) { + goAwayStatus = status; + } + } + + @Override + public void pushPromise(int streamId, int promisedStreamId, List
    requestHeaders) + throws IOException { + frameLogger.logPushPromise(OkHttpFrameLogger.Direction.INBOUND, + streamId, promisedStreamId, requestHeaders); + // streamId == 0 checking is in HTTP/2 decoder. + // The server doesn't use PUSH_PROMISE, so all even streams are IDLE, and odd streams are not + // peer-initiated. + connectionError(ErrorCode.PROTOCOL_ERROR, + "PUSH_PROMISE only allowed on peer-initiated streams. RFC7540 section 6.6"); + } + + @Override + public void windowUpdate(int streamId, long delta) { + frameLogger.logWindowsUpdate(OkHttpFrameLogger.Direction.INBOUND, streamId, delta); + // delta == 0 checking is in HTTP/2 decoder. And it isn't quite right, as it will always cause + // a GOAWAY. RFC7540 section 6.9 says to use RST_STREAM if the stream id isn't 0. Doesn't + // matter much though. + synchronized (lock) { + if (streamId == Utils.CONNECTION_STREAM_ID) { + outboundFlow.windowUpdate(null, (int) delta); + } else { + StreamState stream = streams.get(streamId); + if (stream != null) { + outboundFlow.windowUpdate(stream.getOutboundFlowState(), (int) delta); + } + } + } + } + + @Override + public void priority(int streamId, int streamDependency, int weight, boolean exclusive) { + frameLogger.logPriority( + OkHttpFrameLogger.Direction.INBOUND, streamId, streamDependency, weight, exclusive); + // streamId == 0 checking is in HTTP/2 decoder. + // Ignore priority change. + } + + @Override + public void alternateService(int streamId, String origin, ByteString protocol, String host, + int port, long maxAge) {} + + /** + * Send GOAWAY to the server, then finish all streams and close the transport. RFC7540 §5.4.1. + */ + private void connectionError(ErrorCode errorCode, String moreDetail) { + Status status = GrpcUtil.Http2Error.statusForCode(errorCode.httpCode) + .withDescription(String.format("HTTP2 connection error: %s '%s'", errorCode, moreDetail)); + abruptShutdown(errorCode, moreDetail, status, false); + } + + /** + * Respond with RST_STREAM, making sure to kill the associated stream if it exists. Reason will + * rarely be seen. RFC7540 §5.4.2. + */ + private void streamError(int streamId, ErrorCode errorCode, String reason) { + if (errorCode == ErrorCode.PROTOCOL_ERROR) { + log.log( + Level.FINE, "Responding with RST_STREAM {0}: {1}", new Object[] {errorCode, reason}); + } + synchronized (lock) { + frameWriter.rstStream(streamId, errorCode); + frameWriter.flush(); + StreamState stream = streams.get(streamId); + if (stream != null) { + stream.transportReportStatus( + Status.INTERNAL.withDescription( + String.format("Responded with RST_STREAM %s: %s", errorCode, reason))); + streamClosed(streamId, /*flush=*/ false); + } + } + } + + private void respondWithHttpError( + int streamId, boolean inFinished, int httpCode, Status.Code statusCode, String msg) { + Metadata metadata = new Metadata(); + metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); + metadata.put(InternalStatus.MESSAGE_KEY, msg); + List
    headers = + Headers.createHttpResponseHeaders(httpCode, "text/plain; charset=utf-8", metadata); + Buffer data = new Buffer().writeUtf8(msg); + + synchronized (lock) { + Http2ErrorStreamState stream = + new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow); + if (streams.isEmpty()) { + keepAliveEnforcer.onTransportActive(); + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportActive(); + } + } + streams.put(streamId, stream); + if (inFinished) { + stream.inboundDataReceived(new Buffer(), 0, true); + } + frameWriter.headers(streamId, headers); + outboundFlow.data( + /*outFinished=*/true, stream.getOutboundFlowState(), data, /*flush=*/true); + outboundFlow.notifyWhenNoPendingData( + stream.getOutboundFlowState(), () -> rstOkAtEndOfHttpError(stream)); + } + } + + private void rstOkAtEndOfHttpError(Http2ErrorStreamState stream) { + synchronized (lock) { + if (!stream.hasReceivedEndOfStream()) { + frameWriter.rstStream(stream.streamId, ErrorCode.NO_ERROR); + } + streamClosed(stream.streamId, /*flush=*/ true); + } + } + + private void respondWithGrpcError( + int streamId, boolean inFinished, Status.Code statusCode, String msg) { + Metadata metadata = new Metadata(); + metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); + metadata.put(InternalStatus.MESSAGE_KEY, msg); + List
    headers = Headers.createResponseTrailers(metadata, false); + + synchronized (lock) { + frameWriter.synReply(true, streamId, headers); + if (!inFinished) { + frameWriter.rstStream(streamId, ErrorCode.NO_ERROR); + } + frameWriter.flush(); + } + } + } + + private final class KeepAlivePinger implements KeepAliveManager.KeepAlivePinger { + @Override + public void ping() { + synchronized (lock) { + frameWriter.ping(false, 0, KEEPALIVE_PING); + frameWriter.flush(); + } + tracer.reportKeepAliveSent(); + } + + @Override + public void onPingTimeout() { + synchronized (lock) { + goAwayStatus = Status.UNAVAILABLE + .withDescription("Keepalive failed. Considering connection dead"); + GrpcUtil.closeQuietly(bareSocket); + } + } + } + + interface StreamState { + /** Must be holding 'lock' when calling. */ + void inboundDataReceived(Buffer frame, int windowConsumed, boolean endOfStream); + + /** Must be holding 'lock' when calling. */ + boolean hasReceivedEndOfStream(); + + /** Must be holding 'lock' when calling. */ + int inboundWindowAvailable(); + + /** Must be holding 'lock' when calling. */ + void transportReportStatus(Status status); + + /** Must be holding 'lock' when calling. */ + void inboundRstReceived(Status status); + + OutboundFlowController.StreamState getOutboundFlowState(); + } + + static class Http2ErrorStreamState implements StreamState, OutboundFlowController.Stream { + private final int streamId; + private final Object lock; + private final OutboundFlowController.StreamState outboundFlowState; + @GuardedBy("lock") + private int window; + @GuardedBy("lock") + private boolean receivedEndOfStream; + + Http2ErrorStreamState( + int streamId, Object lock, OutboundFlowController outboundFlow, int initialWindowSize) { + this.streamId = streamId; + this.lock = lock; + this.outboundFlowState = outboundFlow.createState(this, streamId); + this.window = initialWindowSize; + } + + @Override public void onSentBytes(int frameBytes) {} + + @Override public void inboundDataReceived( + Buffer frame, int windowConsumed, boolean endOfStream) { + synchronized (lock) { + if (endOfStream) { + receivedEndOfStream = true; + } + window -= windowConsumed; + try { + frame.skip(frame.size()); // Recycle segments + } catch (IOException ex) { + throw new AssertionError(ex); + } + } + } + + @Override public boolean hasReceivedEndOfStream() { + synchronized (lock) { + return receivedEndOfStream; + } + } + + @Override public int inboundWindowAvailable() { + synchronized (lock) { + return window; + } + } + + @Override public void transportReportStatus(Status status) {} + + @Override public void inboundRstReceived(Status status) {} + + @Override public OutboundFlowController.StreamState getOutboundFlowState() { + synchronized (lock) { + return outboundFlowState; + } + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java index 5df85732ede..1406b39adfd 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java @@ -24,6 +24,8 @@ class OkHttpSettingsUtil { public static final int MAX_CONCURRENT_STREAMS = Settings.MAX_CONCURRENT_STREAMS; public static final int INITIAL_WINDOW_SIZE = Settings.INITIAL_WINDOW_SIZE; + public static final int MAX_HEADER_LIST_SIZE = Settings.MAX_HEADER_LIST_SIZE; + public static final int ENABLE_PUSH = Settings.ENABLE_PUSH; public static boolean isSet(Settings settings, int id) { return settings.isSet(id); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java index c935363213d..2c959ee0768 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java @@ -25,6 +25,8 @@ import com.google.common.base.Preconditions; import io.grpc.okhttp.internal.framed.FrameWriter; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import javax.annotation.Nullable; import okio.Buffer; @@ -33,17 +35,16 @@ * streams. */ class OutboundFlowController { - private final OkHttpClientTransport transport; + private final Transport transport; private final FrameWriter frameWriter; private int initialWindowSize; - private final OutboundFlowState connectionState; + private final StreamState connectionState; - OutboundFlowController( - OkHttpClientTransport transport, FrameWriter frameWriter) { + public OutboundFlowController(Transport transport, FrameWriter frameWriter) { this.transport = Preconditions.checkNotNull(transport, "transport"); this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); this.initialWindowSize = DEFAULT_WINDOW_SIZE; - connectionState = new OutboundFlowState(CONNECTION_STREAM_ID, DEFAULT_WINDOW_SIZE); + connectionState = new StreamState(CONNECTION_STREAM_ID, DEFAULT_WINDOW_SIZE, null); } /** @@ -55,22 +56,15 @@ class OutboundFlowController { * * @return true, if new window size is increased, false otherwise. */ - boolean initialOutboundWindowSize(int newWindowSize) { + public boolean initialOutboundWindowSize(int newWindowSize) { if (newWindowSize < 0) { throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize); } int delta = newWindowSize - initialWindowSize; initialWindowSize = newWindowSize; - for (OkHttpClientStream stream : transport.getActiveStreams()) { - OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState(); - if (state == null) { - // Create the OutboundFlowState with the new window size. - state = new OutboundFlowState(stream, initialWindowSize); - stream.setOutboundFlowState(state); - } else { - state.incrementStreamWindow(delta); - } + for (StreamState state : transport.getActiveStreams()) { + state.incrementStreamWindow(delta); } return delta > 0; @@ -82,15 +76,14 @@ boolean initialOutboundWindowSize(int newWindowSize) { * *

    Must be called with holding transport lock. */ - int windowUpdate(@Nullable OkHttpClientStream stream, int delta) { + public int windowUpdate(@Nullable StreamState state, int delta) { final int updatedWindow; - if (stream == null) { + if (state == null) { // Update the connection window and write any pending frames for all streams. updatedWindow = connectionState.incrementStreamWindow(delta); writeStreams(); } else { // Update the stream window and write any pending frames for the stream. - OutboundFlowState state = state(stream); updatedWindow = state.incrementStreamWindow(delta); WriteStatus writeStatus = new WriteStatus(); @@ -105,18 +98,9 @@ int windowUpdate(@Nullable OkHttpClientStream stream, int delta) { /** * Must be called with holding transport lock. */ - void data(boolean outFinished, int streamId, Buffer source, boolean flush) { + public void data(boolean outFinished, StreamState state, Buffer source, boolean flush) { Preconditions.checkNotNull(source, "source"); - OkHttpClientStream stream = transport.getStream(streamId); - if (stream == null) { - // This is possible for a stream that has received end-of-stream from server (but hasn't sent - // end-of-stream), and was removed from the transport stream map. - // In such case, we just throw away the data. - return; - } - - OutboundFlowState state = state(stream); int window = state.writableWindow(); boolean framesAlreadyQueued = state.hasPendingData(); int size = (int) source.size(); @@ -130,7 +114,7 @@ void data(boolean outFinished, int streamId, Buffer source, boolean flush) { state.write(source, window, false); } // Queue remaining data in the buffer - state.enqueue(source, (int) source.size(), outFinished); + state.enqueueData(source, (int) source.size(), outFinished); } if (flush) { @@ -138,7 +122,19 @@ void data(boolean outFinished, int streamId, Buffer source, boolean flush) { } } - void flush() { + /** + * Transport lock must be held when calling. + */ + public void notifyWhenNoPendingData(StreamState state, Runnable noPendingDataRunnable) { + Preconditions.checkNotNull(noPendingDataRunnable, "noPendingDataRunnable"); + if (state.hasPendingData()) { + state.notifyWhenNoPendingData(noPendingDataRunnable); + } else { + noPendingDataRunnable.run(); + } + } + + public void flush() { try { frameWriter.flush(); } catch (IOException e) { @@ -146,13 +142,9 @@ void flush() { } } - private OutboundFlowState state(OkHttpClientStream stream) { - OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState(); - if (state == null) { - state = new OutboundFlowState(stream, initialWindowSize); - stream.setOutboundFlowState(state); - } - return state; + public StreamState createState(Stream stream, int streamId) { + return new StreamState( + streamId, initialWindowSize, Preconditions.checkNotNull(stream, "stream")); } /** @@ -160,15 +152,15 @@ private OutboundFlowState state(OkHttpClientStream stream) { * *

    Must be called with holding transport lock. */ - void writeStreams() { - OkHttpClientStream[] streams = transport.getActiveStreams(); + public void writeStreams() { + StreamState[] states = transport.getActiveStreams(); + Collections.shuffle(Arrays.asList(states)); int connectionWindow = connectionState.window(); - for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) { + for (int numStreams = states.length; numStreams > 0 && connectionWindow > 0;) { int nextNumStreams = 0; int windowSlice = (int) ceil(connectionWindow / (float) numStreams); for (int index = 0; index < numStreams && connectionWindow > 0; ++index) { - OkHttpClientStream stream = streams[index]; - OutboundFlowState state = state(stream); + StreamState state = states[index]; int bytesForStream = min(connectionWindow, min(state.unallocatedBytes(), windowSlice)); if (bytesForStream > 0) { @@ -179,7 +171,7 @@ void writeStreams() { if (state.unallocatedBytes() > 0) { // There is more data to process for this stream. Add it to the next // pass. - streams[nextNumStreams++] = stream; + states[nextNumStreams++] = state; } } numStreams = nextNumStreams; @@ -187,8 +179,7 @@ void writeStreams() { // Now take one last pass through all of the streams and write any allocated bytes. WriteStatus writeStatus = new WriteStatus(); - for (OkHttpClientStream stream : transport.getActiveStreams()) { - OutboundFlowState state = state(stream); + for (StreamState state : transport.getActiveStreams()) { state.writeBytes(state.allocatedBytes(), writeStatus); state.clearAllocatedBytes(); } @@ -213,25 +204,29 @@ boolean hasWritten() { } } + public interface Transport { + StreamState[] getActiveStreams(); + } + + public interface Stream { + void onSentBytes(int frameBytes); + } + /** * The outbound flow control state for a single stream. */ - private final class OutboundFlowState { - final Buffer pendingWriteBuffer; - final int streamId; - int window; - int allocatedBytes; - OkHttpClientStream stream; - boolean pendingBufferHasEndOfStream = false; - - OutboundFlowState(int streamId, int initialWindowSize) { + public final class StreamState { + private final Buffer pendingWriteBuffer = new Buffer(); + private Runnable noPendingDataRunnable; + private final int streamId; + private int window; + private int allocatedBytes; + private final Stream stream; + private boolean pendingBufferHasEndOfStream = false; + + StreamState(int streamId, int initialWindowSize, Stream stream) { this.streamId = streamId; window = initialWindowSize; - pendingWriteBuffer = new Buffer(); - } - - OutboundFlowState(OkHttpClientStream stream, int initialWindowSize) { - this(stream.id(), initialWindowSize); this.stream = stream; } @@ -305,6 +300,10 @@ int writeBytes(int bytes, WriteStatus writeStatus) { // Update the threshold. maxBytes = min(bytes - bytesAttempted, writableWindow()); } + if (!hasPendingData() && noPendingDataRunnable != null) { + noPendingDataRunnable.run(); + noPendingDataRunnable = null; + } return bytesAttempted; } @@ -328,14 +327,20 @@ void write(Buffer buffer, int bytesToSend, boolean endOfStream) { } catch (IOException e) { throw new RuntimeException(e); } - stream.transportState().onSentBytes(frameBytes); + stream.onSentBytes(frameBytes); bytesToWrite -= frameBytes; } while (bytesToWrite > 0); } - void enqueue(Buffer buffer, int size, boolean endOfStream) { + void enqueueData(Buffer buffer, int size, boolean endOfStream) { this.pendingWriteBuffer.write(buffer, size); this.pendingBufferHasEndOfStream |= endOfStream; } + + void notifyWhenNoPendingData(Runnable noPendingDataRunnable) { + Preconditions.checkState( + this.noPendingDataRunnable == null, "pending data notification already requested"); + this.noPendingDataRunnable = noPendingDataRunnable; + } } -} \ No newline at end of file +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/PlaintextHandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/PlaintextHandshakerSocketFactory.java new file mode 100644 index 00000000000..5338536213f --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/PlaintextHandshakerSocketFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.SecurityLevel; +import io.grpc.internal.GrpcAttributes; +import java.io.IOException; +import java.net.Socket; + +/** + * No-thrills plaintext handshaker. + */ +final class PlaintextHandshakerSocketFactory implements HandshakerSocketFactory { + @Override + public HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException { + attributes = attributes.toBuilder() + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, socket.getLocalSocketAddress()) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, socket.getRemoteSocketAddress()) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) + .build(); + return new HandshakeResult(socket, attributes, null); + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java new file mode 100644 index 00000000000..63c6f33ff79 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java @@ -0,0 +1,60 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.ExperimentalApi; +import io.grpc.okhttp.internal.ConnectionSpec; +import javax.net.ssl.SSLSocketFactory; + +/** A credential with full control over the SSLSocketFactory. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") +public final class SslSocketFactoryServerCredentials { + private SslSocketFactoryServerCredentials() {} + + public static io.grpc.ServerCredentials create(SSLSocketFactory factory) { + return new ServerCredentials(factory); + } + + public static io.grpc.ServerCredentials create( + SSLSocketFactory factory, com.squareup.okhttp.ConnectionSpec connectionSpec) { + return new ServerCredentials(factory, Utils.convertSpec(connectionSpec)); + } + + // Hide implementation detail of how these credentials operate + static final class ServerCredentials extends io.grpc.ServerCredentials { + private final SSLSocketFactory factory; + private final ConnectionSpec connectionSpec; + + ServerCredentials(SSLSocketFactory factory) { + this(factory, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC); + } + + ServerCredentials(SSLSocketFactory factory, ConnectionSpec connectionSpec) { + this.factory = Preconditions.checkNotNull(factory, "factory"); + this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); + } + + public SSLSocketFactory getFactory() { + return factory; + } + + public ConnectionSpec getConnectionSpec() { + return connectionSpec; + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java new file mode 100644 index 00000000000..c375d6246cc --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.InternalChannelz; +import io.grpc.SecurityLevel; +import io.grpc.internal.GrpcAttributes; +import io.grpc.okhttp.internal.ConnectionSpec; +import io.grpc.okhttp.internal.Protocol; +import java.io.IOException; +import java.net.Socket; +import java.util.Arrays; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +/** + * TLS handshaker. + */ +final class TlsServerHandshakerSocketFactory implements HandshakerSocketFactory { + private final PlaintextHandshakerSocketFactory delegate = new PlaintextHandshakerSocketFactory(); + private final SSLSocketFactory socketFactory; + private final ConnectionSpec connectionSpec; + + public TlsServerHandshakerSocketFactory( + SslSocketFactoryServerCredentials.ServerCredentials credentials) { + this.socketFactory = credentials.getFactory(); + this.connectionSpec = credentials.getConnectionSpec(); + } + + @Override + public HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException { + HandshakeResult result = delegate.handshake(socket, attributes); + socket = socketFactory.createSocket(result.socket, null, -1, true); + if (!(socket instanceof SSLSocket)) { + throw new IOException( + "SocketFactory " + socketFactory + " did not produce an SSLSocket: " + socket.getClass()); + } + SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.setUseClientMode(false); + connectionSpec.apply(sslSocket, false); + Protocol expectedProtocol = Protocol.HTTP_2; + String negotiatedProtocol = OkHttpProtocolNegotiator.get().negotiate( + sslSocket, + null, + connectionSpec.supportsTlsExtensions() ? Arrays.asList(expectedProtocol) : null); + if (!expectedProtocol.toString().equals(negotiatedProtocol)) { + throw new IOException("Expected NPN/ALPN " + expectedProtocol + ": " + negotiatedProtocol); + } + attributes = result.attributes.toBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSocket.getSession()) + .build(); + return new HandshakeResult(socket, attributes, + new InternalChannelz.Security(new InternalChannelz.Tls(sslSocket.getSession()))); + } +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java index 2bda131b346..46011588b16 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java @@ -56,7 +56,7 @@ public class AsyncSinkTest { private final QueueingExecutor queueingExecutor = new QueueingExecutor(); private final TransportExceptionHandler exceptionHandler = mock(TransportExceptionHandler.class); private final AsyncSink sink = - AsyncSink.sink(new SerializingExecutor(queueingExecutor), exceptionHandler); + AsyncSink.sink(new SerializingExecutor(queueingExecutor), exceptionHandler, 10000); @Test public void noCoalesceRequired() throws IOException { diff --git a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java index c26edcd0df8..a9d39088844 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java @@ -50,8 +50,7 @@ public class ExceptionHandlingFrameWriterTest { private final TransportExceptionHandler transportExceptionHandler = mock(TransportExceptionHandler.class); private final ExceptionHandlingFrameWriter exceptionHandlingFrameWriter = - new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter, - new OkHttpFrameLogger(Level.FINE, logger)); + new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter); @Test public void exception() throws IOException { @@ -194,4 +193,4 @@ public void close() throws SecurityException { logger.removeHandler(handler); } -} \ No newline at end of file +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 0063fc82ca0..691b9b83656 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -258,9 +258,62 @@ public void sslSocketFactoryFrom_tls_mtls() throws Exception { } @Test - public void sslSocketFactoryFrom_tls_mtls_byteKeyUnsupported() throws Exception { + public void sslSocketFactoryFrom_tls_mtls_keyFile() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST); + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null); + keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()}); + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, new char[0]); + + KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType()); + certStore.load(null); + certStore.setCertificateEntry("mycert", cert.cert()); + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(certStore); + + SSLContext serverContext = SSLContext.getInstance("TLS"); + serverContext.init( + keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null); + final SSLServerSocket serverListenSocket = + (SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0); + serverListenSocket.setNeedClientAuth(true); + final SettableFuture serverSocket = SettableFuture.create(); + new Thread(new Runnable() { + @Override public void run() { + try { + SSLSocket socket = (SSLSocket) serverListenSocket.accept(); + socket.getSession(); // Force handshake + serverSocket.set(socket); + serverListenSocket.close(); + } catch (Throwable t) { + serverSocket.setException(t); + } + } + }).start(); + + ChannelCredentials creds = TlsChannelCredentials.newBuilder() + .keyManager(cert.certificate(), cert.privateKey()) + .trustManager(cert.certificate()) + .build(); + OkHttpChannelBuilder.SslSocketFactoryResult result = + OkHttpChannelBuilder.sslSocketFactoryFrom(creds); + SSLSocket socket = + (SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort()); + socket.getSession(); // Force handshake + assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName()) + .isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST); + socket.close(); + serverSocket.get().close(); + } + + @Test + public void sslSocketFactoryFrom_tls_mtls_passwordUnsupported() throws Exception { ChannelCredentials creds = TlsChannelCredentials.newBuilder() - .keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")) + .keyManager( + TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"), "password") .build(); OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.sslSocketFactoryFrom(creds); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 1628df4d3c3..969afc00d2d 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -21,7 +21,6 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.MISCARRIED; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.ClientStreamListener.RpcProgress.REFUSED; -import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER; import static io.grpc.okhttp.Headers.HTTP_SCHEME_HEADER; import static io.grpc.okhttp.Headers.METHOD_HEADER; @@ -46,8 +45,8 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; import com.google.common.base.Ticker; @@ -72,24 +71,28 @@ import io.grpc.internal.AbstractStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; +import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ManagedClientTransport; -import io.grpc.internal.TransportTracer; import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler; import io.grpc.okhttp.OkHttpFrameLogger.Direction; -import io.grpc.okhttp.internal.ConnectionSpec; +import io.grpc.okhttp.internal.Protocol; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; import io.grpc.okhttp.internal.framed.FrameWriter; import io.grpc.okhttp.internal.framed.Header; import io.grpc.okhttp.internal.framed.HeadersMode; import io.grpc.okhttp.internal.framed.Settings; +import io.grpc.okhttp.internal.framed.Variant; import io.grpc.testing.TestMethodDescriptors; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -113,9 +116,9 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.net.SocketFactory; -import javax.net.ssl.HostnameVerifier; -import javax.net.ssl.SSLSocketFactory; import okio.Buffer; +import okio.BufferedSink; +import okio.BufferedSource; import okio.ByteString; import org.junit.After; import org.junit.Before; @@ -145,7 +148,6 @@ public class OkHttpClientTransportTest { private static final Status SHUTDOWN_REASON = Status.UNAVAILABLE.withDescription("for test"); private static final HttpConnectProxiedSocketAddress NO_PROXY = null; private static final int DEFAULT_START_STREAM_ID = 3; - private static final int DEFAULT_MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE; private static final Attributes EAG_ATTRS = Attributes.EMPTY; private static final Logger logger = Logger.getLogger(OkHttpClientTransport.class.getName()); private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { @@ -154,39 +156,36 @@ public class OkHttpClientTransportTest { @Rule public final Timeout globalTimeout = Timeout.seconds(10); - private FrameWriter frameWriter; - private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @Mock private ManagedClientTransport.Listener transportListener; - private final SocketFactory socketFactory = null; - private final SSLSocketFactory sslSocketFactory = null; - private final HostnameVerifier hostnameVerifier = null; - private final TransportTracer transportTracer = new TransportTracer(); private final Queue capturedBuffer = new ArrayDeque<>(); private OkHttpClientTransport clientTransport; - private MockFrameReader frameReader; - private Socket socket; + private final MockFrameReader frameReader = new MockFrameReader(); + private final Socket socket = new MockSocket(frameReader); + private final FrameWriter frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo( + new MockFrameWriter(socket, capturedBuffer))); private ExecutorService executor = Executors.newCachedThreadPool(); private long nanoTime; // backs a ticker, for testing ping round-trip time measurement private SettableFuture connectedFuture; - private DelayConnectedCallback delayConnectedCallback; private Runnable tooManyPingsRunnable = new Runnable() { @Override public void run() { throw new AssertionError(); } }; + private OkHttpChannelBuilder channelBuilder = OkHttpChannelBuilder.forAddress("127.0.0.1", 1234) + .usePlaintext() + .executor(new FakeClock().getScheduledExecutorService()) // Executor unused + .scheduledExecutorService(new FakeClock().getScheduledExecutorService()) // Executor unused + .transportExecutor(executor) + .flowControlWindow(INITIAL_WINDOW_SIZE); /** Set up for test. */ @Before public void setUp() { MockitoAnnotations.initMocks(this); - frameReader = new MockFrameReader(); - socket = new MockSocket(frameReader); - MockFrameWriter mockFrameWriter = new MockFrameWriter(socket, capturedBuffer); - frameWriter = mock(FrameWriter.class, AdditionalAnswers.delegatesTo(mockFrameWriter)); } @After @@ -196,26 +195,15 @@ public void tearDown() { private void initTransport() throws Exception { startTransport( - DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, null); + DEFAULT_START_STREAM_ID, null, true, null); } private void initTransport(int startId) throws Exception { - startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, null); - } - - private void initTransportAndDelayConnected() throws Exception { - delayConnectedCallback = new DelayConnectedCallback(); - startTransport( - DEFAULT_START_STREAM_ID, - delayConnectedCallback, - false, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, - null); + startTransport(startId, null, true, null); } private void startTransport(int startId, @Nullable Runnable connectingCallback, - boolean waitingForConnected, int maxMessageSize, int initialWindowSize, String userAgent) + boolean waitingForConnected, String userAgent) throws Exception { connectedFuture = SettableFuture.create(); final Ticker ticker = new Ticker() { @@ -230,47 +218,35 @@ public Stopwatch get() { return Stopwatch.createUnstarted(ticker); } }; + channelBuilder.socketFactory(new FakeSocketFactory(socket)); clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), userAgent, - executor, - frameReader, - frameWriter, - new OkHttpFrameLogger(Level.FINE, logger), - startId, - socket, stopwatchSupplier, + new FakeVariant(frameReader, frameWriter), connectingCallback, connectedFuture, - maxMessageSize, - initialWindowSize, - tooManyPingsRunnable, - new TransportTracer()); + tooManyPingsRunnable); clientTransport.start(transportListener); if (waitingForConnected) { connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); } + if (startId != DEFAULT_START_STREAM_ID) { + clientTransport.setNextStreamId(startId); + } } @Test public void testToString() throws Exception { InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415); clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), address, "hostname", - /*agent=*/ null, + /*userAgent=*/ null, EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, NO_PROXY, - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - transportTracer, - false); + tooManyPingsRunnable); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains(address.toString())); @@ -301,6 +277,10 @@ public void close() throws SecurityException { logger.setLevel(Level.ALL); initTransport(); + assertThat(logs).hasSize(1); + LogRecord log = logs.remove(0); + assertThat(log.getMessage()).startsWith(Direction.OUTBOUND + " SETTINGS: ack=false"); + assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -310,7 +290,7 @@ public void close() throws SecurityException { frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); assertThat(logs).hasSize(1); - LogRecord log = logs.remove(0); + log = logs.remove(0); assertThat(log.getMessage()).startsWith(Direction.INBOUND + " HEADERS: streamId=" + 3); assertThat(log.getLevel()).isEqualTo(Level.FINE); @@ -387,8 +367,8 @@ public void close() throws SecurityException { @Test public void maxMessageSizeShouldBeEnforced() throws Exception { - // Allow the response payloads of up to 1 byte. - startTransport(3, null, true, 1, INITIAL_WINDOW_SIZE, null); + channelBuilder.maxInboundMessageSize(1); + initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = @@ -411,10 +391,8 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { @Test public void includeInitialWindowSizeInFirstSettings() throws Exception { - int initialWindowSize = 65535; - startTransport( - DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); - clientTransport.sendConnectionPrefaceAndSettings(); + channelBuilder.flowControlWindow(65535); + initTransport(); ArgumentCaptor settings = ArgumentCaptor.forClass(Settings.class); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); @@ -427,10 +405,8 @@ public void includeInitialWindowSizeInFirstSettings() throws Exception { */ @Test public void includeInitialWindowSizeInFirstSettings_largeWindowSize() throws Exception { - int initialWindowSize = 75535; // 65535 + 10000 - startTransport( - DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, initialWindowSize, null); - clientTransport.sendConnectionPrefaceAndSettings(); + channelBuilder.flowControlWindow(75535); // 65535 + 10000 + initTransport(); ArgumentCaptor settings = ArgumentCaptor.forClass(Settings.class); verify(frameWriter, timeout(TIME_OUT_MS)).settings(settings.capture()); @@ -697,7 +673,7 @@ public void addDefaultUserAgent() throws Exception { @Test public void overrideDefaultUserAgent() throws Exception { - startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "fakeUserAgent"); + startTransport(3, null, true, "fakeUserAgent"); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); @@ -751,24 +727,21 @@ public void writeMessage() throws Exception { public void transportTracer_windowSizeDefault() throws Exception { initTransport(); TransportStats stats = getTransportStats(clientTransport); - assertEquals(INITIAL_WINDOW_SIZE, stats.remoteFlowControlWindow); - // okhttp does not track local window sizes - assertEquals(-1, stats.localFlowControlWindow); + assertEquals(INITIAL_WINDOW_SIZE / 2, stats.remoteFlowControlWindow); // Lower bound + assertEquals(INITIAL_WINDOW_SIZE, stats.localFlowControlWindow); } @Test public void transportTracer_windowSize_remote() throws Exception { initTransport(); TransportStats before = getTransportStats(clientTransport); - assertEquals(INITIAL_WINDOW_SIZE, before.remoteFlowControlWindow); - // okhttp does not track local window sizes - assertEquals(-1, before.localFlowControlWindow); + assertEquals(INITIAL_WINDOW_SIZE / 2, before.remoteFlowControlWindow); // Lower bound + assertEquals(INITIAL_WINDOW_SIZE, before.localFlowControlWindow); frameHandler().windowUpdate(0, 1000); TransportStats after = getTransportStats(clientTransport); - assertEquals(INITIAL_WINDOW_SIZE + 1000, after.remoteFlowControlWindow); - // okhttp does not track local window sizes - assertEquals(-1, after.localFlowControlWindow); + assertEquals(INITIAL_WINDOW_SIZE / 2, after.remoteFlowControlWindow); + assertEquals(INITIAL_WINDOW_SIZE + 1000, after.localFlowControlWindow); } @Test @@ -1694,90 +1667,30 @@ public void ping_failsIfTransportFails() throws Exception { shutdownAndVerify(); } - @Test - public void writeBeforeConnected() throws Exception { - initTransportAndDelayConnected(); - final String message = "Hello Server"; - MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); - stream.start(listener); - InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); - stream.writeMessage(input); - stream.flush(); - // The message should be queued. - verifyNoMoreInteractions(frameWriter); - - allowTransportConnected(); - - // The queued message should be sent out. - verify(frameWriter, timeout(TIME_OUT_MS)) - .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); - Buffer sentFrame = capturedBuffer.poll(); - assertEquals(createMessageFrame(message), sentFrame); - stream.cancel(Status.CANCELLED); - shutdownAndVerify(); - } - - @Test - public void cancelBeforeConnected() throws Exception { - initTransportAndDelayConnected(); - final String message = "Hello Server"; - MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); - stream.start(listener); - InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); - stream.writeMessage(input); - stream.flush(); - stream.cancel(Status.CANCELLED); - verifyNoMoreInteractions(frameWriter); - - allowTransportConnected(); - verifyNoMoreInteractions(frameWriter); - shutdownAndVerify(); - } - @Test public void shutdownDuringConnecting() throws Exception { - initTransportAndDelayConnected(); - MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); - stream.start(listener); + SettableFuture delayed = SettableFuture.create(); + Runnable connectingCallback = () -> Futures.getUnchecked(delayed); + startTransport( + DEFAULT_START_STREAM_ID, + connectingCallback, + false, + null); clientTransport.shutdown(SHUTDOWN_REASON); - allowTransportConnected(); - - // The new stream should be failed, but not the pending stream. - assertNewStreamFail(); - verify(frameWriter, timeout(TIME_OUT_MS)) - .synStream(anyBoolean(), anyBoolean(), eq(3), anyInt(), anyListHeader()); - assertEquals(1, activeStreamCount()); - stream.cancel(Status.CANCELLED); - listener.waitUntilStreamClosed(); - assertEquals(Status.CANCELLED.getCode(), listener.status.getCode()); + delayed.set(null); shutdownAndVerify(); } @Test public void invalidAuthorityPropagates() { clientTransport = new OkHttpClientTransport( - new InetSocketAddress("host", 1234), + channelBuilder.buildTransportFactory(), + new InetSocketAddress("localhost", 1234), "invalid_authority", "userAgent", EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, NO_PROXY, - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - transportTracer, - false); + tooManyPingsRunnable); String host = clientTransport.getOverridenHost(); int port = clientTransport.getOverridenPort(); @@ -1789,22 +1702,13 @@ public void invalidAuthorityPropagates() { @Test public void unreachableServer() throws Exception { clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), new InetSocketAddress("localhost", 0), "authority", "userAgent", EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, NO_PROXY, - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - new TransportTracer(), - false); + tooManyPingsRunnable); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1828,22 +1732,13 @@ public void customSocketFactory() throws Exception { clientTransport = new OkHttpClientTransport( + channelBuilder.socketFactory(socketFactory).buildTransportFactory(), new InetSocketAddress("localhost", 0), "authority", "userAgent", EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, NO_PROXY, - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - new TransportTracer(), - false); + tooManyPingsRunnable); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1859,25 +1754,16 @@ public void proxy_200() throws Exception { ServerSocket serverSocket = new ServerSocket(0); InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), targetAddress, "authority", "userAgent", EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, HttpConnectProxiedSocketAddress.newBuilder() .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - transportTracer, - false); + tooManyPingsRunnable); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1917,25 +1803,16 @@ public void proxy_500() throws Exception { ServerSocket serverSocket = new ServerSocket(0); InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), targetAddress, "authority", "userAgent", EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, HttpConnectProxiedSocketAddress.newBuilder() .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - transportTracer, - false); + tooManyPingsRunnable); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1974,25 +1851,16 @@ public void proxy_immediateServerClose() throws Exception { ServerSocket serverSocket = new ServerSocket(0); InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), targetAddress, "authority", "userAgent", EAG_ATTRS, - executor, - socketFactory, - sslSocketFactory, - hostnameVerifier, - ConnectionSpec.CLEARTEXT, - DEFAULT_MAX_MESSAGE_SIZE, - INITIAL_WINDOW_SIZE, HttpConnectProxiedSocketAddress.newBuilder() .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable, - DEFAULT_MAX_INBOUND_METADATA_SIZE, - transportTracer, - false); + tooManyPingsRunnable); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -2009,6 +1877,37 @@ public void proxy_immediateServerClose() throws Exception { verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } + @Test + public void proxy_serverHangs() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); + clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), + targetAddress, + "authority", + "userAgent", + EAG_ATTRS, + HttpConnectProxiedSocketAddress.newBuilder() + .setTargetAddress(targetAddress) + .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) + .build(), + tooManyPingsRunnable); + clientTransport.proxySocketTimeout = 10; + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(sock.getInputStream(), UTF_8)); + assertEquals("CONNECT theservice:80 HTTP/1.1", reader.readLine()); + assertEquals("Host: theservice:80", reader.readLine()); + while (!"".equals(reader.readLine())) {} + + verify(transportListener, timeout(200)).transportShutdown(any(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + sock.close(); + } + @Test public void goAway_notUtf8() throws Exception { initTransport(); @@ -2385,10 +2284,20 @@ static String getContent(InputStream message) { } private static class MockSocket extends Socket { - MockFrameReader frameReader; + final MockFrameReader frameReader; + private final PipedOutputStream outputStream = new PipedOutputStream(); + private final PipedInputStream outputStreamSink = new PipedInputStream(); + private final PipedOutputStream inputStreamSource = new PipedOutputStream(); + private final PipedInputStream inputStream = new PipedInputStream(); MockSocket(MockFrameReader frameReader) { this.frameReader = frameReader; + try { + outputStreamSink.connect(outputStream); + inputStream.connect(inputStreamSource); + } catch (IOException ex) { + throw new AssertionError(ex); + } } @Override @@ -2400,6 +2309,16 @@ public synchronized void close() { public SocketAddress getLocalSocketAddress() { return InetSocketAddress.createUnresolved("localhost", 4000); } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public InputStream getInputStream() { + return inputStream; + } } static class PingCallbackImpl implements ClientTransport.PingCallback { @@ -2420,10 +2339,6 @@ public void onFailure(Throwable cause) { } } - private void allowTransportConnected() { - delayConnectedCallback.allowConnected(); - } - private void shutdownAndVerify() { clientTransport.shutdown(SHUTDOWN_REASON); assertEquals(0, activeStreamCount()); @@ -2435,19 +2350,6 @@ private void shutdownAndVerify() { frameReader.assertClosed(); } - private static class DelayConnectedCallback implements Runnable { - SettableFuture delayed = SettableFuture.create(); - - @Override - public void run() { - Futures.getUnchecked(delayed); - } - - void allowConnected() { - delayed.set(null); - } - } - private static TransportStats getTransportStats(InternalInstrumented obj) throws ExecutionException, InterruptedException { return obj.getStats().get().data; @@ -2466,10 +2368,6 @@ public MockFrameWriter(Socket socket, Queue capturedBuffer) { this.capturedBuffer = capturedBuffer; } - void setSocket(Socket socket) { - this.socket = socket; - } - @Override public void close() throws IOException { socket.close(); @@ -2559,4 +2457,65 @@ public Socket createSocket(InetAddress inetAddress, int i, InetAddress inetAddre throw exception; } } + + static class FakeSocketFactory extends SocketFactory { + private Socket socket; + + public FakeSocketFactory(Socket socket) { + this.socket = Preconditions.checkNotNull(socket, "socket"); + } + + @Override public Socket createSocket() { + Preconditions.checkNotNull(this.socket, "socket"); + Socket socket = this.socket; + this.socket = null; + return socket; + } + + @Override public Socket createSocket(InetAddress host, int port) { + return createSocket(); + } + + @Override public Socket createSocket( + InetAddress host, int port, InetAddress localAddress, int localPort) { + return createSocket(); + } + + @Override public Socket createSocket(String host, int port) { + return createSocket(); + } + + @Override public Socket createSocket( + String host, int port, InetAddress localHost, int localPort) { + return createSocket(); + } + } + + static class FakeVariant implements Variant { + private FrameReader frameReader; + private FrameWriter frameWriter; + + public FakeVariant(FrameReader frameReader, FrameWriter frameWriter) { + this.frameReader = frameReader; + this.frameWriter = frameWriter; + } + + @Override public Protocol getProtocol() { + return Protocol.HTTP_2; + } + + @Override public FrameReader newReader(BufferedSource source, boolean client) { + Preconditions.checkNotNull(this.frameReader, "frameReader"); + FrameReader frameReader = this.frameReader; + this.frameReader = null; + return frameReader; + } + + @Override public FrameWriter newWriter(BufferedSink sink, boolean client) { + Preconditions.checkNotNull(this.frameWriter, "frameWriter"); + FrameWriter frameWriter = this.frameWriter; + this.frameWriter = null; + return frameWriter; + } + } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java new file mode 100644 index 00000000000..af9b7c12d54 --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -0,0 +1,1421 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Charsets.UTF_8; +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER; +import static io.grpc.okhttp.Headers.HTTP_SCHEME_HEADER; +import static io.grpc.okhttp.Headers.METHOD_HEADER; +import static io.grpc.okhttp.Headers.TE_HEADER; +import static org.mockito.AdditionalAnswers.answerVoid; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import com.google.common.io.ByteStreams; +import io.grpc.Attributes; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListener; +import io.grpc.internal.ServerTransportListener; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameReader; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.HeadersMode; +import io.grpc.okhttp.internal.framed.Http2; +import io.grpc.okhttp.internal.framed.Settings; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import okio.Buffer; +import okio.BufferedSource; +import okio.ByteString; +import okio.Okio; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +/** + * Tests for {@link OkHttpServerTransport}. + */ +@RunWith(JUnit4.class) +public class OkHttpServerTransportTest { + private static final int TIME_OUT_MS = 2000; + private static final int INITIAL_WINDOW_SIZE = 65535; + private static final long MAX_CONNECTION_IDLE = TimeUnit.SECONDS.toNanos(1); + + private MockServerTransportListener mockTransportListener = new MockServerTransportListener(); + private ServerTransportListener transportListener + = mock(ServerTransportListener.class, delegatesTo(mockTransportListener)); + private OkHttpServerTransport serverTransport; + private final PipeSocket socket = new PipeSocket(); + private final FrameWriter clientFrameWriter + = new Http2().newWriter(Okio.buffer(Okio.sink(socket.inputStreamSource)), true); + private final FrameReader clientFrameReader + = new Http2().newReader(Okio.buffer(Okio.source(socket.outputStreamSink)), true); + private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class); + private final DataFrameHandler clientDataFrames = mock(DataFrameHandler.class); + private ExecutorService threadPool = Executors.newCachedThreadPool(); + private HandshakerSocketFactory handshakerSocketFactory + = mock(HandshakerSocketFactory.class, delegatesTo(new PlaintextHandshakerSocketFactory())); + private final FakeClock fakeClock = new FakeClock(); + private OkHttpServerBuilder serverBuilder + = new OkHttpServerBuilder(new InetSocketAddress(1234), handshakerSocketFactory) + .executor(new FakeClock().getScheduledExecutorService()) // Executor unused + .scheduledExecutorService(fakeClock.getScheduledExecutorService()) + .transportExecutor(new Executor() { + @Override public void execute(Runnable runnable) { + if (runnable instanceof OkHttpServerTransport.FrameHandler) { + threadPool.execute(runnable); + } else { + // Writing is buffered in the PipeSocket, so AsyncSinc can be executed immediately + runnable.run(); + } + } + }) + .flowControlWindow(INITIAL_WINDOW_SIZE) + .maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS) + .permitKeepAliveWithoutCalls(true) + .permitKeepAliveTime(0, TimeUnit.SECONDS); + + @Rule public final Timeout globalTimeout = Timeout.seconds(10); + + @Before + public void setUp() throws Exception { + doAnswer(answerVoid((Boolean outDone, Integer streamId, BufferedSource in, Integer length) -> { + in.require(length); + Buffer buf = new Buffer(); + buf.write(in.getBuffer(), length); + clientDataFrames.data(outDone, streamId, buf); + })).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt()); + } + + @After + public void tearDown() throws Exception { + threadPool.shutdownNow(); + socket.closeSourceAndSink(); + } + + @Test + public void startThenShutdown() throws Exception { + initTransport(); + handshake(); + shutdownAndTerminate(/*lastStreamId=*/ 0); + } + + @Test + public void maxConnectionAge() throws Exception { + serverBuilder.maxConnectionAge(5, TimeUnit.SECONDS) + .maxConnectionAgeGrace(1, TimeUnit.SECONDS); + initTransport(); + handshake(); + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.synStream(true, false, 1, -1, Arrays.asList( + new Header("some-client-sent-trailer", "trailer-value"))); + pingPong(); + fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(6)); // > 1.1 * 5 + fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(1)); + verifyGracefulShutdown(1); + } + + @Test + public void maxConnectionIdleTimer() throws Exception { + initTransport(); + handshake(); + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.synStream(true, false, 1, -1, Arrays.asList( + new Header("some-client-sent-trailer", "trailer-value"))); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.messages.peek()).isNull(); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.close(Status.OK, new Metadata()); + + List

    responseTrailers = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER, + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + fakeClock.forwardNanos(MAX_CONNECTION_IDLE); + fakeClock.forwardNanos(MAX_CONNECTION_IDLE); + verifyGracefulShutdown(1); + } + + @Test + public void maxConnectionIdleTimer_respondWithError() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("host", "example.com:80"), + new Header("host", "example.com:80"))); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 400, Status.Code.INTERNAL, "Multiple host headers disallowed. RFC7230 section 5.4"); + + pingPong(); + fakeClock.forwardNanos(MAX_CONNECTION_IDLE); + fakeClock.forwardNanos(MAX_CONNECTION_IDLE); + verifyGracefulShutdown(1); + } + + @Test + public void startThenShutdownTwice() throws Exception { + initTransport(); + handshake(); + serverTransport.shutdown(); + shutdownAndTerminate(/*lastStreamId=*/ 0); + } + + @Test + public void shutdownDuringHandshake() throws Exception { + doAnswer(invocation -> { + socket.getInputStream().read(); + throw new IOException("handshake purposefully failed"); + }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); + serverBuilder.transportExecutor(threadPool); + initTransport(); + serverTransport.shutdown(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void shutdownNowDuringHandshake() throws Exception { + doAnswer(invocation -> { + socket.getInputStream().read(); + throw new IOException("handshake purposefully failed"); + }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); + serverBuilder.transportExecutor(threadPool); + initTransport(); + serverTransport.shutdownNow(Status.UNAVAILABLE.withDescription("shutdown now")); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void clientCloseDuringHandshake() throws Exception { + doAnswer(invocation -> { + socket.getInputStream().read(); + throw new IOException("handshake purposefully failed"); + }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); + serverBuilder.transportExecutor(threadPool); + initTransport(); + socket.close(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void closeDuringHttp2Preface() throws Exception { + initTransport(); + socket.close(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void noSettingsDuringHttp2HandshakeSettings() throws Exception { + initTransport(); + clientFrameWriter.connectionPreface(); + clientFrameWriter.flush(); + socket.close(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void noSettingsDuringHttp2Handshake() throws Exception { + initTransport(); + clientFrameWriter.connectionPreface(); + clientFrameWriter.ping(false, 0, 0x1234); + clientFrameWriter.flush(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void startThenClientDisconnect() throws Exception { + initTransport(); + handshake(); + + socket.closeSourceAndSink(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void basicRpc_succeeds() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + assertThat(streamListener.method).isEqualTo("com.example/SimpleService.doit"); + assertThat(streamListener.headers.get( + Metadata.Key.of("Some-Metadata", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("this could be anything"); + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello server"); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.writeHeaders(metadata("User-Data", "best data")); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.close(Status.OK, metadata("End-Metadata", "bye")); + + List
    responseHeaders = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER, + new Header("user-data", "best data")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseMessageFrame = createMessageFrame("Howdy client"); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size())); + verify(clientDataFrames).data(false, 1, responseMessageFrame); + + List
    responseTrailers = Arrays.asList( + new Header("end-metadata", "bye"), + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + SocketStats stats = serverTransport.getStats().get(); + assertThat(stats.data.streamsStarted).isEqualTo(1); + assertThat(stats.data.streamsSucceeded).isEqualTo(1); + assertThat(stats.data.streamsFailed).isEqualTo(0); + assertThat(stats.data.messagesSent).isEqualTo(1); + assertThat(stats.data.messagesReceived).isEqualTo(1); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void activeRpc_delaysShutdownTermination() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + pingPong(); + + serverTransport.shutdown(); + verifyGracefulShutdown(1); + verify(transportListener, never()).transportTerminated(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello server"); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.writeHeaders(new Metadata()); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.flush(); + + List
    responseHeaders = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseMessageFrame = createMessageFrame("Howdy client"); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size())); + verify(clientDataFrames).data(false, 1, responseMessageFrame); + pingPong(); + assertThat(serverTransport.getActiveStreams().length).isEqualTo(1); + verify(transportListener, never()).transportTerminated(); + + streamListener.stream.close(Status.OK, new Metadata()); + List
    responseTrailers = Arrays.asList( + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + assertThat(serverTransport.getActiveStreams().length).isEqualTo(0); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void headersForStream0_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(0, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.INTERNAL_ERROR, + ByteString.encodeUtf8("Error in frame decoder")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void headersForEvenStream_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(2, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8("Clients cannot open even numbered streams. RFC7540 section 5.1.1")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void headersTooLarge_failsWith431() throws Exception { + initTransport(); + handshake(); + + StringBuilder largeString = new StringBuilder(); + for (int i = 0; i < 100; i++) { + largeString.append( + "Row, row, row your boat, gently down the stream. Merrily, merrily, merrily, merrily, " + + "life is but a dream. "); + } + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("too-large", largeString.toString()))); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 431, Status.Code.RESOURCE_EXHAUSTED, "Request metadata larger than 8192: 10953"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void invalidPseudoHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header(":status", "999"), // Invalid for requests + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void multipleAuthorityHeaders_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_AUTHORITY, "example.com:8080"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void pseudoHeaderAfterRegularHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + CONTENT_TYPE_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingSchemeHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void connectionHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + new Header("connection", "content-type"), + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void trailersAfterEndStream_failsWithRst() throws Exception { + initTransport(); + handshake(); + + List
    headers = Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER); + clientFrameWriter.synStream(true, false, 1, -1, headers); + clientFrameWriter.synStream(true, false, 1, -1, headers); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.STREAM_CLOSED); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.status).isNotNull(); + assertThat(streamListener.status.getCode()).isNotEqualTo(Status.Code.OK); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void trailers_endStream() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.synStream(true, false, 1, -1, Arrays.asList( + new Header("some-client-sent-trailer", "trailer-value"))); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.messages.peek()).isNull(); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.close(Status.OK, new Metadata()); + + List
    responseTrailers = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER, + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void headersInMiddleOfRequest_failsWithRst() throws Exception { + initTransport(); + handshake(); + + List
    headers = Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER); + clientFrameWriter.headers(1, headers); + clientFrameWriter.headers(1, headers); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.status).isNotNull(); + assertThat(streamListener.status.getCode()).isNotEqualTo(Status.Code.OK); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void multipleHostHeaders_failsWith400() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("host", "example.com:80"), + new Header("host", "example.com:80"))); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 400, Status.Code.INTERNAL, "Multiple host headers disallowed. RFC7230 section 5.4"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void hostWithoutAuthority_usesHost() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header("host", "example.com:80"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.rstStream(1, ErrorCode.CANCEL); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void authorityAndHost_usesAuthority() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header("host", "example2.com:8080"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.rstStream(1, ErrorCode.CANCEL); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingAuthorityAndHost_hasNullAuthority() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.rstStream(1, ErrorCode.CANCEL); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isNull(); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void emptyPath_failsWith404() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, ""), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 404, Status.Code.UNIMPLEMENTED, "Expected path to start with /: "); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void nonAbsolutePath_failsWith404() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "https://example.com/"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 404, Status.Code.UNIMPLEMENTED, "Expected path to start with /: https://example.com/"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingContentType_failsWith415() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 415, Status.Code.INTERNAL, "Content-Type is missing or duplicated"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void repeatedContentType_failsWith415() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 415, Status.Code.INTERNAL, "Content-Type is missing or duplicated"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void textContentType_failsWith415() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header("content-type", "text/plain"), + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 415, Status.Code.INTERNAL, "Content-Type is not supported: text/plain"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void httpGet_failsWith405() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + new Header(":method", "GET"), + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 405, Status.Code.INTERNAL, "HTTP Method is not supported: GET"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingTeTrailers_failsWithInternal() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER)); + clientFrameWriter.flush(); + + List
    responseHeaders = Arrays.asList( + new Header(":status", "200"), + new Header("content-type", "application/grpc"), + new Header("grpc-status", "" + Status.Code.INTERNAL.value()), + new Header("grpc-message", "Expected header TE: trailers, but is received. " + + "Some intermediate proxy may not support trailers")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.NO_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void httpErrorsAdhereToFlowControl() throws Exception { + Settings settings = new Settings(); + OkHttpSettingsUtil.set(settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE, 1); + + initTransport(); + handshake(settings); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + new Header(":method", "GET"), // Invalid + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + String errorDescription = "HTTP Method is not supported: GET"; + List
    responseHeaders = Arrays.asList( + new Header(":status", "405"), + new Header("content-type", "text/plain; charset=utf-8"), + new Header("grpc-status", "" + Status.Code.INTERNAL.value()), + new Header("grpc-message", errorDescription)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(0, 1)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).data( + eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size())); + verify(clientDataFrames).data(false, 1, responseDataFrame); + + clientFrameWriter.windowUpdate(1, 1000); + clientFrameWriter.flush(); + + responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(1)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).data( + eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size())); + verify(clientDataFrames).data(true, 1, responseDataFrame); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.NO_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void dataForStream0_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + Buffer requestMessageFrame = createMessageFrame("Nope"); + clientFrameWriter.data(true, 0, requestMessageFrame, (int) requestMessageFrame.size()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8("Stream 0 is reserved for control messages. RFC7540 section 5.1.1")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void dataForEvenStream_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + Buffer requestMessageFrame = createMessageFrame("Nope"); + clientFrameWriter.data(true, 2, requestMessageFrame, (int) requestMessageFrame.size()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8("Clients cannot open even numbered streams. RFC7540 section 5.1.1")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void dataAfterHalfClose_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + requestMessageFrame = createMessageFrame("oh, I forgot"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.STREAM_CLOSED); + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + pingPong(); + assertThat(streamListener.status).isNotNull(); + assertThat(streamListener.status.getCode()).isNotEqualTo(Status.Code.OK); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void pushPromise_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.pushPromise(2, 3, Arrays.asList()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8( + "PUSH_PROMISE only allowed on peer-initiated streams. RFC7540 section 6.6")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void channelzStats() throws Exception { + serverBuilder.flowControlWindow(60000); + initTransport(); + handshake(); + clientFrameWriter.windowUpdate(0, 1000); // connection stream id + pingPong(); + + SocketStats stats = serverTransport.getStats().get(); + assertThat(stats.data.streamsStarted).isEqualTo(0); + assertThat(stats.data.streamsSucceeded).isEqualTo(0); + assertThat(stats.data.streamsFailed).isEqualTo(0); + assertThat(stats.data.messagesSent).isEqualTo(0); + assertThat(stats.data.messagesReceived).isEqualTo(0); + assertThat(stats.data.remoteFlowControlWindow).isEqualTo(30000); // Lower bound + assertThat(stats.data.localFlowControlWindow).isEqualTo(66535); + assertThat(stats.local).isEqualTo(new InetSocketAddress("127.0.0.1", 4000)); + assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000)); + } + + @Test + public void keepAliveEnforcer_enforcesPings() throws Exception { + serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) { + pingPong(); + } + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM, + ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)); + } + + @Test + public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception { + serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(false, 1, requestMessageFrame, (int) requestMessageFrame.size()); + pingPong(); + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello server"); + + streamListener.stream.writeHeaders(metadata("User-Data", "best data")); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + + for (int i = 0; i < 10; i++) { + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + pingPong(); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.flush(); + } + } + + @Test + public void keepAliveEnforcer_initialIdle() throws Exception { + serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) { + pingPong(); + } + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM, + ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)); + } + + @Test + public void keepAliveEnforcer_noticesActive() throws Exception { + serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + for (int i = 0; i < 10; i++) { + pingPong(); + } + verify(clientFramesRead, never()).goAway(anyInt(), eq(ErrorCode.ENHANCE_YOUR_CALM), + eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII))); + } + + private void initTransport() throws Exception { + serverTransport = new OkHttpServerTransport( + new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()), + socket); + serverTransport.start(transportListener); + } + + private void handshake() throws Exception { + handshake(new Settings()); + } + + private void handshake(Settings settings) throws Exception { + clientFrameWriter.connectionPreface(); + clientFrameWriter.settings(settings); + clientFrameWriter.flush(); + clientFrameReader.readConnectionPreface(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + ArgumentCaptor settingsCaptor = ArgumentCaptor.forClass(Settings.class); + verify(clientFramesRead).settings(eq(false), settingsCaptor.capture()); + clientFrameWriter.ackSettings(settingsCaptor.getValue()); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).ackSettings(); + verify(transportListener, timeout(TIME_OUT_MS)).transportReady(any(Attributes.class)); + } + + private static Buffer createMessageFrame(String stringMessage) { + byte[] message = stringMessage.getBytes(UTF_8); + Buffer buffer = new Buffer(); + buffer.writeByte(0 /* UNCOMPRESSED */); + buffer.writeInt(message.length); + buffer.write(message); + return buffer; + } + + private Metadata metadata(String... keysAndValues) { + Metadata metadata = new Metadata(); + assertThat(keysAndValues.length % 2).isEqualTo(0); + for (int i = 0; i < keysAndValues.length; i += 2) { + metadata.put( + Metadata.Key.of(keysAndValues[i], Metadata.ASCII_STRING_MARSHALLER), + keysAndValues[i + 1]); + } + return metadata; + } + + private void verifyGracefulShutdown(int lastStreamId) + throws IOException { + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(2147483647, ErrorCode.NO_ERROR, ByteString.EMPTY); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).ping(false, 0, 0x1111); + clientFrameWriter.ping(true, 0, 0x1111); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(lastStreamId, ErrorCode.NO_ERROR, ByteString.EMPTY); + } + + private void shutdownAndTerminate(int lastStreamId) throws IOException { + assertThat(serverTransport.getActiveStreams().length).isEqualTo(0); + serverTransport.shutdown(); + verifyGracefulShutdown(lastStreamId); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + private int pingPongId = 0; + + /** Send a ping and wait for the ping ack. */ + private void pingPong() throws IOException { + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).ping(true, pingPongId, 0); + } + + private void verifyHttpError( + int streamId, int httpCode, Status.Code grpcCode, String errorDescription) throws Exception { + List
    responseHeaders = Arrays.asList( + new Header(":status", "" + httpCode), + new Header("content-type", "text/plain; charset=utf-8"), + new Header("grpc-status", "" + grpcCode.value()), + new Header("grpc-message", errorDescription)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, streamId, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).data( + eq(true), eq(streamId), any(BufferedSource.class), eq((int) responseDataFrame.size())); + verify(clientDataFrames).data(true, streamId, responseDataFrame); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(streamId, ErrorCode.NO_ERROR); + } + + private static class MockServerTransportListener implements ServerTransportListener { + Deque newStreams = new ArrayDeque<>(); + + @Override public void streamCreated(ServerStream stream, String method, Metadata headers) { + MockStreamListener streamListener = new MockStreamListener(stream, method, headers); + stream.setListener(streamListener); + newStreams.add(streamListener); + } + + @Override public Attributes transportReady(Attributes attributes) { + return attributes; + } + + @Override public void transportTerminated() {} + } + + private static class MockStreamListener implements ServerStreamListener { + final ServerStream stream; + final String method; + final Metadata headers; + + Deque messages = new ArrayDeque<>(); + boolean halfClosedCalled; + boolean onReadyCalled; + Status status; + CountDownLatch closed = new CountDownLatch(1); + + MockStreamListener(ServerStream stream, String method, Metadata headers) { + this.stream = stream; + this.method = method; + this.headers = headers; + } + + @Override + public void messagesAvailable(MessageProducer producer) { + InputStream inputStream; + while ((inputStream = producer.next()) != null) { + try { + String msg = getContent(inputStream); + if (msg != null) { + messages.add(msg); + } + } catch (IOException ex) { + while ((inputStream = producer.next()) != null) { + GrpcUtil.closeQuietly(inputStream); + } + throw new RuntimeException(ex); + } + } + } + + @Override + public void halfClosed() { + halfClosedCalled = true; + } + + @Override + public void closed(Status status) { + this.status = status; + closed.countDown(); + } + + @Override + public void onReady() { + onReadyCalled = true; + } + + boolean isOnReadyCalled() { + boolean value = onReadyCalled; + onReadyCalled = false; + return value; + } + + void waitUntilStreamClosed() throws InterruptedException, TimeoutException { + if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) { + throw new TimeoutException("Failed waiting stream to be closed."); + } + } + + static String getContent(InputStream message) throws IOException { + try { + return new String(ByteStreams.toByteArray(message), UTF_8); + } finally { + message.close(); + } + } + } + + private static class PipeSocket extends Socket { + private final PipedOutputStream outputStream = new PipedOutputStream(); + private final PipedInputStream outputStreamSink = new PipedInputStream(); + private final PipedOutputStream inputStreamSource = new PipedOutputStream(); + private final PipedInputStream inputStream = new PipedInputStream(); + + public PipeSocket() { + try { + outputStreamSink.connect(outputStream); + inputStream.connect(inputStreamSource); + } catch (IOException ex) { + throw new AssertionError(ex); + } + } + + @Override + public synchronized void close() throws IOException { + try { + outputStream.close(); + } finally { + inputStream.close(); + // PipedInputStream can only be woken by PipedOutputStream, so PipedOutputStream.close() is + // a better imitation of Socket.close(). + inputStreamSource.close(); + } + } + + public void closeSourceAndSink() throws IOException { + try { + outputStreamSink.close(); + } finally { + inputStreamSource.close(); + } + } + + @Override + public SocketAddress getLocalSocketAddress() { + return new InetSocketAddress("127.0.0.1", 4000); + } + + @Override + public SocketAddress getRemoteSocketAddress() { + return new InetSocketAddress("127.0.0.2", 5000); + } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public InputStream getInputStream() { + return inputStream; + } + } + + private interface DataFrameHandler { + void data(boolean inFinished, int streamId, Buffer payload); + } +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java index 44e493c259f..076eea3349a 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java @@ -16,6 +16,7 @@ package io.grpc.okhttp; +import io.grpc.InsecureServerCredentials; import io.grpc.ServerStreamTracer; import io.grpc.internal.AbstractTransportTest; import io.grpc.internal.ClientTransportFactory; @@ -23,8 +24,6 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; -import io.grpc.netty.InternalNettyServerBuilder; -import io.grpc.netty.NettyServerBuilder; import java.net.InetSocketAddress; import java.util.List; import java.util.concurrent.TimeUnit; @@ -53,21 +52,17 @@ public void releaseClientFactory() { @Override protected InternalServer newServer( List streamTracerFactories) { - NettyServerBuilder builder = NettyServerBuilder - .forPort(0) - .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW); - InternalNettyServerBuilder.setTransportTracerFactory(builder, fakeClockTransportTracer); - return InternalNettyServerBuilder.buildTransportServers(builder, streamTracerFactories); + return newServer(0, streamTracerFactories); } @Override protected InternalServer newServer( int port, List streamTracerFactories) { - NettyServerBuilder builder = NettyServerBuilder - .forAddress(new InetSocketAddress(port)) - .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW); - InternalNettyServerBuilder.setTransportTracerFactory(builder, fakeClockTransportTracer); - return InternalNettyServerBuilder.buildTransportServers(builder, streamTracerFactories); + return OkHttpServerBuilder + .forPort(port, InsecureServerCredentials.create()) + .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW) + .setTransportTracerFactory(fakeClockTransportTracer) + .buildTransportServers(streamTracerFactories); } @Override @@ -100,11 +95,4 @@ protected long fakeCurrentTimeNanos() { protected boolean haveTransportTracer() { return true; } - - @Override - @org.junit.Test - @org.junit.Ignore - public void clientChecksInboundMetadataSize_trailer() { - // Server-side is flaky due to https://github.com/netty/netty/pull/8332 - } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java new file mode 100644 index 00000000000..1871ec83f88 --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java @@ -0,0 +1,277 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; + +import com.google.common.base.Throwables; +import io.grpc.ChannelCredentials; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.ServerCredentials; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.internal.testing.TestUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TlsTesting; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.IOException; +import java.io.InputStream; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Verify OkHttp's TLS integration. */ +@RunWith(JUnit4.class) +public class TlsTest { + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + + @Before + public void checkForAlpnApi() throws Exception { + // This checks for the "Java 9 ALPN API" which was backported to Java 8u252. The Kokoro Windows + // CI is on too old of a JDK for us to assume this is available. + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, null, null); + SSLEngine engine = context.createSSLEngine(); + try { + SSLEngine.class.getMethod("getApplicationProtocol").invoke(engine); + } catch (NoSuchMethodException | UnsupportedOperationException ex) { + Assume.assumeNoException(ex); + } + } + + @Test + public void basicTls_succeeds() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); + } + + @Test + public void mtls_succeeds() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key"); + InputStream caCert = TlsTesting.loadCert("ca.pem")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .trustManager(caCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream clientCertChain = TlsTesting.loadCert("client.pem"); + InputStream clientPrivateKey = TlsTesting.loadCert("client.key"); + InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .keyManager(clientCertChain, clientPrivateKey) + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); + } + + @Test + public void untrustedClient_fails() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key"); + InputStream caCert = TlsTesting.loadCert("ca.pem")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .trustManager(caCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream clientCertChain = TlsTesting.loadCert("badclient.pem"); + InputStream clientPrivateKey = TlsTesting.loadCert("badclient.key"); + InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .keyManager(clientCertChain, clientPrivateKey) + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + assertRpcFails(channel); + } + + @Test + public void missingOptionalClientCert_succeeds() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key"); + InputStream caCert = TlsTesting.loadCert("ca.pem")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .trustManager(caCert) + .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); + } + + @Test + public void missingRequiredClientCert_fails() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key"); + InputStream caCert = TlsTesting.loadCert("ca.pem")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .trustManager(caCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + assertRpcFails(channel); + } + + @Test + public void untrustedServer_fails() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds = TlsChannelCredentials.create(); + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + assertRpcFails(channel); + } + + @Test + public void unmatchedServerSubjectAlternativeNames_fails() throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannelBuilder(server, channelCreds) + .overrideAuthority("notgonnamatch.example.com") + .build()); + + assertRpcFails(channel); + } + + private static Server server(ServerCredentials creds) throws IOException { + return OkHttpServerBuilder.forPort(0, creds) + .directExecutor() + .addService(new SimpleServiceImpl()) + .build() + .start(); + } + + private static ManagedChannelBuilder clientChannelBuilder( + Server server, ChannelCredentials creds) { + return OkHttpChannelBuilder.forAddress("localhost", server.getPort(), creds) + .directExecutor() + .overrideAuthority(TestUtils.TEST_SERVER_HOST); + } + + private static ManagedChannel clientChannel(Server server, ChannelCredentials creds) { + return clientChannelBuilder(server, creds).build(); + } + + private static void assertRpcFails(ManagedChannel channel) { + SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel); + try { + stub.unaryRpc(SimpleRequest.getDefaultInstance()); + assertWithMessage("TLS handshake should have failed, but didn't; received RPC response") + .fail(); + } catch (StatusRuntimeException e) { + assertWithMessage(Throwables.getStackTraceAsString(e)) + .that(e.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + } + // We really want to see TRANSIENT_FAILURE here, but if the test runs slowly the 1s backoff + // may be exceeded by the time the failure happens (since it counts from the start of the + // attempt). Even so, CONNECTING is a strong indicator that the handshake failed; otherwise we'd + // expect READY or IDLE. + assertThat(channel.getState(false)) + .isAnyOf(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING); + } + + private static final class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest req, StreamObserver respOb) { + respOb.onNext(SimpleResponse.getDefaultInstance()); + respOb.onCompleted(); + } + } +} diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Credentials.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Credentials.java new file mode 100644 index 00000000000..08a46ada7a7 --- /dev/null +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Credentials.java @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Forked from OkHttp 2.7.0 + */ +package io.grpc.okhttp.internal; + +import java.io.UnsupportedEncodingException; +import okio.ByteString; + +/** Factory for HTTP authorization credentials. */ +public final class Credentials { + private Credentials() { + } + + /** Returns an auth credential for the Basic scheme. */ + public static String basic(String userName, String password) { + try { + String usernameAndPassword = userName + ":" + password; + byte[] bytes = usernameAndPassword.getBytes("ISO-8859-1"); + String encoded = ByteString.of(bytes).base64(); + return "Basic " + encoded; + } catch (UnsupportedEncodingException e) { + throw new AssertionError(); + } + } +} diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Headers.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Headers.java new file mode 100644 index 00000000000..115f9964008 --- /dev/null +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Headers.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Forked from OkHttp 2.7.0 com.squareup.okhttp.Headers + */ +package io.grpc.okhttp.internal; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +/** + * The header fields of a single HTTP message. Values are uninterpreted strings; + * + *

    This class trims whitespace from values. It never returns values with + * leading or trailing whitespace. + * + *

    Instances of this class are immutable. Use {@link Builder} to create + * instances. + */ +public final class Headers { + private final String[] namesAndValues; + + private Headers(Builder builder) { + this.namesAndValues = builder.namesAndValues.toArray(new String[builder.namesAndValues.size()]); + } + + /** Returns the last value corresponding to the specified field, or null. */ + public String get(String name) { + return get(namesAndValues, name); + } + + /** Returns the number of field values. */ + public int size() { + return namesAndValues.length / 2; + } + + /** Returns the field at {@code position} or null if that is out of range. */ + public String name(int index) { + int nameIndex = index * 2; + if (nameIndex < 0 || nameIndex >= namesAndValues.length) { + return null; + } + return namesAndValues[nameIndex]; + } + + /** Returns the value at {@code index} or null if that is out of range. */ + public String value(int index) { + int valueIndex = index * 2 + 1; + if (valueIndex < 0 || valueIndex >= namesAndValues.length) { + return null; + } + return namesAndValues[valueIndex]; + } + + public Builder newBuilder() { + Builder result = new Builder(); + Collections.addAll(result.namesAndValues, namesAndValues); + return result; + } + + @Override public String toString() { + StringBuilder result = new StringBuilder(); + for (int i = 0, size = size(); i < size; i++) { + result.append(name(i)).append(": ").append(value(i)).append("\n"); + } + return result.toString(); + } + + private static String get(String[] namesAndValues, String name) { + for (int i = namesAndValues.length - 2; i >= 0; i -= 2) { + if (name.equalsIgnoreCase(namesAndValues[i])) { + return namesAndValues[i + 1]; + } + } + return null; + } + + public static final class Builder { + private final List namesAndValues = new ArrayList<>(20); + + /** + * Add a field with the specified value without any validation. Only + * appropriate for headers from the remote peer or cache. + */ + Builder addLenient(String name, String value) { + namesAndValues.add(name); + namesAndValues.add(value.trim()); + return this; + } + + public Builder removeAll(String name) { + for (int i = 0; i < namesAndValues.size(); i += 2) { + if (name.equalsIgnoreCase(namesAndValues.get(i))) { + namesAndValues.remove(i); // name + namesAndValues.remove(i); // value + i -= 2; + } + } + return this; + } + + /** + * Set a field with the specified value. If the field is not found, it is + * added. If the field is found, the existing values are replaced. + */ + public Builder set(String name, String value) { + checkNameAndValue(name, value); + removeAll(name); + addLenient(name, value); + return this; + } + + private void checkNameAndValue(String name, String value) { + if (name == null) throw new IllegalArgumentException("name == null"); + if (name.isEmpty()) throw new IllegalArgumentException("name is empty"); + for (int i = 0, length = name.length(); i < length; i++) { + char c = name.charAt(i); + if (c <= '\u001f' || c >= '\u007f') { + throw new IllegalArgumentException(String.format( + Locale.US, + "Unexpected char %#04x at %d in header name: %s", (int) c, i, name)); + } + } + if (value == null) throw new IllegalArgumentException("value == null"); + for (int i = 0, length = value.length(); i < length; i++) { + char c = value.charAt(i); + if (c <= '\u001f' || c >= '\u007f') { + throw new IllegalArgumentException(String.format( + Locale.US, + "Unexpected char %#04x at %d in header value: %s", (int) c, i, value)); + } + } + } + + public Headers build() { + return new Headers(this); + } + } +} diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/StatusLine.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/StatusLine.java new file mode 100644 index 00000000000..ab72ee2d294 --- /dev/null +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/StatusLine.java @@ -0,0 +1,87 @@ +/* + * Forked from OkHttp 2.7.0 + */ +package io.grpc.okhttp.internal; + +import java.io.IOException; +import java.net.ProtocolException; + +/** An HTTP response status line like "HTTP/1.1 200 OK". */ +public final class StatusLine { + /** Numeric status code, 307: Temporary Redirect. */ + public static final int HTTP_TEMP_REDIRECT = 307; + public static final int HTTP_PERM_REDIRECT = 308; + public static final int HTTP_CONTINUE = 100; + + public final Protocol protocol; + public final int code; + public final String message; + + public StatusLine(Protocol protocol, int code, String message) { + this.protocol = protocol; + this.code = code; + this.message = message; + } + + public static StatusLine parse(String statusLine) throws IOException { + // H T T P / 1 . 1 2 0 0 T e m p o r a r y R e d i r e c t + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 + + // Parse protocol like "HTTP/1.1" followed by a space. + int codeStart; + Protocol protocol; + if (statusLine.startsWith("HTTP/1.")) { + if (statusLine.length() < 9 || statusLine.charAt(8) != ' ') { + throw new ProtocolException("Unexpected status line: " + statusLine); + } + int httpMinorVersion = statusLine.charAt(7) - '0'; + codeStart = 9; + if (httpMinorVersion == 0) { + protocol = Protocol.HTTP_1_0; + } else if (httpMinorVersion == 1) { + protocol = Protocol.HTTP_1_1; + } else { + throw new ProtocolException("Unexpected status line: " + statusLine); + } + } else if (statusLine.startsWith("ICY ")) { + // Shoutcast uses ICY instead of "HTTP/1.0". + protocol = Protocol.HTTP_1_0; + codeStart = 4; + } else { + throw new ProtocolException("Unexpected status line: " + statusLine); + } + + // Parse response code like "200". Always 3 digits. + if (statusLine.length() < codeStart + 3) { + throw new ProtocolException("Unexpected status line: " + statusLine); + } + int code; + try { + code = Integer.parseInt(statusLine.substring(codeStart, codeStart + 3)); + } catch (NumberFormatException e) { + throw new ProtocolException("Unexpected status line: " + statusLine); + } + + // Parse an optional response message like "OK" or "Not Modified". If it + // exists, it is separated from the response code by a space. + String message = ""; + if (statusLine.length() > codeStart + 3) { + if (statusLine.charAt(codeStart + 3) != ' ') { + throw new ProtocolException("Unexpected status line: " + statusLine); + } + message = statusLine.substring(codeStart + 4); + } + + return new StatusLine(protocol, code, message); + } + + @Override public String toString() { + StringBuilder result = new StringBuilder(); + result.append(protocol == Protocol.HTTP_1_0 ? "HTTP/1.0" : "HTTP/1.1"); + result.append(' ').append(code); + if (message != null) { + result.append(' ').append(message); + } + return result.toString(); + } +} diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Util.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Util.java index 556d849c705..cad74a026b1 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Util.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Util.java @@ -19,31 +19,15 @@ package io.grpc.okhttp.internal; -import java.io.Closeable; -import java.io.IOException; -import java.io.InterruptedIOException; -import java.io.UnsupportedEncodingException; import java.lang.reflect.Array; -import java.net.ServerSocket; -import java.net.Socket; import java.nio.charset.Charset; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; -import okio.Buffer; -import okio.ByteString; -import okio.Source; /** Junk drawer of utility methods. */ public final class Util { - public static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; public static final String[] EMPTY_STRING_ARRAY = new String[0]; /** A cheap and type-safe constant for the UTF-8 Charset. */ @@ -52,192 +36,16 @@ public final class Util { private Util() { } - public static void checkOffsetAndCount(long arrayLength, long offset, long count) { - if ((offset | count) < 0 || offset > arrayLength || arrayLength - offset < count) { - throw new ArrayIndexOutOfBoundsException(); - } - } - /** Returns true if two possibly-null objects are equal. */ public static boolean equal(Object a, Object b) { return a == b || (a != null && a.equals(b)); } - /** - * Closes {@code closeable}, ignoring any checked exceptions. Does nothing - * if {@code closeable} is null. - */ - public static void closeQuietly(Closeable closeable) { - if (closeable != null) { - try { - closeable.close(); - } catch (RuntimeException rethrown) { - throw rethrown; - } catch (Exception ignored) { - // The method is defined to ignore checked exceptions - } - } - } - - /** - * Closes {@code socket}, ignoring any checked exceptions. Does nothing if - * {@code socket} is null. - */ - public static void closeQuietly(Socket socket) { - if (socket != null) { - try { - socket.close(); - } catch (AssertionError e) { - if (!isAndroidGetsocknameError(e)) throw e; - } catch (RuntimeException rethrown) { - throw rethrown; - } catch (Exception ignored) { - // The method is defined to ignore checked exceptions - } - } - } - - /** - * Closes {@code serverSocket}, ignoring any checked exceptions. Does nothing if - * {@code serverSocket} is null. - */ - public static void closeQuietly(ServerSocket serverSocket) { - if (serverSocket != null) { - try { - serverSocket.close(); - } catch (RuntimeException rethrown) { - throw rethrown; - } catch (Exception ignored) { - // The method is defined to ignore checked exceptions - } - } - } - - /** - * Closes {@code a} and {@code b}. If either close fails, this completes - * the other close and rethrows the first encountered exception. - */ - public static void closeAll(Closeable a, Closeable b) throws IOException { - Throwable thrown = null; - try { - a.close(); - } catch (Throwable e) { - thrown = e; - } - try { - b.close(); - } catch (Throwable e) { - if (thrown == null) thrown = e; - } - if (thrown == null) return; - if (thrown instanceof IOException) throw (IOException) thrown; - if (thrown instanceof RuntimeException) throw (RuntimeException) thrown; - if (thrown instanceof Error) throw (Error) thrown; - throw new AssertionError(thrown); - } - - /** - * Attempts to exhaust {@code source}, returning true if successful. This is useful when reading - * a complete source is helpful, such as when doing so completes a cache body or frees a socket - * connection for reuse. - */ - public static boolean discard(Source source, int timeout, TimeUnit timeUnit) { - try { - return skipAll(source, timeout, timeUnit); - } catch (IOException e) { - return false; - } - } - - /** - * Reads until {@code in} is exhausted or the deadline has been reached. This is careful to not - * extend the deadline if one exists already. - */ - public static boolean skipAll(Source source, int duration, TimeUnit timeUnit) throws IOException { - long now = System.nanoTime(); - long originalDuration = source.timeout().hasDeadline() - ? source.timeout().deadlineNanoTime() - now - : Long.MAX_VALUE; - source.timeout().deadlineNanoTime(now + Math.min(originalDuration, timeUnit.toNanos(duration))); - try { - Buffer skipBuffer = new Buffer(); - while (source.read(skipBuffer, 2048) != -1) { - skipBuffer.clear(); - } - return true; // Success! The source has been exhausted. - } catch (InterruptedIOException e) { - return false; // We ran out of time before exhausting the source. - } finally { - if (originalDuration == Long.MAX_VALUE) { - source.timeout().clearDeadline(); - } else { - source.timeout().deadlineNanoTime(now + originalDuration); - } - } - } - - /** Returns a 32 character string containing an MD5 hash of {@code s}. */ - public static String md5Hex(String s) { - try { - MessageDigest messageDigest = MessageDigest.getInstance("MD5"); - byte[] md5bytes = messageDigest.digest(s.getBytes("UTF-8")); - return ByteString.of(md5bytes).hex(); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); - } catch (UnsupportedEncodingException e) { - throw new AssertionError(e); - } - } - - /** Returns a Base 64-encoded string containing a SHA-1 hash of {@code s}. */ - public static String shaBase64(String s) { - try { - MessageDigest messageDigest = MessageDigest.getInstance("SHA-1"); - byte[] sha1Bytes = messageDigest.digest(s.getBytes("UTF-8")); - return ByteString.of(sha1Bytes).base64(); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); - } catch (UnsupportedEncodingException e) { - throw new AssertionError(e); - } - } - - /** Returns a SHA-1 hash of {@code s}. */ - public static ByteString sha1(ByteString s) { - try { - MessageDigest messageDigest = MessageDigest.getInstance("SHA-1"); - byte[] sha1Bytes = messageDigest.digest(s.toByteArray()); - return ByteString.of(sha1Bytes); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); - } - } - - /** Returns an immutable copy of {@code list}. */ - public static List immutableList(List list) { - return Collections.unmodifiableList(new ArrayList<>(list)); - } - /** Returns an immutable list containing {@code elements}. */ public static List immutableList(T[] elements) { return Collections.unmodifiableList(Arrays.asList(elements.clone())); } - /** Returns an immutable copy of {@code map}. */ - public static Map immutableMap(Map map) { - return Collections.unmodifiableMap(new LinkedHashMap<>(map)); - } - - public static ThreadFactory threadFactory(final String name, final boolean daemon) { - return new ThreadFactory() { - @Override public Thread newThread(Runnable runnable) { - Thread result = new Thread(runnable, name); - result.setDaemon(daemon); - return result; - } - }; - } - /** * Returns an array containing containing only elements found in {@code first} and also in * {@code second}. The returned elements are in the same order as in {@code first}. @@ -264,30 +72,4 @@ private static List intersect(T[] first, T[] second) { } return result; } - - /** Returns {@code s} with control characters and non-ASCII characters replaced with '?'. */ - public static String toHumanReadableAscii(String s) { - for (int i = 0, length = s.length(), c; i < length; i += Character.charCount(c)) { - c = s.codePointAt(i); - if (c > '\u001f' && c < '\u007f') continue; - - Buffer buffer = new Buffer(); - buffer.writeUtf8(s, 0, i); - for (int j = i; j < length; j += Character.charCount(c)) { - c = s.codePointAt(j); - buffer.writeUtf8CodePoint(c > '\u001f' && c < '\u007f' ? c : '?'); - } - return buffer.readUtf8(); - } - return s; - } - - /** - * Returns true if {@code e} is due to a firmware bug fixed after Android 4.2.2. - * https://code.google.com/p/android/issues/detail?id=54072 - */ - public static boolean isAndroidGetsocknameError(AssertionError e) { - return e.getCause() != null && e.getMessage() != null - && e.getMessage().contains("getsockname failed"); - } } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java index 197a7f72fc8..3a8c41c6285 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java @@ -23,6 +23,7 @@ import io.grpc.okhttp.internal.Protocol; import java.io.IOException; import java.util.List; +import java.util.Locale; import java.util.logging.Logger; import okio.Buffer; @@ -231,6 +232,7 @@ private void readData(Handler handler, int length, byte flags, int streamId) short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0; length = lengthWithoutPadding(length, flags, padding); + // FIXME: pass padding length to handler because it should be included for flow control handler.data(inFinished, streamId, source, length); source.skip(padding); } @@ -589,12 +591,12 @@ void frameHeader(int streamId, int length, byte type, byte flags) throws IOExcep @FormatMethod private static IllegalArgumentException illegalArgument(String message, Object... args) { - throw new IllegalArgumentException(format(message, args)); + throw new IllegalArgumentException(format(Locale.US, message, args)); } @FormatMethod private static IOException ioException(String message, Object... args) throws IOException { - throw new IOException(format(message, args)); + throw new IOException(format(Locale.US, message, args)); } /** @@ -683,7 +685,7 @@ static final class FrameLogger { static String formatHeader(boolean inbound, int streamId, int length, byte type, byte flags) { String formattedType = type < TYPES.length ? TYPES[type] : format("0x%02x", type); String formattedFlags = formatFlags(type, flags); - return format("%s 0x%08x %5d %-13s %s", inbound ? "<<" : ">>", streamId, length, + return format(Locale.US, "%s 0x%08x %5d %-13s %s", inbound ? "<<" : ">>", streamId, length, formattedType, formattedFlags); } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java index 0d0ecce9982..591b59129ed 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java @@ -46,7 +46,7 @@ public final class Settings { /** spdy/3: Sender's estimate of max outgoing kbps. */ static final int DOWNLOAD_BANDWIDTH = 2; /** HTTP/2: The peer must not send a PUSH_PROMISE frame when this is 0. */ - static final int ENABLE_PUSH = 2; + public static final int ENABLE_PUSH = 2; /** spdy/3: Sender's estimate of millis between sending a request and receiving a response. */ static final int ROUND_TRIP_TIME = 3; /** Sender's maximum number of concurrent streams. */ @@ -58,7 +58,7 @@ public final class Settings { /** spdy/3: Retransmission rate. Percentage */ static final int DOWNLOAD_RETRANS_RATE = 6; /** HTTP/2: Advisory only. Size in bytes of the largest header list the sender will accept. */ - static final int MAX_HEADER_LIST_SIZE = 6; + public static final int MAX_HEADER_LIST_SIZE = 6; /** Window size in bytes. */ public static final int INITIAL_WINDOW_SIZE = 7; /** spdy/3: Size of the client certificate vector. Unsupported. */ diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/proxy/HttpUrl.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/proxy/HttpUrl.java new file mode 100644 index 00000000000..1a03812f26c --- /dev/null +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/proxy/HttpUrl.java @@ -0,0 +1,477 @@ +/* + * Copyright (C) 2015 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Forked from OkHttp 2.7.0 com.squareup.okhttp.HttpUrl + */ +package io.grpc.okhttp.internal.proxy; + +import java.io.EOFException; +import java.net.IDN; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.Locale; +import okio.Buffer; + +/** + * Helper class to build a proxy URL. + */ +public final class HttpUrl { + private static final char[] HEX_DIGITS = + { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' }; + + /** Either "http" or "https". */ + private final String scheme; + + /** Canonical hostname. */ + private final String host; + + /** Either 80, 443 or a user-specified port. In range [1..65535]. */ + private final int port; + + /** Canonical URL. */ + private final String url; + + private HttpUrl(Builder builder) { + this.scheme = builder.scheme; + this.host = builder.host; + this.port = builder.effectivePort(); + this.url = builder.toString(); + } + + /** Returns either "http" or "https". */ + public String scheme() { + return scheme; + } + + public boolean isHttps() { + return scheme.equals("https"); + } + + /** + * Returns the host address suitable for use with {@link InetAddress#getAllByName(String)}. May + * be: + *

      + *
    • A regular host name, like {@code android.com}. + *
    • An IPv4 address, like {@code 127.0.0.1}. + *
    • An IPv6 address, like {@code ::1}. Note that there are no square braces. + *
    • An encoded IDN, like {@code xn--n3h.net}. + *
    + */ + public String host() { + return host; + } + + /** + * Returns the explicitly-specified port if one was provided, or the default port for this URL's + * scheme. For example, this returns 8443 for {@code https://square.com:8443/} and 443 for {@code + * https://square.com/}. The result is in {@code [1..65535]}. + */ + public int port() { + return port; + } + + /** + * Returns 80 if {@code scheme.equals("http")}, 443 if {@code scheme.equals("https")} and -1 + * otherwise. + */ + public static int defaultPort(String scheme) { + if (scheme.equals("http")) { + return 80; + } else if (scheme.equals("https")) { + return 443; + } else { + return -1; + } + } + + public Builder newBuilder() { + Builder result = new Builder(); + result.scheme = scheme; + result.host = host; + // If we're set to a default port, unset it in case of a scheme change. + result.port = port != defaultPort(scheme) ? port : -1; + return result; + } + + @Override public boolean equals(Object o) { + return o instanceof HttpUrl && ((HttpUrl) o).url.equals(url); + } + + @Override public int hashCode() { + return url.hashCode(); + } + + @Override public String toString() { + return url; + } + + public static final class Builder { + String scheme; + String host; + int port = -1; + + public Builder() { + } + + public Builder scheme(String scheme) { + if (scheme == null) { + throw new IllegalArgumentException("scheme == null"); + } else if (scheme.equalsIgnoreCase("http")) { + this.scheme = "http"; + } else if (scheme.equalsIgnoreCase("https")) { + this.scheme = "https"; + } else { + throw new IllegalArgumentException("unexpected scheme: " + scheme); + } + return this; + } + + /** + * @param host either a regular hostname, International Domain Name, IPv4 address, or IPv6 + * address. + */ + public Builder host(String host) { + if (host == null) throw new IllegalArgumentException("host == null"); + String encoded = canonicalizeHost(host, 0, host.length()); + if (encoded == null) throw new IllegalArgumentException("unexpected host: " + host); + this.host = encoded; + return this; + } + + public Builder port(int port) { + if (port <= 0 || port > 65535) throw new IllegalArgumentException("unexpected port: " + port); + this.port = port; + return this; + } + + int effectivePort() { + return port != -1 ? port : defaultPort(scheme); + } + + public HttpUrl build() { + if (scheme == null) throw new IllegalStateException("scheme == null"); + if (host == null) throw new IllegalStateException("host == null"); + return new HttpUrl(this); + } + + @Override public String toString() { + StringBuilder result = new StringBuilder(); + result.append(scheme); + result.append("://"); + + if (host.indexOf(':') != -1) { + // Host is an IPv6 address. + result.append('['); + result.append(host); + result.append(']'); + } else { + result.append(host); + } + + int effectivePort = effectivePort(); + if (effectivePort != defaultPort(scheme)) { + result.append(':'); + result.append(effectivePort); + } + + return result.toString(); + } + + + private static String canonicalizeHost(String input, int pos, int limit) { + // Start by percent decoding the host. The WHATWG spec suggests doing this only after we've + // checked for IPv6 square braces. But Chrome does it first, and that's more lenient. + String percentDecoded = percentDecode(input, pos, limit, false); + + // If the input is encased in square braces "[...]", drop 'em. We have an IPv6 address. + if (percentDecoded.startsWith("[") && percentDecoded.endsWith("]")) { + InetAddress inetAddress = decodeIpv6(percentDecoded, 1, percentDecoded.length() - 1); + if (inetAddress == null) return null; + byte[] address = inetAddress.getAddress(); + if (address.length == 16) return inet6AddressToAscii(address); + throw new AssertionError(); + } + + return domainToAscii(percentDecoded); + } + + /** Decodes an IPv6 address like 1111:2222:3333:4444:5555:6666:7777:8888 or ::1. */ + private static InetAddress decodeIpv6(String input, int pos, int limit) { + byte[] address = new byte[16]; + int b = 0; + int compress = -1; + int groupOffset = -1; + + for (int i = pos; i < limit; ) { + if (b == address.length) return null; // Too many groups. + + // Read a delimiter. + if (i + 2 <= limit && input.regionMatches(i, "::", 0, 2)) { + // Compression "::" delimiter, which is anywhere in the input, including its prefix. + if (compress != -1) return null; // Multiple "::" delimiters. + i += 2; + b += 2; + compress = b; + if (i == limit) break; + } else if (b != 0) { + // Group separator ":" delimiter. + if (input.regionMatches(i, ":", 0, 1)) { + i++; + } else if (input.regionMatches(i, ".", 0, 1)) { + // If we see a '.', rewind to the beginning of the previous group and parse as IPv4. + if (!decodeIpv4Suffix(input, groupOffset, limit, address, b - 2)) return null; + b += 2; // We rewound two bytes and then added four. + break; + } else { + return null; // Wrong delimiter. + } + } + + // Read a group, one to four hex digits. + int value = 0; + groupOffset = i; + for (; i < limit; i++) { + char c = input.charAt(i); + int hexDigit = decodeHexDigit(c); + if (hexDigit == -1) break; + value = (value << 4) + hexDigit; + } + int groupLength = i - groupOffset; + if (groupLength == 0 || groupLength > 4) return null; // Group is the wrong size. + + // We've successfully read a group. Assign its value to our byte array. + address[b++] = (byte) ((value >>> 8) & 0xff); + address[b++] = (byte) (value & 0xff); + } + + // All done. If compression happened, we need to move bytes to the right place in the + // address. Here's a sample: + // + // input: "1111:2222:3333::7777:8888" + // before: { 11, 11, 22, 22, 33, 33, 00, 00, 77, 77, 88, 88, 00, 00, 00, 00 } + // compress: 6 + // b: 10 + // after: { 11, 11, 22, 22, 33, 33, 00, 00, 00, 00, 00, 00, 77, 77, 88, 88 } + // + if (b != address.length) { + if (compress == -1) return null; // Address didn't have compression or enough groups. + System.arraycopy(address, compress, address, address.length - (b - compress), b - compress); + Arrays.fill(address, compress, compress + (address.length - b), (byte) 0); + } + + try { + return InetAddress.getByAddress(address); + } catch (UnknownHostException e) { + throw new AssertionError(); + } + } + + /** Decodes an IPv4 address suffix of an IPv6 address, like 1111::5555:6666:192.168.0.1. */ + private static boolean decodeIpv4Suffix( + String input, int pos, int limit, byte[] address, int addressOffset) { + int b = addressOffset; + + for (int i = pos; i < limit; ) { + if (b == address.length) return false; // Too many groups. + + // Read a delimiter. + if (b != addressOffset) { + if (input.charAt(i) != '.') return false; // Wrong delimiter. + i++; + } + + // Read 1 or more decimal digits for a value in 0..255. + int value = 0; + int groupOffset = i; + for (; i < limit; i++) { + char c = input.charAt(i); + if (c < '0' || c > '9') break; + if (value == 0 && groupOffset != i) return false; // Reject unnecessary leading '0's. + value = (value * 10) + c - '0'; + if (value > 255) return false; // Value out of range. + } + int groupLength = i - groupOffset; + if (groupLength == 0) return false; // No digits. + + // We've successfully read a byte. + address[b++] = (byte) value; + } + + if (b != addressOffset + 4) return false; // Too few groups. We wanted exactly four. + return true; // Success. + } + + /** + * Performs IDN ToASCII encoding and canonicalize the result to lowercase. e.g. This converts + * {@code ☃.net} to {@code xn--n3h.net}, and {@code WwW.GoOgLe.cOm} to {@code www.google.com}. + * {@code null} will be returned if the input cannot be ToASCII encoded or if the result + * contains unsupported ASCII characters. + */ + private static String domainToAscii(String input) { + try { + String result = IDN.toASCII(input).toLowerCase(Locale.US); + if (result.isEmpty()) return null; + + // Confirm that the IDN ToASCII result doesn't contain any illegal characters. + if (containsInvalidHostnameAsciiCodes(result)) { + return null; + } + // TODO: implement all label limits. + return result; + } catch (IllegalArgumentException e) { + return null; + } + } + + private static boolean containsInvalidHostnameAsciiCodes(String hostnameAscii) { + for (int i = 0; i < hostnameAscii.length(); i++) { + char c = hostnameAscii.charAt(i); + // The WHATWG Host parsing rules accepts some character codes which are invalid by + // definition for OkHttp's host header checks (and the WHATWG Host syntax definition). Here + // we rule out characters that would cause problems in host headers. + if (c <= '\u001f' || c >= '\u007f') { + return true; + } + // Check for the characters mentioned in the WHATWG Host parsing spec: + // U+0000, U+0009, U+000A, U+000D, U+0020, "#", "%", "/", ":", "?", "@", "[", "\", and "]" + // (excluding the characters covered above). + if (" #%/:?@[\\]".indexOf(c) != -1) { + return true; + } + } + return false; + } + + private static String inet6AddressToAscii(byte[] address) { + // Go through the address looking for the longest run of 0s. Each group is 2-bytes. + int longestRunOffset = -1; + int longestRunLength = 0; + for (int i = 0; i < address.length; i += 2) { + int currentRunOffset = i; + while (i < 16 && address[i] == 0 && address[i + 1] == 0) { + i += 2; + } + int currentRunLength = i - currentRunOffset; + if (currentRunLength > longestRunLength) { + longestRunOffset = currentRunOffset; + longestRunLength = currentRunLength; + } + } + + // Emit each 2-byte group in hex, separated by ':'. The longest run of zeroes is "::". + Buffer result = new Buffer(); + for (int i = 0; i < address.length; ) { + if (i == longestRunOffset) { + result.writeByte(':'); + i += longestRunLength; + if (i == 16) result.writeByte(':'); + } else { + if (i > 0) result.writeByte(':'); + int group = (address[i] & 0xff) << 8 | (address[i + 1] & 0xff); + result.writeHexadecimalUnsignedLong(group); + i += 2; + } + } + return result.readUtf8(); + } + } + + static String percentDecode(String encoded, int pos, int limit, boolean plusIsSpace) { + for (int i = pos; i < limit; i++) { + char c = encoded.charAt(i); + if (c == '%' || (c == '+' && plusIsSpace)) { + // Slow path: the character at i requires decoding! + Buffer out = new Buffer(); + out.writeUtf8(encoded, pos, i); + percentDecode(out, encoded, i, limit, plusIsSpace); + return out.readUtf8(); + } + } + + // Fast path: no characters in [pos..limit) required decoding. + return encoded.substring(pos, limit); + } + + static void percentDecode(Buffer out, String encoded, int pos, int limit, boolean plusIsSpace) { + int codePoint; + for (int i = pos; i < limit; i += Character.charCount(codePoint)) { + codePoint = encoded.codePointAt(i); + if (codePoint == '%' && i + 2 < limit) { + int d1 = decodeHexDigit(encoded.charAt(i + 1)); + int d2 = decodeHexDigit(encoded.charAt(i + 2)); + if (d1 != -1 && d2 != -1) { + out.writeByte((d1 << 4) + d2); + i += 2; + continue; + } + } else if (codePoint == '+' && plusIsSpace) { + out.writeByte(' '); + continue; + } + out.writeUtf8CodePoint(codePoint); + } + } + + static int decodeHexDigit(char c) { + if (c >= '0' && c <= '9') return c - '0'; + if (c >= 'a' && c <= 'f') return c - 'a' + 10; + if (c >= 'A' && c <= 'F') return c - 'A' + 10; + return -1; + } + + static void canonicalize(Buffer out, String input, int pos, int limit, + String encodeSet, boolean alreadyEncoded, boolean plusIsSpace, boolean asciiOnly) { + Buffer utf8Buffer = null; // Lazily allocated. + int codePoint; + for (int i = pos; i < limit; i += Character.charCount(codePoint)) { + codePoint = input.codePointAt(i); + if (alreadyEncoded + && (codePoint == '\t' || codePoint == '\n' || codePoint == '\f' || codePoint == '\r')) { + // Skip this character. + } else if (codePoint == '+' && plusIsSpace) { + // Encode '+' as '%2B' since we permit ' ' to be encoded as either '+' or '%20'. + out.writeUtf8(alreadyEncoded ? "+" : "%2B"); + } else if (codePoint < 0x20 + || codePoint == 0x7f + || (codePoint >= 0x80 && asciiOnly) + || encodeSet.indexOf(codePoint) != -1 + || (codePoint == '%' && !alreadyEncoded)) { + // Percent encode this character. + if (utf8Buffer == null) { + utf8Buffer = new Buffer(); + } + utf8Buffer.writeUtf8CodePoint(codePoint); + while (!utf8Buffer.exhausted()) { + try { + fakeEofExceptionMethod(); // Okio 2.x can throw EOFException from readByte() + int b = utf8Buffer.readByte() & 0xff; + out.writeByte('%'); + out.writeByte(HEX_DIGITS[(b >> 4) & 0xf]); + out.writeByte(HEX_DIGITS[b & 0xf]); + } catch (EOFException e) { + throw new IndexOutOfBoundsException(e.getMessage()); + } + } + } else { + // This character doesn't need encoding. Just copy it over. + out.writeUtf8CodePoint(codePoint); + } + } + } + + private static void fakeEofExceptionMethod() throws EOFException {} +} diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/proxy/Request.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/proxy/Request.java new file mode 100644 index 00000000000..aef4f3783d5 --- /dev/null +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/proxy/Request.java @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2013 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Forked from OkHttp 2.7.0 com.squareup.okhttp.Request + */ +package io.grpc.okhttp.internal.proxy; + +import io.grpc.okhttp.internal.Headers; + +/** + * An HTTP ProxyRequest. Instances of this class are immutable. + */ +public final class Request { + private final HttpUrl url; + private final Headers headers; + + private Request(Builder builder) { + this.url = builder.url; + this.headers = builder.headers.build(); + } + + public HttpUrl httpUrl() { + return url; + } + + public Headers headers() { + return headers; + } + + public Builder newBuilder() { + return new Builder(); + } + + + @Override public String toString() { + return "Request{" + + "url=" + url + + '}'; + } + + public static class Builder { + private HttpUrl url; + private Headers.Builder headers; + + public Builder() { + this.headers = new Headers.Builder(); + } + + public Builder url(HttpUrl url) { + if (url == null) throw new IllegalArgumentException("url == null"); + this.url = url; + return this; + } + + /** + * Sets the header named {@code name} to {@code value}. If this request + * already has any headers with that name, they are all replaced. + */ + public Builder header(String name, String value) { + headers.set(name, value); + return this; + } + + public Request build() { + if (url == null) throw new IllegalStateException("url == null"); + return new Request(this); + } + } +} diff --git a/protobuf-lite/build.gradle b/protobuf-lite/build.gradle index 7b58309c414..0ada691913c 100644 --- a/protobuf-lite/build.gradle +++ b/protobuf-lite/build.gradle @@ -11,17 +11,17 @@ description = 'gRPC: Protobuf Lite' dependencies { api project(':grpc-api'), - libraries.protobuf_lite + libraries.protobuf.javalite implementation libraries.jsr305, libraries.guava testImplementation project(':grpc-core') - signature "org.codehaus.mojo.signature:java17:1.0@signature" - signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" + signature libraries.signature.java + signature libraries.signature.android } -compileTestJava { +tasks.named("compileTestJava").configure { options.compilerArgs += [ "-Xlint:-cast" ] @@ -33,7 +33,7 @@ protobuf { if (project.hasProperty('protoc')) { path = project.protoc } else { - artifact = "com.google.protobuf:protoc:${protocVersion}" + artifact = libs.protobuf.protoc.get() } } generateProtoTasks { diff --git a/protobuf/build.gradle b/protobuf/build.gradle index bb8546dc701..93b379daf43 100644 --- a/protobuf/build.gradle +++ b/protobuf/build.gradle @@ -12,10 +12,10 @@ description = 'gRPC: Protobuf' dependencies { api project(':grpc-api'), libraries.jsr305, - libraries.protobuf + libraries.protobuf.java implementation libraries.guava - api (libraries.google_api_protos) { + api (libraries.google.api.protos) { // 'com.google.api:api-common' transitively depends on auto-value, which breaks our // annotations. exclude group: 'com.google.api', module: 'api-common' @@ -25,7 +25,10 @@ dependencies { exclude group: 'com.google.protobuf', module: 'protobuf-javalite' } - signature "org.codehaus.mojo.signature:java17:1.0@signature" + signature libraries.signature.java + signature libraries.signature.android } -javadoc.options.links 'https://developers.google.com/protocol-buffers/docs/reference/java/' +tasks.named("javadoc").configure { + options.links 'https://developers.google.com/protocol-buffers/docs/reference/java/' +} diff --git a/repositories.bzl b/repositories.bzl index 8ad5f109e6f..cced1d29bee 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -11,38 +11,41 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # ) IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.android:annotations:4.1.1.4", - "com.google.api.grpc:proto-google-common-protos:2.0.1", + "com.google.api.grpc:proto-google-common-protos:2.9.0", "com.google.auth:google-auth-library-credentials:0.22.0", "com.google.auth:google-auth-library-oauth2-http:0.22.0", + "com.google.auto.value:auto-value-annotations:1.9", + "com.google.auto.value:auto-value:1.9", "com.google.code.findbugs:jsr305:3.0.2", - "com.google.code.gson:gson:jar:2.8.9", - "com.google.auto.value:auto-value:1.7.4", - "com.google.auto.value:auto-value-annotations:1.7.4", + "com.google.code.gson:gson:2.9.0", "com.google.errorprone:error_prone_annotations:2.9.0", "com.google.guava:failureaccess:1.0.1", - "com.google.guava:guava:30.1.1-android", + "com.google.guava:guava:31.0.1-android", "com.google.j2objc:j2objc-annotations:1.3", + "com.google.re2j:re2j:1.6", "com.google.truth:truth:1.0.1", - "com.squareup.okhttp:okhttp:2.7.4", + "com.squareup.okhttp:okhttp:2.7.5", "com.squareup.okio:okio:1.17.5", - "io.netty:netty-buffer:4.1.72.Final", - "io.netty:netty-codec-http2:4.1.72.Final", - "io.netty:netty-codec-http:4.1.72.Final", - "io.netty:netty-codec-socks:4.1.72.Final", - "io.netty:netty-codec:4.1.72.Final", - "io.netty:netty-common:4.1.72.Final", - "io.netty:netty-handler-proxy:4.1.72.Final", - "io.netty:netty-handler:4.1.72.Final", - "io.netty:netty-resolver:4.1.72.Final", - "io.netty:netty-tcnative-boringssl-static:2.0.46.Final", - "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.72.Final", - "io.netty:netty-transport:4.1.72.Final", + "io.netty:netty-buffer:4.1.79.Final", + "io.netty:netty-codec-http2:4.1.79.Final", + "io.netty:netty-codec-http:4.1.79.Final", + "io.netty:netty-codec-socks:4.1.79.Final", + "io.netty:netty-codec:4.1.79.Final", + "io.netty:netty-common:4.1.79.Final", + "io.netty:netty-handler-proxy:4.1.79.Final", + "io.netty:netty-handler:4.1.79.Final", + "io.netty:netty-resolver:4.1.79.Final", + "io.netty:netty-tcnative-boringssl-static:2.0.54.Final", + "io.netty:netty-tcnative-classes:2.0.54.Final", + "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.79.Final", + "io.netty:netty-transport-native-unix-common:4.1.79.Final", + "io.netty:netty-transport:4.1.79.Final", "io.opencensus:opencensus-api:0.24.0", "io.opencensus:opencensus-contrib-grpc-metrics:0.24.0", - "io.perfmark:perfmark-api:0.23.0", + "io.perfmark:perfmark-api:0.25.0", "junit:junit:4.12", "org.apache.tomcat:annotations-api:6.0.53", - "org.codehaus.mojo:animal-sniffer-annotations:1.19", + "org.codehaus.mojo:animal-sniffer-annotations:1.21", ] # For use with maven_install's override_targets. @@ -73,24 +76,73 @@ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { "io.grpc:grpc-census": "@io_grpc_grpc_java//census", "io.grpc:grpc-context": "@io_grpc_grpc_java//context", "io.grpc:grpc-core": "@io_grpc_grpc_java//core:core_maven", + "io.grpc:grpc-googleapis": "@io_grpc_grpc_java//googleapis", "io.grpc:grpc-grpclb": "@io_grpc_grpc_java//grpclb", "io.grpc:grpc-netty": "@io_grpc_grpc_java//netty", "io.grpc:grpc-netty-shaded": "@io_grpc_grpc_java//netty:shaded_maven", "io.grpc:grpc-okhttp": "@io_grpc_grpc_java//okhttp", "io.grpc:grpc-protobuf": "@io_grpc_grpc_java//protobuf", "io.grpc:grpc-protobuf-lite": "@io_grpc_grpc_java//protobuf-lite", + "io.grpc:grpc-rls": "@io_grpc_grpc_java//rls", + "io.grpc:grpc-services": "@io_grpc_grpc_java//services:services_maven", "io.grpc:grpc-stub": "@io_grpc_grpc_java//stub", "io.grpc:grpc-testing": "@io_grpc_grpc_java//testing", + "io.grpc:grpc-xds": "@io_grpc_grpc_java//xds:xds_maven", } def grpc_java_repositories(): """Imports dependencies for grpc-java.""" + if not native.existing_rule("com_github_cncf_xds"): + http_archive( + name = "com_github_cncf_xds", + strip_prefix = "xds-d92e9ce0af512a73a3a126b32fa4920bee12e180", + sha256 = "27be88b1ff2844885d3b2d0d579546f3a8b3f26b4871eed89082c9709e49a4bd", + urls = [ + "https://github.com/cncf/xds/archive/d92e9ce0af512a73a3a126b32fa4920bee12e180.tar.gz", + ], + ) + if not native.existing_rule("com_github_grpc_grpc"): + http_archive( + name = "com_github_grpc_grpc", + strip_prefix = "grpc-1.46.0", + sha256 = "67423a4cd706ce16a88d1549297023f0f9f0d695a96dd684adc21e67b021f9bc", + urls = [ + "https://github.com/grpc/grpc/archive/v1.46.0.tar.gz", + ], + ) if not native.existing_rule("com_google_protobuf"): com_google_protobuf() if not native.existing_rule("com_google_protobuf_javalite"): com_google_protobuf_javalite() + if not native.existing_rule("com_google_googleapis"): + http_archive( + name = "com_google_googleapis", + sha256 = "49930468563dd48283e8301e8d4e71436bf6d27ac27c235224cc1a098710835d", + strip_prefix = "googleapis-ca1372c6d7bcb199638ebfdb40d2b2660bab7b88", + urls = [ + "https://github.com/googleapis/googleapis/archive/ca1372c6d7bcb199638ebfdb40d2b2660bab7b88.tar.gz", + ], + ) + if not native.existing_rule("io_bazel_rules_go"): + http_archive( + name = "io_bazel_rules_go", + sha256 = "ab21448cef298740765f33a7f5acee0607203e4ea321219f2a4c85a6e0fb0a27", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.32.0/rules_go-v0.32.0.zip", + "https://github.com/bazelbuild/rules_go/releases/download/v0.32.0/rules_go-v0.32.0.zip", + ], + ) if not native.existing_rule("io_grpc_grpc_proto"): io_grpc_grpc_proto() + if not native.existing_rule("envoy_api"): + http_archive( + name = "envoy_api", + sha256 = "a0c58442cc2038ccccad9616dd1bab5ff1e65da2bbc0ae41020ef6010119eb0e", + strip_prefix = "data-plane-api-869b00336913138cad96a653458aab650c4e70ea", + urls = [ + "https://github.com/envoyproxy/data-plane-api/archive/869b00336913138cad96a653458aab650c4e70ea.tar.gz", + ], + ) def com_google_protobuf(): # proto_library rules implicitly depend on @com_google_protobuf//:protoc, @@ -98,18 +150,18 @@ def com_google_protobuf(): # This statement defines the @com_google_protobuf repo. http_archive( name = "com_google_protobuf", - sha256 = "9ceef0daf7e8be16cd99ac759271eb08021b53b1c7b6edd399953a76390234cd", - strip_prefix = "protobuf-3.19.2", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.2.zip"], + sha256 = "c72840a5081484c4ac20789ea5bb5d5de6bc7c477ad76e7109fda2bc4e630fe6", + strip_prefix = "protobuf-3.21.7", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.7.zip"], ) def com_google_protobuf_javalite(): # java_lite_proto_library rules implicitly depend on @com_google_protobuf_javalite http_archive( name = "com_google_protobuf_javalite", - sha256 = "9ceef0daf7e8be16cd99ac759271eb08021b53b1c7b6edd399953a76390234cd", - strip_prefix = "protobuf-3.19.2", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.2.zip"], + sha256 = "c72840a5081484c4ac20789ea5bb5d5de6bc7c477ad76e7109fda2bc4e630fe6", + strip_prefix = "protobuf-3.21.7", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.7.zip"], ) def io_grpc_grpc_proto(): diff --git a/rls/BUILD.bazel b/rls/BUILD.bazel index 4daa7029560..d4af3569240 100644 --- a/rls/BUILD.bazel +++ b/rls/BUILD.bazel @@ -7,8 +7,8 @@ java_library( ]), visibility = ["//visibility:public"], deps = [ - ":autovalue", ":rls_java_grpc", + "//:auto_value_annotations", "//api", "//core", "//core:internal", @@ -17,27 +17,8 @@ java_library( "@com_google_auto_value_auto_value_annotations//jar", "@com_google_code_findbugs_jsr305//jar", "@com_google_guava_guava//jar", - "@io_grpc_grpc_proto//:rls_java_proto", "@io_grpc_grpc_proto//:rls_config_java_proto", - ], -) - -java_plugin( - name = "autovalue_plugin", - processor_class = "com.google.auto.value.processor.AutoValueProcessor", - deps = [ - "@com_google_auto_value_auto_value//jar", - ], -) - -java_library( - name = "autovalue", - exported_plugins = [ - ":autovalue_plugin", - ], - neverlink = 1, - exports = [ - "@com_google_auto_value_auto_value//jar", + "@io_grpc_grpc_proto//:rls_java_proto", ], ) diff --git a/rls/build.gradle b/rls/build.gradle index 45f17fb71c3..ddd8cb65870 100644 --- a/rls/build.gradle +++ b/rls/build.gradle @@ -14,19 +14,19 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), - libraries.autovalue_annotation, + libraries.auto.value.annotations, libraries.guava - annotationProcessor libraries.autovalue - compileOnly libraries.javax_annotation + annotationProcessor libraries.auto.value + compileOnly libraries.javax.annotation testImplementation libraries.truth, project(':grpc-grpclb'), project(':grpc-testing'), project(':grpc-testing-proto'), project(':grpc-core').sourceSets.test.output // for FakeClock - signature "org.codehaus.mojo.signature:java17:1.0@signature" + signature libraries.signature.java } -[compileJava].each() { +tasks.named("compileJava").configure { it.options.compilerArgs += [ // only has AutoValue annotation processor "-Xlint:-processing", @@ -37,7 +37,7 @@ dependencies { "|") } -javadoc { +tasks.named("javadoc").configure { // Do not publish javadoc since currently there is no public API. failOnError false // no public or protected classes found to document exclude 'io/grpc/lookup/v1/**' @@ -45,7 +45,7 @@ javadoc { exclude 'io/grpc/rls/Internal*' } -jacocoTestReport { +tasks.named("jacocoTestReport").configure { classDirectories.from = sourceSets.main.output.collect { fileTree(dir: it, exclude: ['**/io/grpc/lookup/**']) } diff --git a/rls/src/main/java/io/grpc/rls/AdaptiveThrottler.java b/rls/src/main/java/io/grpc/rls/AdaptiveThrottler.java index 55f3f72453e..576b234b0c3 100644 --- a/rls/src/main/java/io/grpc/rls/AdaptiveThrottler.java +++ b/rls/src/main/java/io/grpc/rls/AdaptiveThrottler.java @@ -21,7 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import io.grpc.internal.TimeProvider; +import com.google.common.base.Ticker; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLongFieldUpdater; @@ -44,7 +44,7 @@ final class AdaptiveThrottler implements Throttler { private static final int DEFAULT_HISTORY_SECONDS = 30; private static final int DEFAULT_REQUEST_PADDING = 8; - private static final float DEFAULT_RATIO_FOR_ACCEPT = 1.2f; + private static final float DEFAULT_RATIO_FOR_ACCEPT = 2.0f; /** * The duration of history of calls used by Adaptive Throttler. @@ -60,7 +60,7 @@ final class AdaptiveThrottler implements Throttler { * is currently accepting. */ private final float ratioForAccepts; - private final TimeProvider timeProvider; + private final Ticker ticker; /** * The number of requests attempted by the client during the Adaptive Throttler instance's * history of calls. This includes requests throttled at the client. The history period defaults @@ -79,10 +79,10 @@ private AdaptiveThrottler(Builder builder) { this.historySeconds = builder.historySeconds; this.requestsPadding = builder.requestsPadding; this.ratioForAccepts = builder.ratioForAccepts; - this.timeProvider = builder.timeProvider; + this.ticker = builder.ticker; long internalNanos = TimeUnit.SECONDS.toNanos(historySeconds); - this.requestStat = new TimeBasedAccumulator(internalNanos, timeProvider); - this.throttledStat = new TimeBasedAccumulator(internalNanos, timeProvider); + this.requestStat = new TimeBasedAccumulator(internalNanos, ticker); + this.throttledStat = new TimeBasedAccumulator(internalNanos, ticker); } @Override @@ -92,7 +92,7 @@ public boolean shouldThrottle() { @VisibleForTesting boolean shouldThrottle(float random) { - long nowNanos = timeProvider.currentTimeNanos(); + long nowNanos = ticker.read(); if (getThrottleProbability(nowNanos) <= random) { return false; } @@ -118,7 +118,7 @@ float getThrottleProbability(long nowNanos) { @Override public void registerBackendResponse(boolean throttled) { - long now = timeProvider.currentTimeNanos(); + long now = ticker.read(); requestStat.increment(now); if (throttled) { throttledStat.increment(now); @@ -150,7 +150,7 @@ static final class Builder { private float ratioForAccepts = DEFAULT_RATIO_FOR_ACCEPT; private int historySeconds = DEFAULT_HISTORY_SECONDS; private int requestsPadding = DEFAULT_REQUEST_PADDING; - private TimeProvider timeProvider = TimeProvider.SYSTEM_TIME_PROVIDER; + private Ticker ticker = Ticker.systemTicker(); public Builder setRatioForAccepts(float ratioForAccepts) { this.ratioForAccepts = ratioForAccepts; @@ -167,8 +167,8 @@ public Builder setRequestsPadding(int requestsPadding) { return this; } - public Builder setTimeProvider(TimeProvider timeProvider) { - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); + public Builder setTicker(Ticker ticker) { + this.ticker = checkNotNull(ticker, "ticker"); return this; } @@ -205,9 +205,6 @@ void increment() { } } - // Represents a slot which is not initialized and is unusable. - private static final Slot NULL_SLOT = new Slot(-1); - /** The array of slots. */ private final AtomicReferenceArray slots = new AtomicReferenceArray<>(NUM_SLOTS); @@ -224,7 +221,7 @@ void increment() { */ private volatile int currentIndex; - private final TimeProvider timeProvider; + private final Ticker ticker; /** * Interval constructor. @@ -232,7 +229,7 @@ void increment() { * @param internalNanos is the stat interval in nanoseconds * @throws IllegalArgumentException if the supplied interval is too small to be effective */ - TimeBasedAccumulator(long internalNanos, TimeProvider timeProvider) { + TimeBasedAccumulator(long internalNanos, Ticker ticker) { checkArgument( internalNanos >= NUM_SLOTS, "Interval must be greater than %s", @@ -240,30 +237,27 @@ void increment() { this.interval = internalNanos; this.slotNanos = internalNanos / NUM_SLOTS; this.currentIndex = 0; - for (int i = 0; i < NUM_SLOTS; i++) { - slots.set(i, NULL_SLOT); - } - this.timeProvider = checkNotNull(timeProvider, "ticker"); + this.ticker = checkNotNull(ticker, "ticker"); } /** Gets the current slot. */ private Slot getSlot(long now) { Slot currentSlot = slots.get(currentIndex); - if (now < currentSlot.endNanos) { + if (currentSlot != null && now - currentSlot.endNanos < 0) { return currentSlot; } else { long slotBoundary = getSlotEndTime(now); synchronized (this) { int index = currentIndex; currentSlot = slots.get(index); - if (now < currentSlot.endNanos) { + if (currentSlot != null && now - currentSlot.endNanos < 0) { return currentSlot; } int newIndex = (index == NUM_SLOTS - 1) ? 0 : index + 1; Slot nextSlot = new Slot(slotBoundary); slots.set(newIndex, nextSlot); // Set currentIndex only after assigning the new slot to slots, otherwise - // racing readers will see NULL_SLOT or an old slot. + // racing readers will see null or an old slot. currentIndex = newIndex; return nextSlot; } @@ -294,7 +288,7 @@ long getInterval() { * * @param now is the time used to increment the count */ - final void increment(long now) { + void increment(long now) { getSlot(now).increment(); } @@ -304,28 +298,33 @@ final void increment(long now) { * @param now the current time * @return the statistic count */ - final long get(long now) { + long get(long now) { long intervalEnd = getSlotEndTime(now); long intervalStart = intervalEnd - interval; // This is the point at which increments to new slots will be ignored. int index = currentIndex; long accumulated = 0L; - long prevSlotEnd = Long.MAX_VALUE; + Long prevSlotEnd = null; for (int i = 0; i < NUM_SLOTS; i++) { if (index < 0) { index = NUM_SLOTS - 1; } Slot currentSlot = slots.get(index); index--; + if (currentSlot == null) { + continue; + } + long currentSlotEnd = currentSlot.endNanos; - if (currentSlotEnd <= intervalStart || currentSlotEnd > prevSlotEnd) { + if (currentSlotEnd - intervalStart <= 0 + || (prevSlotEnd != null && currentSlotEnd - prevSlotEnd > 0)) { break; } prevSlotEnd = currentSlotEnd; - if (currentSlotEnd > intervalEnd) { + if (currentSlotEnd - intervalEnd > 0) { continue; } accumulated = accumulated + currentSlot.count; @@ -337,7 +336,7 @@ final long get(long now) { public String toString() { return MoreObjects.toStringHelper(this) .add("interval", interval) - .add("current_count", get(timeProvider.currentTimeNanos())) + .add("current_count", get(ticker.read())) .toString(); } } diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index da388be0503..e8595961613 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -23,6 +23,7 @@ import com.google.common.base.Converter; import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.common.base.Ticker; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.ChannelLogger; @@ -41,7 +42,6 @@ import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.TimeProvider; import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.lookup.v1.RouteLookupServiceGrpc.RouteLookupServiceStub; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; @@ -60,6 +60,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -80,19 +81,24 @@ final class CachingRlsLbClient { REQUEST_CONVERTER = new RlsProtoConverters.RouteLookupRequestConverter().reverse(); private static final Converter RESPONSE_CONVERTER = new RouteLookupResponseConverter().reverse(); + public static final long MIN_EVICTION_TIME_DELTA_NANOS = TimeUnit.SECONDS.toNanos(5); + public static final int BYTES_PER_CHAR = 2; + public static final int STRING_OVERHEAD_BYTES = 38; + /** Minimum bytes for a Java Object. */ + public static final int OBJ_OVERHEAD_B = 16; // All cache status changes (pending, backoff, success) must be under this lock private final Object lock = new Object(); // LRU cache based on access order (BACKOFF and actual data will be here) @GuardedBy("lock") - private final LinkedHashLruCache linkedHashLruCache; + private final RlsAsyncLruCache linkedHashLruCache; // any RPC on the fly will cached in this map @GuardedBy("lock") private final Map pendingCallCache = new HashMap<>(); private final SynchronizationContext synchronizationContext; private final ScheduledExecutorService scheduledExecutorService; - private final TimeProvider timeProvider; + private final Ticker ticker; private final Throttler throttler; private final LbPolicyConfiguration lbPolicyConfig; @@ -118,14 +124,14 @@ private CachingRlsLbClient(Builder builder) { maxAgeNanos = rlsConfig.maxAgeInNanos(); staleAgeNanos = rlsConfig.staleAgeInNanos(); callTimeoutNanos = rlsConfig.lookupServiceTimeoutInNanos(); - timeProvider = checkNotNull(builder.timeProvider, "timeProvider"); + ticker = checkNotNull(builder.ticker, "ticker"); throttler = checkNotNull(builder.throttler, "throttler"); linkedHashLruCache = new RlsAsyncLruCache( rlsConfig.cacheSizeBytes(), builder.evictionListener, scheduledExecutorService, - timeProvider, + ticker, lock); logger = helper.getChannelLogger(); String serverHost = null; @@ -175,6 +181,19 @@ private CachingRlsLbClient(Builder builder) { logger.log(ChannelLogLevel.DEBUG, "CachingRlsLbClient created"); } + /** + * Convert the status to UNAVAILBLE and enhance the error message. + * @param status status as provided by server + * @param serverName Used for error description + * @return Transformed status + */ + static Status convertRlsServerStatus(Status status, String serverName) { + return Status.UNAVAILABLE.withCause(status.getCause()).withDescription( + String.format("Unable to retrieve RLS targets from RLS server %s. " + + "RLS server returned: %s: %s", + serverName, status.getCode(), status.getDescription())); + } + @CheckReturnValue private ListenableFuture asyncRlsCall(RouteLookupRequest request) { final SettableFuture response = SettableFuture.create(); @@ -199,13 +218,13 @@ public void onNext(io.grpc.lookup.v1.RouteLookupResponse value) { public void onError(Throwable t) { logger.log(ChannelLogLevel.DEBUG, "Error looking up route:", t); response.setException(t); - throttler.registerBackendResponse(false); + throttler.registerBackendResponse(true); helper.propagateRlsError(); } @Override public void onCompleted() { - throttler.registerBackendResponse(true); + throttler.registerBackendResponse(false); } }); return response; @@ -229,7 +248,7 @@ final CachedRouteLookupResponse get(final RouteLookupRequest request) { // cache hit, initiate async-refresh if entry is staled logger.log(ChannelLogLevel.DEBUG, "Cache hit for the request"); DataCacheEntry dataEntry = ((DataCacheEntry) cacheEntry); - if (dataEntry.isStaled(timeProvider.currentTimeNanos())) { + if (dataEntry.isStaled(ticker.read())) { dataEntry.maybeRefresh(); } return CachedRouteLookupResponse.dataEntry((DataCacheEntry) cacheEntry); @@ -273,12 +292,12 @@ private CachedRouteLookupResponse handleNewRequest(RouteLookupRequest request) { try { RouteLookupResponse response = asyncCall.get(); DataCacheEntry dataEntry = new DataCacheEntry(request, response); - linkedHashLruCache.cache(request, dataEntry); + linkedHashLruCache.cacheAndClean(request, dataEntry); return CachedRouteLookupResponse.dataEntry(dataEntry); } catch (Exception e) { BackoffCacheEntry backoffEntry = new BackoffCacheEntry(request, Status.fromThrowable(e), backoffProvider.get()); - linkedHashLruCache.cache(request, backoffEntry); + linkedHashLruCache.cacheAndClean(request, backoffEntry); return CachedRouteLookupResponse.backoffEntry(backoffEntry); } } @@ -322,6 +341,10 @@ public void run() { } }); } + + void triggerPendingRpcProcessing() { + super.updateBalancingState(state, picker); + } } /** @@ -372,6 +395,15 @@ ChildPolicyWrapper getChildPolicyWrapper() { return dataCacheEntry.getChildPolicyWrapper(); } + @VisibleForTesting + @Nullable + ChildPolicyWrapper getChildPolicyWrapper(String target) { + if (!hasData()) { + return null; + } + return dataCacheEntry.getChildPolicyWrapper(target); + } + @Nullable String getHeaderData() { if (!hasData()) { @@ -465,14 +497,15 @@ private void transitionToDataEntry(RouteLookupResponse routeLookupResponse) { ChannelLogLevel.DEBUG, "Transition to data cache: routeLookupResponse={0}", routeLookupResponse); - linkedHashLruCache.cache(request, new DataCacheEntry(request, routeLookupResponse)); + linkedHashLruCache.cacheAndClean(request, new DataCacheEntry(request, routeLookupResponse)); } } private void transitionToBackOff(Status status) { synchronized (lock) { logger.log(ChannelLogLevel.DEBUG, "Transition to back off: status={0}", status); - linkedHashLruCache.cache(request, new BackoffCacheEntry(request, status, backoffPolicy)); + linkedHashLruCache.cacheAndClean(request, + new BackoffCacheEntry(request, status, backoffPolicy)); } } @@ -496,30 +529,40 @@ abstract class CacheEntry { abstract int getSizeBytes(); final boolean isExpired() { - return isExpired(timeProvider.currentTimeNanos()); + return isExpired(ticker.read()); } abstract boolean isExpired(long now); abstract void cleanup(); + + protected long getMinEvictionTime() { + return 0L; + } + + protected void triggerPendingRpcProcessing() { + helper.triggerPendingRpcProcessing(); + } } /** Implementation of {@link CacheEntry} contains valid data. */ final class DataCacheEntry extends CacheEntry { private final RouteLookupResponse response; + private final long minEvictionTime; private final long expireTime; private final long staleTime; - private final ChildPolicyWrapper childPolicyWrapper; + private final List childPolicyWrappers; // GuardedBy CachingRlsLbClient.lock DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) { super(request); this.response = checkNotNull(response, "response"); - // TODO(creamsoup) fallback to other targets if first one is not available - childPolicyWrapper = + checkState(!response.targets().isEmpty(), "No targets returned by RLS"); + childPolicyWrappers = refCountedChildPolicyWrapperFactory - .createOrGet(response.targets().get(0)); - long now = timeProvider.currentTimeNanos(); + .createOrGet(response.targets()); + long now = ticker.read(); + minEvictionTime = now + MIN_EVICTION_TIME_DELTA_NANOS; expireTime = now + maxAgeNanos; staleTime = now + staleAgeNanos; } @@ -551,47 +594,78 @@ void maybeRefresh() { // async call returned finished future is most likely throttled try { RouteLookupResponse response = asyncCall.get(); - linkedHashLruCache.cache(request, new DataCacheEntry(request, response)); + linkedHashLruCache.cacheAndClean(request, new DataCacheEntry(request, response)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (Exception e) { BackoffCacheEntry backoffEntry = new BackoffCacheEntry(request, Status.fromThrowable(e), backoffProvider.get()); - linkedHashLruCache.cache(request, backoffEntry); + linkedHashLruCache.cacheAndClean(request, backoffEntry); } } } } + @VisibleForTesting + ChildPolicyWrapper getChildPolicyWrapper(String target) { + for (ChildPolicyWrapper childPolicyWrapper : childPolicyWrappers) { + if (childPolicyWrapper.getTarget().equals(target)) { + return childPolicyWrapper; + } + } + + throw new RuntimeException("Target not found:" + target); + } + @Nullable ChildPolicyWrapper getChildPolicyWrapper() { - return childPolicyWrapper; + for (ChildPolicyWrapper childPolicyWrapper : childPolicyWrappers) { + if (childPolicyWrapper.getState() != ConnectivityState.TRANSIENT_FAILURE) { + return childPolicyWrapper; + } + } + return childPolicyWrappers.get(0); } String getHeaderData() { return response.getHeaderData(); } + // Assume UTF-16 (2 bytes) and overhead of a String object is 38 bytes + int calcStringSize(String target) { + return target.length() * BYTES_PER_CHAR + STRING_OVERHEAD_BYTES; + } + @Override int getSizeBytes() { - // size of strings and java object overhead, actual memory usage is more than this. - return - (response.targets().get(0).length() + response.getHeaderData().length()) * 2 + 38 * 2; + int targetSize = 0; + for (String target : response.targets()) { + targetSize += calcStringSize(target); + } + return targetSize + calcStringSize(response.getHeaderData()) + OBJ_OVERHEAD_B // response size + + Long.SIZE * 2 + OBJ_OVERHEAD_B; // Other fields } @Override boolean isExpired(long now) { - return expireTime <= now; + return expireTime - now <= 0; } boolean isStaled(long now) { - return staleTime <= now; + return staleTime - now <= 0; + } + + @Override + protected long getMinEvictionTime() { + return minEvictionTime; } @Override void cleanup() { synchronized (lock) { - refCountedChildPolicyWrapperFactory.release(childPolicyWrapper); + for (ChildPolicyWrapper policyWrapper : childPolicyWrappers) { + refCountedChildPolicyWrapperFactory.release(policyWrapper); + } } } @@ -602,7 +676,7 @@ public String toString() { .add("response", response) .add("expireTime", expireTime) .add("staleTime", staleTime) - .add("childPolicyWrapper", childPolicyWrapper) + .add("childPolicyWrappers", childPolicyWrappers) .toString(); } } @@ -624,7 +698,7 @@ private final class BackoffCacheEntry extends CacheEntry { this.status = checkNotNull(status, "status"); this.backoffPolicy = checkNotNull(backoffPolicy, "backoffPolicy"); long delayNanos = backoffPolicy.nextBackoffNanos(); - this.expireNanos = timeProvider.currentTimeNanos() + delayNanos; + this.expireNanos = ticker.read() + delayNanos; this.scheduledHandle = synchronizationContext.schedule( new Runnable() { @@ -659,11 +733,11 @@ private void transitionToPending() { } else { try { RouteLookupResponse response = call.get(); - linkedHashLruCache.cache(request, new DataCacheEntry(request, response)); + linkedHashLruCache.cacheAndClean(request, new DataCacheEntry(request, response)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (Exception e) { - linkedHashLruCache.cache( + linkedHashLruCache.cacheAndClean( request, new BackoffCacheEntry(request, Status.fromThrowable(e), backoffPolicy)); } @@ -677,12 +751,12 @@ Status getStatus() { @Override int getSizeBytes() { - return 0; + return OBJ_OVERHEAD_B * 3 + Long.SIZE + 8; // 3 java objects, 1 long and a boolean } @Override boolean isExpired(long now) { - return expireNanos <= now; + return expireNanos - now <= 0; } @Override @@ -717,7 +791,7 @@ static final class Builder { private LbPolicyConfiguration lbPolicyConfig; private Throttler throttler = new HappyThrottler(); private ResolvedAddressFactory resolvedAddressFactory; - private TimeProvider timeProvider = TimeProvider.SYSTEM_TIME_PROVIDER; + private Ticker ticker = Ticker.systemTicker(); private EvictionListener evictionListener; private BackoffPolicy.Provider backoffProvider = new ExponentialBackoffPolicy.Provider(); @@ -746,8 +820,8 @@ Builder setResolvedAddressesFactory( return this; } - Builder setTimeProvider(TimeProvider timeProvider) { - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); + Builder setTicker(Ticker ticker) { + this.ticker = checkNotNull(ticker, "ticker"); return this; } @@ -811,14 +885,14 @@ private static final class RlsAsyncLruCache RlsAsyncLruCache(long maxEstimatedSizeBytes, @Nullable EvictionListener evictionListener, - ScheduledExecutorService ses, TimeProvider timeProvider, Object lock) { + ScheduledExecutorService ses, Ticker ticker, Object lock) { super( maxEstimatedSizeBytes, new AutoCleaningEvictionListener(evictionListener), 1, TimeUnit.MINUTES, ses, - timeProvider, + ticker, lock); } @@ -835,8 +909,22 @@ protected int estimateSizeOf(RouteLookupRequest key, CacheEntry value) { @Override protected boolean shouldInvalidateEldestEntry( RouteLookupRequest eldestKey, CacheEntry eldestValue) { + if (eldestValue.getMinEvictionTime() > now()) { + return false; + } + // eldest entry should be evicted if size limit exceeded - return true; + return this.estimatedSizeBytes() > this.estimatedMaxSizeBytes(); + } + + public CacheEntry cacheAndClean(RouteLookupRequest key, CacheEntry value) { + CacheEntry newEntry = cache(key, value); + + // force cleanup if new entry pushed cache over max size (in bytes) + if (fitToLimit()) { + value.triggerPendingRpcProcessing(); + } + return newEntry; } } @@ -898,23 +986,20 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { boolean hasFallback = defaultTarget != null && !defaultTarget.isEmpty(); if (response.hasData()) { ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper(); - SubchannelPicker picker = childPolicyWrapper.getPicker(); + SubchannelPicker picker = + (childPolicyWrapper != null) ? childPolicyWrapper.getPicker() : null; if (picker == null) { return PickResult.withNoResult(); } - PickResult result = picker.pickSubchannel(args); - if (result.getStatus().isOk()) { - return result; - } - if (hasFallback) { - return useFallback(args); - } - return PickResult.withError(result.getStatus()); + // Happy path + return picker.pickSubchannel(args); } else if (response.hasError()) { if (hasFallback) { return useFallback(args); } - return PickResult.withError(response.getStatus()); + return PickResult.withError( + convertRlsServerStatus(response.getStatus(), + lbPolicyConfig.getRouteLookupConfig().lookupService())); } else { return PickResult.withNoResult(); } @@ -958,4 +1043,5 @@ public String toString() { .toString(); } } + } diff --git a/rls/src/main/java/io/grpc/rls/ChildLbResolvedAddressFactory.java b/rls/src/main/java/io/grpc/rls/ChildLbResolvedAddressFactory.java index 884447ec878..73b2e7591e4 100644 --- a/rls/src/main/java/io/grpc/rls/ChildLbResolvedAddressFactory.java +++ b/rls/src/main/java/io/grpc/rls/ChildLbResolvedAddressFactory.java @@ -16,7 +16,6 @@ package io.grpc.rls; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.Attributes; @@ -33,8 +32,7 @@ final class ChildLbResolvedAddressFactory implements ResolvedAddressFactory { ChildLbResolvedAddressFactory( List addresses, Attributes attributes) { - checkArgument(addresses != null && !addresses.isEmpty(), "Address must be provided"); - this.addresses = Collections.unmodifiableList(addresses); + this.addresses = Collections.unmodifiableList(checkNotNull(addresses, "addresses")); this.attributes = checkNotNull(attributes, "attributes"); } diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java index f40cde9fd81..0abef014f8a 100644 --- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java +++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java @@ -247,6 +247,15 @@ ChildPolicyWrapper createOrGet(String target) { } } + // GuardedBy CachingRlsLbClient.lock + List createOrGet(List targets) { + List retVal = new ArrayList<>(); + for (String target : targets) { + retVal.add(createOrGet(target)); + } + return retVal; + } + // GuardedBy CachingRlsLbClient.lock void release(ChildPolicyWrapper childPolicyWrapper) { checkNotNull(childPolicyWrapper, "childPolicyWrapper"); @@ -312,6 +321,10 @@ ChildPolicyReportingHelper getHelper() { return helper; } + public ConnectivityState getState() { + return state; + } + void refreshState() { helper.getSynchronizationContext().execute( new Runnable() { diff --git a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java index 9c1a24a0d1e..77376a374f4 100644 --- a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java +++ b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java @@ -21,7 +21,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.base.MoreObjects; -import io.grpc.internal.TimeProvider; +import com.google.common.base.Ticker; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -53,7 +53,7 @@ abstract class LinkedHashLruCache implements LruCache { @GuardedBy("lock") private final LinkedHashMap delegate; private final PeriodicCleaner periodicCleaner; - private final TimeProvider timeProvider; + private final Ticker ticker; private final EvictionListener evictionListener; private final AtomicLong estimatedSizeBytes = new AtomicLong(); private long estimatedMaxSizeBytes; @@ -64,13 +64,13 @@ abstract class LinkedHashLruCache implements LruCache { int cleaningInterval, TimeUnit cleaningIntervalUnit, ScheduledExecutorService ses, - final TimeProvider timeProvider, + final Ticker ticker, Object lock) { checkState(estimatedMaxSizeBytes > 0, "max estimated cache size should be positive"); this.estimatedMaxSizeBytes = estimatedMaxSizeBytes; this.lock = checkNotNull(lock, "lock"); this.evictionListener = new SizeHandlingEvictionListener(evictionListener); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); + this.ticker = checkNotNull(ticker, "ticker"); delegate = new LinkedHashMap( // rough estimate or minimum hashmap default Math.max((int) (estimatedMaxSizeBytes / 1000), 16), @@ -83,7 +83,7 @@ protected boolean removeEldestEntry(Map.Entry eldest) { } // first, remove at most 1 expired entry - boolean removed = cleanupExpiredEntries(1, timeProvider.currentTimeNanos()); + boolean removed = cleanupExpiredEntries(1, ticker.read()); // handles size based eviction if necessary no expired entry boolean shouldRemove = !removed && shouldInvalidateEldestEntry(eldest.getKey(), eldest.getValue().value); @@ -118,6 +118,10 @@ protected int estimateSizeOf(K key, V value) { return 1; } + protected long estimatedMaxSizeBytes() { + return estimatedMaxSizeBytes; + } + /** Updates size for given key if entry exists. It is useful if the cache value is mutated. */ public void updateEntrySize(K key) { synchronized (lock) { @@ -174,7 +178,7 @@ private SizedValue readInternal(K key) { checkNotNull(key, "key"); synchronized (lock) { SizedValue existing = delegate.get(key); - if (existing != null && isExpired(key, existing.value, timeProvider.currentTimeNanos())) { + if (existing != null && isExpired(key, existing.value, ticker.read())) { invalidate(key, EvictionType.EXPIRED); return null; } @@ -233,30 +237,50 @@ public final List values() { } } + protected long now() { + return ticker.read(); + } + /** - * Resizes cache. If new size is smaller than current estimated size, it will free up space by + * Cleans up cache if needed to fit into max size bytes by * removing expired entries and removing oldest entries by LRU order. + * Returns TRUE if any unexpired entries were removed */ - public final void resize(int newSizeBytes) { - long now = timeProvider.currentTimeNanos(); + protected final boolean fitToLimit() { + boolean removedAnyUnexpired = false; synchronized (lock) { - this.estimatedMaxSizeBytes = newSizeBytes; - if (estimatedSizeBytes.get() <= newSizeBytes) { + if (estimatedSizeBytes.get() <= estimatedMaxSizeBytes) { // new size is larger no need to do cleanup - return; + return false; } // cleanup expired entries - cleanupExpiredEntries(now); + cleanupExpiredEntries(now()); // cleanup eldest entry until new size limit Iterator> lruIter = delegate.entrySet().iterator(); while (lruIter.hasNext() && estimatedMaxSizeBytes < this.estimatedSizeBytes.get()) { Map.Entry entry = lruIter.next(); + if (!shouldInvalidateEldestEntry(entry.getKey(), entry.getValue().value)) { + break; // Violates some constraint like minimum age so stop our cleanup + } lruIter.remove(); // eviction listener will update the estimatedSizeBytes evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); + removedAnyUnexpired = true; } } + return removedAnyUnexpired; + } + + /** + * Resizes cache. If new size is smaller than current estimated size, it will free up space by + * removing expired entries and removing oldest entries by LRU order. + */ + public final void resize(long newSizeBytes) { + synchronized (lock) { + this.estimatedMaxSizeBytes = newSizeBytes; + fitToLimit(); + } } @Override @@ -331,7 +355,7 @@ private class CleaningTask implements Runnable { @Override public void run() { - cleanupExpiredEntries(timeProvider.currentTimeNanos()); + cleanupExpiredEntries(ticker.read()); } } } diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java index 2aac96cadcf..5d4e749087d 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java @@ -49,7 +49,7 @@ final class RlsLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(ChannelLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); LbPolicyConfiguration lbPolicyConfiguration = (LbPolicyConfiguration) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -78,6 +78,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { // not required. this.lbPolicyConfiguration = lbPolicyConfiguration; } + return true; } @Override diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancerProvider.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancerProvider.java index 54ef848498c..d78f2d67fc0 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancerProvider.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancerProvider.java @@ -73,7 +73,7 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalanc new LbPolicyConfiguration(routeLookupConfig, routeLookupChannelServiceConfig, lbPolicy)); } catch (Exception e) { return ConfigOrError.fromError( - Status.INVALID_ARGUMENT + Status.UNAVAILABLE .withDescription("can't parse config: " + e.getMessage()) .withCause(e)); } diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java index 14e3e50bd49..cd164f5e2a7 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java @@ -109,20 +109,35 @@ static final class RouteLookupConfigConverter @Override protected RouteLookupConfig doForward(Map json) { - ImmutableList grpcKeyBuilders = + ImmutableList grpcKeybuilders = GrpcKeyBuilderConverter.covertAll( - checkNotNull(JsonUtil.getListOfObjects(json, "grpcKeyBuilders"), "grpcKeyBuilders")); - checkArgument(!grpcKeyBuilders.isEmpty(), "must have at least one GrpcKeyBuilder"); + checkNotNull(JsonUtil.getListOfObjects(json, "grpcKeybuilders"), "grpcKeybuilders")); + + // Validate grpc_keybuilders + checkArgument(!grpcKeybuilders.isEmpty(), "must have at least one GrpcKeyBuilder"); Set names = new HashSet<>(); - for (GrpcKeyBuilder keyBuilder : grpcKeyBuilders) { + for (GrpcKeyBuilder keyBuilder : grpcKeybuilders) { for (Name name : keyBuilder.names()) { checkArgument(names.add(name), "duplicate names in grpc_keybuilders: " + name); } + + Set keys = new HashSet<>(); + for (NameMatcher header : keyBuilder.headers()) { + checkKeys(keys, header.key(), "header"); + } + for (String key : keyBuilder.constantKeys().keySet()) { + checkKeys(keys, key, "constant"); + } + String extraKeyStr = keyToString(keyBuilder.extraKeys()); + checkArgument(keys.add(extraKeyStr), + "duplicate extra key in grpc_keybuilders: " + extraKeyStr); } + + // Validate lookup_service String lookupService = JsonUtil.getString(json, "lookupService"); checkArgument(!Strings.isNullOrEmpty(lookupService), "lookupService must not be empty"); try { - new URI(lookupService); + URI unused = new URI(lookupService); } catch (URISyntaxException e) { throw new IllegalArgumentException( "The lookupService field is not valid URI: " + lookupService, e); @@ -147,7 +162,7 @@ protected RouteLookupConfig doForward(Map json) { cacheSize = Math.min(cacheSize, MAX_CACHE_SIZE); String defaultTarget = Strings.emptyToNull(JsonUtil.getString(json, "defaultTarget")); return RouteLookupConfig.builder() - .grpcKeyBuilders(grpcKeyBuilders) + .grpcKeybuilders(grpcKeybuilders) .lookupService(lookupService) .lookupServiceTimeoutInNanos(timeout) .maxAgeInNanos(maxAge) @@ -157,6 +172,11 @@ protected RouteLookupConfig doForward(Map json) { .build(); } + private static String keyToString(ExtraKeys extraKeys) { + return String.format("host: %s, service: %s, method: %s", + extraKeys.host(), extraKeys.service(), extraKeys.method()); + } + private static T orDefault(@Nullable T value, T defaultValue) { if (value == null) { return checkNotNull(defaultValue, "defaultValue"); @@ -170,6 +190,12 @@ protected Map doBackward(RouteLookupConfig routeLookupConfig) { } } + private static void checkKeys(Set keys, String key, String keyType) { + checkArgument(key != null, "unset " + keyType + " key"); + checkArgument(!key.isEmpty(), "Empty string for " + keyType + " key"); + checkArgument(keys.add(key), "duplicate " + keyType + " key in grpc_keybuilders: " + key); + } + private static final class GrpcKeyBuilderConverter { public static ImmutableList covertAll(List> keyBuilders) { ImmutableList.Builder keyBuilderList = ImmutableList.builder(); diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoData.java b/rls/src/main/java/io/grpc/rls/RlsProtoData.java index 929b7800759..49f32c6b6e3 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoData.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoData.java @@ -75,7 +75,7 @@ abstract static class RouteLookupConfig { * keyed by name. If no GrpcKeyBuilder matches, an empty key_map will be sent to the lookup * service; it should likely reply with a global default route and raise an alert. */ - abstract ImmutableList grpcKeyBuilders(); + abstract ImmutableList grpcKeybuilders(); /** * Returns the name of the lookup service as a gRPC URI. Typically, this will be a subdomain of @@ -119,7 +119,7 @@ static Builder builder() { @AutoValue.Builder abstract static class Builder { - abstract Builder grpcKeyBuilders(ImmutableList grpcKeyBuilders); + abstract Builder grpcKeybuilders(ImmutableList grpcKeybuilders); abstract Builder lookupService(String lookupService); diff --git a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java index 4a32bece2b2..a6ca0137ff1 100644 --- a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java +++ b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java @@ -50,10 +50,10 @@ final class RlsRequestFactory { private static Map createKeyBuilderTable( RouteLookupConfig config) { Map table = new HashMap<>(); - for (GrpcKeyBuilder grpcKeyBuilder : config.grpcKeyBuilders()) { + for (GrpcKeyBuilder grpcKeyBuilder : config.grpcKeybuilders()) { for (Name name : grpcKeyBuilder.names()) { - boolean hasMethod = name.method() == null || name.method().isEmpty(); - String method = hasMethod ? "*" : name.method(); + boolean noMethod = name.method() == null || name.method().isEmpty(); + String method = noMethod ? "*" : name.method(); String path = "/" + name.service() + "/" + method; table.put(path, grpcKeyBuilder); } @@ -89,7 +89,7 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { rlsRequestHeaders.put(extraKeys.method(), method); } rlsRequestHeaders.putAll(constantKeys); - return RouteLookupRequest.create(rlsRequestHeaders.build()); + return RouteLookupRequest.create(rlsRequestHeaders.buildOrThrow()); } private ImmutableMap.Builder createRequestHeaders( diff --git a/rls/src/test/java/io/grpc/rls/AdaptiveThrottlerTest.java b/rls/src/test/java/io/grpc/rls/AdaptiveThrottlerTest.java index 8da8e9b0210..6852b2479d5 100644 --- a/rls/src/test/java/io/grpc/rls/AdaptiveThrottlerTest.java +++ b/rls/src/test/java/io/grpc/rls/AdaptiveThrottlerTest.java @@ -18,8 +18,8 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.common.base.Ticker; import io.grpc.internal.FakeClock; -import io.grpc.internal.TimeProvider; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; @@ -30,21 +30,23 @@ public class AdaptiveThrottlerTest { private static final float TOLERANCE = 0.0001f; private final FakeClock fakeClock = new FakeClock(); - private final TimeProvider fakeTimeProvider = fakeClock.getTimeProvider(); + private final Ticker fakeTicker = fakeClock.getTicker(); private final AdaptiveThrottler throttler = new AdaptiveThrottler.Builder() .setHistorySeconds(1) .setRatioForAccepts(1.0f) .setRequestsPadding(1) - .setTimeProvider(fakeTimeProvider) + .setTicker(fakeTicker) .build(); @Test public void shouldThrottle() { + long startTime = fakeClock.currentTimeMillis(); + // initial states - assertThat(throttler.requestStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(0L); - assertThat(throttler.throttledStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(0L); - assertThat(throttler.getThrottleProbability(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.requestStat.get(fakeTicker.read())).isEqualTo(0L); + assertThat(throttler.throttledStat.get(fakeTicker.read())).isEqualTo(0L); + assertThat(throttler.getThrottleProbability(fakeTicker.read())) .isWithin(TOLERANCE).of(0.0f); // Request 1, allowed by all. @@ -52,10 +54,10 @@ public void shouldThrottle() { fakeClock.forwardTime(1L, TimeUnit.MILLISECONDS); throttler.registerBackendResponse(false); - assertThat(throttler.requestStat.get(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.requestStat.get(fakeTicker.read())) .isEqualTo(1L); - assertThat(throttler.throttledStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(0L); - assertThat(throttler.getThrottleProbability(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.throttledStat.get(fakeTicker.read())).isEqualTo(0L); + assertThat(throttler.getThrottleProbability(fakeTicker.read())) .isWithin(TOLERANCE).of(0.0f); // Request 2, throttled by backend @@ -63,25 +65,26 @@ public void shouldThrottle() { fakeClock.forwardTime(1L, TimeUnit.MILLISECONDS); throttler.registerBackendResponse(true); - assertThat(throttler.requestStat.get(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.requestStat.get(fakeTicker.read())) .isEqualTo(2L); - assertThat(throttler.throttledStat.get(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.throttledStat.get(fakeTicker.read())) .isEqualTo(1L); - assertThat(throttler.getThrottleProbability(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.getThrottleProbability(fakeTicker.read())) .isWithin(TOLERANCE) .of(1.0f / 3.0f); - // Skip half a second (half the duration). - fakeClock.forwardTime(500 - fakeClock.currentTimeMillis(), TimeUnit.MILLISECONDS); + // Skip to half second mark from the beginning (half the duration). + fakeClock.forwardTime(500 - (fakeClock.currentTimeMillis() - startTime), + TimeUnit.MILLISECONDS); // Request 3, throttled by backend assertThat(throttler.shouldThrottle(0.4f)).isFalse(); fakeClock.forwardTime(1L, TimeUnit.MILLISECONDS); throttler.registerBackendResponse(true); - assertThat(throttler.requestStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(3L); - assertThat(throttler.throttledStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(2L); - assertThat(throttler.getThrottleProbability(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.requestStat.get(fakeTicker.read())).isEqualTo(3L); + assertThat(throttler.throttledStat.get(fakeTicker.read())).isEqualTo(2L); + assertThat(throttler.getThrottleProbability(fakeTicker.read())) .isWithin(TOLERANCE) .of(2.0f / 4.0f); @@ -89,19 +92,33 @@ public void shouldThrottle() { assertThat(throttler.shouldThrottle(0.4f)).isTrue(); fakeClock.forwardTime(1L, TimeUnit.MILLISECONDS); - assertThat(throttler.requestStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(4L); - assertThat(throttler.throttledStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(3L); - assertThat(throttler.getThrottleProbability(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.requestStat.get(fakeTicker.read())).isEqualTo(4L); + assertThat(throttler.throttledStat.get(fakeTicker.read())).isEqualTo(3L); + assertThat(throttler.getThrottleProbability(fakeTicker.read())) .isWithin(TOLERANCE) .of(3.0f / 5.0f); // Skip to the point where only requests 3 and 4 are visible. - fakeClock.forwardTime(1250 - fakeClock.currentTimeMillis(), TimeUnit.MILLISECONDS); + fakeClock.forwardTime( + 1250 - (fakeClock.currentTimeMillis() - startTime), TimeUnit.MILLISECONDS); - assertThat(throttler.requestStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(2L); - assertThat(throttler.throttledStat.get(fakeTimeProvider.currentTimeNanos())).isEqualTo(2L); - assertThat(throttler.getThrottleProbability(fakeTimeProvider.currentTimeNanos())) + assertThat(throttler.requestStat.get(fakeTicker.read())).isEqualTo(2L); + assertThat(throttler.throttledStat.get(fakeTicker.read())).isEqualTo(2L); + assertThat(throttler.getThrottleProbability(fakeTicker.read())) .isWithin(TOLERANCE) .of(2.0f / 3.0f); } + + /** + * Check that when the ticker returns a negative value for now that the slot detection logic + * is correctly handled and then when the value transitions from negative to positive that things + * continue to work correctly. + */ + @Test + public void negativeTickerValues() { + long rewindAmount = TimeUnit.MILLISECONDS.toNanos(300) + fakeClock.getTicker().read(); + fakeClock.forwardTime(-1 * rewindAmount, TimeUnit.NANOSECONDS); + assertThat(fakeClock.getTicker().read()).isEqualTo(TimeUnit.MILLISECONDS.toNanos(-300)); + shouldThrottle(); + } } diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index 588bddcaa61..120a486dec6 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -21,9 +21,10 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.rls.CachingRlsLbClient.RLS_DATA_KEY; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -55,12 +56,12 @@ import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.FakeClock; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.rls.CachingRlsLbClient.CacheEntry; import io.grpc.rls.CachingRlsLbClient.CachedRouteLookupResponse; import io.grpc.rls.CachingRlsLbClient.RlsPicker; -import io.grpc.rls.DoNotUseDirectScheduledExecutorService.FakeTimeProvider; import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LruCache.EvictionListener; @@ -79,8 +80,10 @@ import java.io.IOException; import java.lang.Thread.UncaughtExceptionHandler; import java.net.SocketAddress; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; @@ -130,13 +133,11 @@ public void uncaughtException(Thread t, Throwable e) { new ChildLbResolvedAddressFactory( ImmutableList.of(new EquivalentAddressGroup(socketAddress)), Attributes.EMPTY); private final TestLoadBalancerProvider lbProvider = new TestLoadBalancerProvider(); - private final DoNotUseDirectScheduledExecutorService fakeScheduledExecutorService = - mock(DoNotUseDirectScheduledExecutorService.class, CALLS_REAL_METHODS); - private final FakeTimeProvider fakeTimeProvider = - fakeScheduledExecutorService.getFakeTimeProvider(); + private final FakeClock fakeClock = new FakeClock(); private final StaticFixedDelayRlsServerImpl rlsServerImpl = new StaticFixedDelayRlsServerImpl( - TimeUnit.MILLISECONDS.toNanos(SERVER_LATENCY_MILLIS), fakeScheduledExecutorService); + TimeUnit.MILLISECONDS.toNanos(SERVER_LATENCY_MILLIS), + fakeClock.getScheduledExecutorService()); private final ChildLoadBalancingPolicy childLbPolicy = new ChildLoadBalancingPolicy("target", Collections.emptyMap(), lbProvider); private final Helper helper = @@ -150,6 +151,7 @@ public void uncaughtException(Thread t, Throwable e) { private String rlsChannelOverriddenAuthority; private void setUpRlsLbClient() { + fakeThrottler.resetCounts(); rlsLbClient = CachingRlsLbClient.newBuilder() .setBackoffProvider(fakeBackoffProvider) @@ -158,7 +160,7 @@ private void setUpRlsLbClient() { .setHelper(helper) .setLbPolicyConfig(lbPolicyConfiguration) .setThrottler(fakeThrottler) - .setTimeProvider(fakeTimeProvider) + .setTicker(fakeClock.getTicker()) .build(); } @@ -200,19 +202,19 @@ public void get_noError_lifeCycle() throws Exception { assertThat(resp.isPending()).isTrue(); // server response - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); resp = getInSyncContext(routeLookupRequest); assertThat(resp.hasData()).isTrue(); // cache hit for staled entry - fakeTimeProvider.forwardTime(ROUTE_LOOKUP_CONFIG.staleAgeInNanos(), TimeUnit.NANOSECONDS); + fakeClock.forwardTime(ROUTE_LOOKUP_CONFIG.staleAgeInNanos(), TimeUnit.NANOSECONDS); resp = getInSyncContext(routeLookupRequest); assertThat(resp.hasData()).isTrue(); // async refresh finishes - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); inOrder .verify(evictionListener) .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.REPLACED)); @@ -222,7 +224,7 @@ public void get_noError_lifeCycle() throws Exception { assertThat(resp.hasData()).isTrue(); // existing cache expired - fakeTimeProvider.forwardTime(ROUTE_LOOKUP_CONFIG.maxAgeInNanos(), TimeUnit.NANOSECONDS); + fakeClock.forwardTime(ROUTE_LOOKUP_CONFIG.maxAgeInNanos(), TimeUnit.NANOSECONDS); resp = getInSyncContext(routeLookupRequest); @@ -256,7 +258,7 @@ public void rls_withCustomRlsChannelServiceConfig() throws Exception { .setHelper(helper) .setLbPolicyConfig(lbPolicyConfiguration) .setThrottler(fakeThrottler) - .setTimeProvider(fakeTimeProvider) + .setTicker(fakeClock.getTicker()) .build(); RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); @@ -270,7 +272,7 @@ public void rls_withCustomRlsChannelServiceConfig() throws Exception { assertThat(resp.isPending()).isTrue(); // server response - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); resp = getInSyncContext(routeLookupRequest); assertThat(resp.hasData()).isTrue(); @@ -296,7 +298,7 @@ public void get_throttledAndRecover() throws Exception { assertThat(resp.hasError()).isTrue(); - fakeTimeProvider.forwardTime(10, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); // initially backed off entry is backed off again verify(evictionListener) .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.REPLACED)); @@ -307,14 +309,14 @@ public void get_throttledAndRecover() throws Exception { // let it pass throttler fakeThrottler.nextResult = false; - fakeTimeProvider.forwardTime(10, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); resp = getInSyncContext(routeLookupRequest); assertThat(resp.isPending()).isTrue(); // server responses - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); resp = getInSyncContext(routeLookupRequest); @@ -337,7 +339,7 @@ public void get_updatesLbState() throws Exception { // valid channel CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); assertThat(resp.isPending()).isTrue(); - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); resp = getInSyncContext(routeLookupRequest); assertThat(resp.hasData()).isTrue(); @@ -361,6 +363,8 @@ public void get_updatesLbState() throws Exception { assertThat(pickResult.getStatus().isOk()).isTrue(); assertThat(pickResult.getSubchannel()).isNotNull(); assertThat(headers.get(RLS_DATA_KEY)).isEqualTo("header-rls-data-value"); + assertThat(fakeThrottler.getNumThrottled()).isEqualTo(0); + assertThat(fakeThrottler.getNumUnthrottled()).isEqualTo(1); // move backoff further back to only test error behavior fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); @@ -369,7 +373,7 @@ public void get_updatesLbState() throws Exception { RouteLookupRequest.create(ImmutableMap.of()); CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); assertThat(errorResp.isPending()).isTrue(); - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); errorResp = getInSyncContext(invalidRouteLookupRequest); assertThat(errorResp.hasError()).isTrue(); @@ -386,7 +390,98 @@ public void get_updatesLbState() throws Exception { headers, CallOptions.DEFAULT)); assertThat(pickResult.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(pickResult.getStatus().getDescription()).isEqualTo("fallback not available"); + assertThat(pickResult.getStatus().getDescription()).contains("fallback not available"); + assertThat(fakeThrottler.getNumThrottled()).isEqualTo(1); + assertThat(fakeThrottler.getNumUnthrottled()).isEqualTo(1); + } + + @Test + public void get_withAdaptiveThrottler() throws Exception { + AdaptiveThrottler adaptiveThrottler = + new AdaptiveThrottler.Builder() + .setHistorySeconds(1) + .setRatioForAccepts(1.0f) + .setRequestsPadding(1) + .setTicker(fakeClock.getTicker()) + .build(); + + this.rlsLbClient = + CachingRlsLbClient.newBuilder() + .setBackoffProvider(fakeBackoffProvider) + .setResolvedAddressesFactory(resolvedAddressFactory) + .setEvictionListener(evictionListener) + .setHelper(helper) + .setLbPolicyConfig(lbPolicyConfiguration) + .setThrottler(adaptiveThrottler) + .setTicker(fakeClock.getTicker()) + .build(); + InOrder inOrder = inOrder(helper); + RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); + rlsServerImpl.setLookupTable( + ImmutableMap.of( + routeLookupRequest, + RouteLookupResponse.create( + ImmutableList.of("primary.cloudbigtable.googleapis.com"), + "header-rls-data-value"))); + + // valid channel + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + assertThat(resp.isPending()).isTrue(); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + + resp = getInSyncContext(routeLookupRequest); + assertThat(resp.hasData()).isTrue(); + + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + ArgumentCaptor stateCaptor = + ArgumentCaptor.forClass(ConnectivityState.class); + inOrder.verify(helper, times(2)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + + Metadata headers = new Metadata(); + PickResult pickResult = pickerCaptor.getValue().pickSubchannel( + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod().toBuilder().setFullMethodName("service1/create") + .build(), + headers, + CallOptions.DEFAULT)); + assertThat(pickResult.getSubchannel()).isNotNull(); + assertThat(headers.get(RLS_DATA_KEY)).isEqualTo("header-rls-data-value"); + + // move backoff further back to only test error behavior + fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); + // try to get invalid + RouteLookupRequest invalidRouteLookupRequest = + RouteLookupRequest.create(ImmutableMap.of()); + CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); + assertThat(errorResp.isPending()).isTrue(); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + + errorResp = getInSyncContext(invalidRouteLookupRequest); + assertThat(errorResp.hasError()).isTrue(); + + // Channel is still READY because the subchannel for method /service1/create is still READY. + // Method /doesn/exists will use fallback child balancer and fail immediately. + inOrder.verify(helper) + .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); + PickSubchannelArgsImpl invalidArgs = getInvalidArgs(headers); + pickResult = pickerCaptor.getValue().pickSubchannel(invalidArgs); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(pickResult.getStatus().getDescription()).contains("fallback not available"); + long time = fakeClock.getTicker().read(); + assertThat(adaptiveThrottler.requestStat.get(time)).isEqualTo(2L); + assertThat(adaptiveThrottler.throttledStat.get(time)).isEqualTo(1L); + } + + private PickSubchannelArgsImpl getInvalidArgs(Metadata headers) { + PickSubchannelArgsImpl invalidArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod().toBuilder() + .setFullMethodName("doesn/exists") + .build(), + headers, + CallOptions.DEFAULT); + return invalidArgs; } @Test @@ -405,20 +500,21 @@ public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); assertThat(resp.isPending()).isTrue(); - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); resp = getInSyncContext(routeLookupRequest); assertThat(resp.hasData()).isTrue(); assertThat(resp.getHeaderData()).isEqualTo("header"); ChildPolicyWrapper childPolicyWrapper = resp.getChildPolicyWrapper(); + assertNotNull(childPolicyWrapper); assertThat(childPolicyWrapper.getTarget()).isEqualTo("target"); assertThat(childPolicyWrapper.getPicker()).isNotInstanceOf(RlsPicker.class); // request2 has same target, it should reuse childPolicyWrapper CachedRouteLookupResponse resp2 = getInSyncContext(routeLookupRequest2); assertThat(resp2.isPending()).isTrue(); - fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); resp2 = getInSyncContext(routeLookupRequest2); assertThat(resp2.hasData()).isTrue(); @@ -426,9 +522,75 @@ public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { assertThat(resp2.getChildPolicyWrapper()).isEqualTo(resp.getChildPolicyWrapper()); } + @Test + public void get_childPolicyWrapper_multiTarget() throws Exception { + setUpRlsLbClient(); + RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + rlsServerImpl.setLookupTable( + ImmutableMap.of( + routeLookupRequest, + RouteLookupResponse.create( + ImmutableList.of("target1", "target2", "target3"), + "header"))); + + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + assertThat(resp.isPending()).isTrue(); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + + resp = getInSyncContext(routeLookupRequest); + assertThat(resp.hasData()).isTrue(); + List policyWrappers = new ArrayList<>(); + + for (int i = 1; i <= 3; i++) { + String target = "target" + i; + policyWrappers.add(resp.getChildPolicyWrapper(target)); + } + + // Set to states: null, READY, null + setState(policyWrappers.get(1), ConnectivityState.READY); + ChildPolicyWrapper childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(0), childPolicy); + + // Set to states: null, CONNECTING, null + setState(policyWrappers.get(1), ConnectivityState.CONNECTING); + childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(0), childPolicy); + + // Set to states: null, CONNECTING, READY + setState(policyWrappers.get(2), ConnectivityState.READY); + childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(0), childPolicy); + + // Set to states: READY, CONNECTING, READY + setState(policyWrappers.get(0), ConnectivityState.READY); + childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(0), childPolicy); + + // Set to states: TRANSIENT_FAILURE, CONNECTING, READY + setState(policyWrappers.get(0), ConnectivityState.TRANSIENT_FAILURE); + childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(1), childPolicy); + + // Set to states: TRANSIENT_FAILURE, TRANSIENT_FAILURE, TRANSIENT_FAILURE + setState(policyWrappers.get(1), ConnectivityState.TRANSIENT_FAILURE); + setState(policyWrappers.get(2), ConnectivityState.TRANSIENT_FAILURE); + childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(0), childPolicy); + + // Set to states: TRANSIENT_FAILURE, TRANSIENT_FAILURE, READY + setState(policyWrappers.get(2), ConnectivityState.READY); + childPolicy = resp.getChildPolicyWrapper(); + assertSame(policyWrappers.get(2), childPolicy); + } + + private void setState(ChildPolicyWrapper policyWrapper, ConnectivityState newState) { + policyWrapper.getHelper().updateBalancingState(newState, policyWrapper.getPicker()); + } + private static RouteLookupConfig getRouteLookupConfig() { return RouteLookupConfig.builder() - .grpcKeyBuilders(ImmutableList.of( + .grpcKeybuilders(ImmutableList.of( GrpcKeyBuilder.create( ImmutableList.of(Name.create("service1", "create")), ImmutableList.of( @@ -437,7 +599,7 @@ private static RouteLookupConfig getRouteLookupConfig() { ExtraKeys.create("server", "service-key", "method-key"), ImmutableMap.of()))) .lookupService("service1") - .lookupServiceTimeoutInNanos(TimeUnit.SECONDS.toNanos(2)) + .lookupServiceTimeoutInNanos(TimeUnit.SECONDS.toNanos(10)) .maxAgeInNanos(TimeUnit.SECONDS.toNanos(300)) .staleAgeInNanos(TimeUnit.SECONDS.toNanos(240)) .cacheSizeBytes(1000) @@ -500,7 +662,7 @@ public LoadBalancer newLoadBalancer(final Helper helper) { LoadBalancer loadBalancer = new LoadBalancer() { @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { Map config = (Map) resolvedAddresses.getLoadBalancingPolicyConfig(); if (DEFAULT_TARGET.equals(config.get("target"))) { helper.updateBalancingState( @@ -509,7 +671,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { return PickResult.withError( - Status.UNAVAILABLE.withDescription("fallback not available")); + Status.UNAVAILABLE.withDescription("fallback not available")); } }); } else { @@ -522,6 +684,8 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } }); } + + return true; } @Override @@ -670,7 +834,7 @@ public ChannelCredentials withoutBearerTokens() { @Override public ScheduledExecutorService getScheduledExecutorService() { - return fakeScheduledExecutorService; + return fakeClock.getScheduledExecutorService(); } @Override @@ -685,6 +849,8 @@ public ChannelLogger getChannelLogger() { } private static final class FakeThrottler implements Throttler { + int numUnthrottled; + int numThrottled; private boolean nextResult = false; @@ -695,7 +861,24 @@ public boolean shouldThrottle() { @Override public void registerBackendResponse(boolean throttled) { - // no-op + if (throttled) { + numThrottled++; + } else { + numUnthrottled++; + } + } + + public int getNumUnthrottled() { + return numUnthrottled; + } + + public int getNumThrottled() { + return numThrottled; + } + + public void resetCounts() { + numThrottled = 0; + numUnthrottled = 0; } } } diff --git a/rls/src/test/java/io/grpc/rls/DoNotUseDirectScheduledExecutorService.java b/rls/src/test/java/io/grpc/rls/DoNotUseDirectScheduledExecutorService.java deleted file mode 100644 index 9fa06a3a722..00000000000 --- a/rls/src/test/java/io/grpc/rls/DoNotUseDirectScheduledExecutorService.java +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.rls; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; -import static org.mockito.Mockito.mock; - -import com.google.common.base.MoreObjects; -import io.grpc.internal.TimeProvider; -import java.util.Comparator; -import java.util.PriorityQueue; -import java.util.concurrent.Delayed; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicReference; - -/** - * A fake minimal implementation of {@link ScheduledExecutorService} *only* supports - * {@link ScheduledExecutorService#scheduleAtFixedRate(Runnable, long, long, TimeUnit)} (at most 1 - * task is allowed) and {@link ScheduledExecutorService#schedule(Runnable, long, TimeUnit)}. It is - * directExecutor equivalent for {@link ScheduledExecutorService}. - * - *

    Example: - *

    - * import static org.mockito.Mockito.CALLS_REAL_METHODS;
    - * import static org.mockito.Mockito.mock;
    - *
    - * private final DoNotUseDirectScheduledExecutorService fakeScheduledService =
    - *     mock(DoNotUseDirectScheduledExecutorService.class, CALLS_REAL_METHODS);
    - * 
    - * - *

    Note: This class is only intended to be used in this test with CALL_REAL_METHODS mock. This - * implementation is not thread-safe. Not safe to use elsewhere. - */ -abstract class DoNotUseDirectScheduledExecutorService implements ScheduledExecutorService { - - private long currTimeNanos; - private long period; - private long nextRun; - private AtomicReference repeatedCommand; - private PriorityQueue scheduledCommands; - private boolean initialized; - - private DoNotUseDirectScheduledExecutorService() { - throw new UnsupportedOperationException("this class is for mock only"); - } - - /** - * Note: CALLS_REAL_METHODS doesn't initialize instance variables, all the methods need to call - * maybeInit if they access instance variables. - */ - private void maybeInit() { - if (initialized) { - return; - } - - initialized = true; - repeatedCommand = new AtomicReference<>(); - scheduledCommands = new PriorityQueue<>(11, new ScheduledRunnableComparator()); - } - - @Override - public final ScheduledFuture scheduleAtFixedRate( - Runnable command, long initialDelay, long period, TimeUnit unit) { - maybeInit(); - checkArgument(period > 0, "period should be positive"); - checkArgument(initialDelay >= 0, "initial delay should be >= 0"); - checkState(this.repeatedCommand.get() == null, "only can schedule one"); - if (initialDelay == 0) { - initialDelay = period; - command.run(); - } - this.repeatedCommand.set(checkNotNull(command, "command")); - this.nextRun = checkNotNull(unit, "unit").toNanos(initialDelay) + currTimeNanos; - this.period = unit.toNanos(period); - return mock(ScheduledFuture.class); - } - - @Override - public final ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { - maybeInit(); - checkNotNull(command, "command"); - checkNotNull(unit, "unit"); - checkArgument(delay > 0, "delay must be positive"); - ScheduledRunnable scheduledRunnable = - new ScheduledRunnable(currTimeNanos + TimeUnit.NANOSECONDS.convert(delay, unit), command); - scheduledCommands.add(scheduledRunnable); - return scheduledRunnable.scheduledFuture; - } - - final FakeTimeProvider getFakeTimeProvider() { - maybeInit(); - return new FakeTimeProvider(); - } - - private void forwardTime(long delta, TimeUnit unit) { - maybeInit(); - checkNotNull(unit, "unit"); - checkArgument(delta > 0, "delta must be positive"); - long finalTime = currTimeNanos + unit.toNanos(delta); - - if (repeatedCommand.get() != null) { - while (finalTime >= nextRun) { - scheduledCommands.add(new ScheduledRunnable(nextRun, repeatedCommand.get())); - nextRun += period; - } - } - - while (!scheduledCommands.isEmpty() - && scheduledCommands.peek().scheduledTimeNanos <= finalTime) { - ScheduledRunnable scheduledCommand = scheduledCommands.poll(); - try { - // pretend to run at the scheduled time - currTimeNanos = scheduledCommand.scheduledTimeNanos; - scheduledCommand.run(); - } catch (Throwable t) { - throw new RuntimeException("failed to run scheduled command: " + scheduledCommand, t); - } - } - - this.currTimeNanos = finalTime; - } - - private final class ScheduledRunnable implements Runnable { - private final long scheduledTimeNanos; - private final Runnable command; - private final ScheduledFuture scheduledFuture = new ScheduledRunnable.FakeScheduledFuture(); - private boolean running = false; - private boolean done = false; - - public ScheduledRunnable(long scheduledTimeNanos, Runnable command) { - this.scheduledTimeNanos = scheduledTimeNanos; - this.command = checkNotNull(command, "command"); - } - - @Override - public void run() { - if (!scheduledFuture.isCancelled()) { - running = true; - command.run(); - done = true; - } - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("scheduledTimeNanos", scheduledTimeNanos) - .add("command", command) - .add("scheduledFuture", scheduledFuture) - .add("running", running) - .add("done", done) - .toString(); - } - - private final class FakeScheduledFuture implements ScheduledFuture { - boolean cancelled = false; - - @Override - public long getDelay(TimeUnit unit) { - return unit.convert(scheduledTimeNanos - currTimeNanos, TimeUnit.NANOSECONDS); - } - - @Override - public int compareTo(Delayed unused) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - if (running) { - return false; - } - cancelled = true; - return true; - } - - @Override - public boolean isCancelled() { - return cancelled; - } - - @Override - public boolean isDone() { - return done; - } - - @Override - public Object get() throws InterruptedException, ExecutionException { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(long timeout, TimeUnit unit) - throws InterruptedException, ExecutionException, TimeoutException { - throw new UnsupportedOperationException(); - } - } - } - - private static final class ScheduledRunnableComparator - implements Comparator { - @Override - public int compare(ScheduledRunnable o1, ScheduledRunnable o2) { - return Long.compare(o1.scheduledTimeNanos, o2.scheduledTimeNanos); - } - } - - final class FakeTimeProvider implements TimeProvider { - - @Override - public long currentTimeNanos() { - return currTimeNanos; - } - - void forwardTime(long delta, TimeUnit unit) { - DoNotUseDirectScheduledExecutorService.this.forwardTime(delta, unit); - } - } -} diff --git a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java index 60266e15998..19b3a012151 100644 --- a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java +++ b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java @@ -19,11 +19,10 @@ import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.CALLS_REAL_METHODS; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import io.grpc.rls.DoNotUseDirectScheduledExecutorService.FakeTimeProvider; +import com.google.common.base.Ticker; +import io.grpc.internal.FakeClock; import io.grpc.rls.LruCache.EvictionListener; import io.grpc.rls.LruCache.EvictionType; import java.util.Objects; @@ -45,9 +44,8 @@ public class LinkedHashLruCacheTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private final DoNotUseDirectScheduledExecutorService fakeScheduledService = - mock(DoNotUseDirectScheduledExecutorService.class, CALLS_REAL_METHODS); - private final FakeTimeProvider timeProvider = fakeScheduledService.getFakeTimeProvider(); + private final FakeClock fakeClock = new FakeClock(); + private final Ticker ticker = fakeClock.getTicker(); @Mock private EvictionListener evictionListener; @@ -60,8 +58,8 @@ public void setUp() { evictionListener, 10, TimeUnit.NANOSECONDS, - fakeScheduledService, - timeProvider, + fakeClock.getScheduledExecutorService(), + fakeClock.getTicker(), new Object()) { @Override protected boolean isExpired(Integer key, Entry value, long nowNanos) { @@ -88,8 +86,8 @@ public void eviction_size() { @Test public void size() { - Entry entry1 = new Entry("Entry0", timeProvider.currentTimeNanos() + 10); - Entry entry2 = new Entry("Entry1", timeProvider.currentTimeNanos() + 20); + Entry entry1 = new Entry("Entry0", ticker.read() + 10); + Entry entry2 = new Entry("Entry1", ticker.read() + 20); cache.cache(0, entry1); cache.cache(1, entry2); assertThat(cache.estimatedSize()).isEqualTo(2); @@ -103,22 +101,22 @@ public void size() { @Test public void eviction_expire() { - Entry toBeEvicted = new Entry("Entry0", timeProvider.currentTimeNanos() + 10); - Entry survivor = new Entry("Entry1", timeProvider.currentTimeNanos() + 20); + Entry toBeEvicted = new Entry("Entry0", ticker.read() + 10); + Entry survivor = new Entry("Entry1", ticker.read() + 20); cache.cache(0, toBeEvicted); cache.cache(1, survivor); - timeProvider.forwardTime(10, TimeUnit.NANOSECONDS); + fakeClock.forwardTime(10, TimeUnit.NANOSECONDS); verify(evictionListener).onEviction(0, toBeEvicted, EvictionType.EXPIRED); - timeProvider.forwardTime(10, TimeUnit.NANOSECONDS); + fakeClock.forwardTime(10, TimeUnit.NANOSECONDS); verify(evictionListener).onEviction(1, survivor, EvictionType.EXPIRED); } @Test public void eviction_explicit() { - Entry toBeEvicted = new Entry("Entry0", timeProvider.currentTimeNanos() + 10); - Entry survivor = new Entry("Entry1", timeProvider.currentTimeNanos() + 20); + Entry toBeEvicted = new Entry("Entry0", ticker.read() + 10); + Entry survivor = new Entry("Entry1", ticker.read() + 20); cache.cache(0, toBeEvicted); cache.cache(1, survivor); @@ -129,8 +127,8 @@ public void eviction_explicit() { @Test public void eviction_replaced() { - Entry toBeEvicted = new Entry("Entry0", timeProvider.currentTimeNanos() + 10); - Entry survivor = new Entry("Entry1", timeProvider.currentTimeNanos() + 20); + Entry toBeEvicted = new Entry("Entry0", ticker.read() + 10); + Entry survivor = new Entry("Entry1", ticker.read() + 20); cache.cache(0, toBeEvicted); cache.cache(0, survivor); @@ -141,7 +139,7 @@ public void eviction_replaced() { public void eviction_size_shouldEvictAlreadyExpired() { for (int i = 1; i <= MAX_SIZE; i++) { // last two entries are <= current time (already expired) - cache.cache(i, new Entry("Entry" + i, timeProvider.currentTimeNanos() + MAX_SIZE - i - 1)); + cache.cache(i, new Entry("Entry" + i, ticker.read() + MAX_SIZE - i - 1)); } cache.cache(MAX_SIZE + 1, new Entry("should kick the first", Long.MAX_VALUE)); @@ -155,7 +153,7 @@ public void eviction_size_shouldEvictAlreadyExpired() { public void eviction_get_shouldNotReturnAlreadyExpired() { for (int i = 1; i <= MAX_SIZE; i++) { // last entry is already expired when added - cache.cache(i, new Entry("Entry" + i, timeProvider.currentTimeNanos() + MAX_SIZE - i)); + cache.cache(i, new Entry("Entry" + i, ticker.read() + MAX_SIZE - i)); } assertThat(cache.estimatedSize()).isEqualTo(MAX_SIZE); @@ -166,7 +164,7 @@ public void eviction_get_shouldNotReturnAlreadyExpired() { @Test public void updateEntrySize() { - Entry entry = new Entry("Entry", timeProvider.currentTimeNanos() + 10); + Entry entry = new Entry("Entry", ticker.read() + 10); cache.cache(1, entry); @@ -185,8 +183,8 @@ public void updateEntrySize() { @Test public void updateEntrySize_multipleEntries() { - Entry entry1 = new Entry("Entry", timeProvider.currentTimeNanos() + 10, 2); - Entry entry2 = new Entry("Entry2", timeProvider.currentTimeNanos() + 10, 3); + Entry entry1 = new Entry("Entry", ticker.read() + 10, 2); + Entry entry2 = new Entry("Entry2", ticker.read() + 10, 3); cache.cache(1, entry1); cache.cache(2, entry2); @@ -202,8 +200,8 @@ public void updateEntrySize_multipleEntries() { @Test public void invalidateAll() { - Entry entry1 = new Entry("Entry", timeProvider.currentTimeNanos() + 10); - Entry entry2 = new Entry("Entry2", timeProvider.currentTimeNanos() + 10); + Entry entry1 = new Entry("Entry", ticker.read() + 10); + Entry entry2 = new Entry("Entry2", ticker.read() + 10); cache.cache(1, entry1); cache.cache(2, entry2); @@ -217,9 +215,9 @@ public void invalidateAll() { @Test public void resize() { - Entry entry1 = new Entry("Entry", timeProvider.currentTimeNanos() + 10); - Entry entry2 = new Entry("Entry2", timeProvider.currentTimeNanos() + 10); - Entry entry3 = new Entry("Entry3", timeProvider.currentTimeNanos() + 10); + Entry entry1 = new Entry("Entry", ticker.read() + 10); + Entry entry2 = new Entry("Entry2", ticker.read() + 10); + Entry entry3 = new Entry("Entry3", ticker.read() + 10); cache.cache(1, entry1); cache.cache(2, entry2); diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index 5dfdf948d4b..9f95200d503 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -21,7 +21,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -56,6 +55,7 @@ import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.lookup.v1.RouteLookupServiceGrpc; @@ -74,6 +74,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import javax.annotation.Nonnull; import org.junit.After; import org.junit.Before; @@ -98,8 +99,7 @@ public class RlsLoadBalancerTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private final RlsLoadBalancerProvider provider = new RlsLoadBalancerProvider(); - private final DoNotUseDirectScheduledExecutorService fakeScheduledExecutorService = - mock(DoNotUseDirectScheduledExecutorService.class, CALLS_REAL_METHODS); + private final FakeClock fakeClock = new FakeClock(); private final SynchronizationContext syncContext = new SynchronizationContext(new UncaughtExceptionHandler() { @Override @@ -111,6 +111,7 @@ public void uncaughtException(Thread t, Throwable e) { mock(Helper.class, AdditionalAnswers.delegatesTo(new FakeHelper())); private final FakeRlsServerImpl fakeRlsServerImpl = new FakeRlsServerImpl(); private final Deque subchannels = new LinkedList<>(); + private final FakeThrottler fakeThrottler = new FakeThrottler(); @Mock private Marshaller mockMarshaller; @Captor @@ -121,7 +122,7 @@ public void uncaughtException(Thread t, Throwable e) { private String defaultTarget = "defaultTarget"; @Before - public void setUp() throws Exception { + public void setUp() { MockitoAnnotations.initMocks(this); fakeSearchMethod = @@ -155,19 +156,51 @@ public void setUp() throws Exception { rlsLb.cachingRlsLbClientBuilderProvider = new CachingRlsLbClientBuilderProvider() { @Override public CachingRlsLbClient.Builder get() { - // using default throttler which doesn't throttle - return CachingRlsLbClient.newBuilder(); + // using fake throttler to allow enablement of throttler + return CachingRlsLbClient.newBuilder() + .setThrottler(fakeThrottler) + .setTicker(fakeClock.getTicker()); } }; } @After - public void tearDown() throws Exception { + public void tearDown() { rlsLb.shutdown(); } @Test - public void lb_working_withDefaultTarget() throws Exception { + public void lb_serverStatusCodeConversion() throws Exception { + deliverResolvedAddresses(); + InOrder inOrder = inOrder(helper); + inOrder.verify(helper) + .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + Metadata headers = new Metadata(); + PickSubchannelArgsImpl fakeSearchMethodArgs = + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT); + PickResult res = picker.pickSubchannel(fakeSearchMethodArgs); + FakeSubchannel subchannel = (FakeSubchannel) res.getSubchannel(); + assertThat(subchannel).isNotNull(); + + // Ensure happy path is unaffected + subchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + res = picker.pickSubchannel(fakeSearchMethodArgs); + assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); + + // Check on conversion + Throwable cause = new Throwable("cause"); + Status aborted = Status.ABORTED.withCause(cause).withDescription("base desc"); + Status serverStatus = CachingRlsLbClient.convertRlsServerStatus(aborted, "conv.test"); + assertThat(serverStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(serverStatus.getCause()).isEqualTo(cause); + assertThat(serverStatus.getDescription()).contains("RLS server returned: "); + assertThat(serverStatus.getDescription()).endsWith("ABORTED: base desc"); + assertThat(serverStatus.getDescription()).contains("RLS server conv.test"); + } + + @Test + public void lb_working_withDefaultTarget_rlsResponding() throws Exception { deliverResolvedAddresses(); InOrder inOrder = inOrder(helper); inOrder.verify(helper) @@ -207,7 +240,7 @@ public void lb_working_withDefaultTarget() throws Exception { FakeSubchannel rescueSubchannel = subchannels.getLast(); // search subchannel is down, rescue subchannel is connecting - searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); + searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -216,40 +249,79 @@ public void lb_working_withDefaultTarget() throws Exception { inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - // search again, use pending fallback because searchSubchannel is in failure mode + // search again, verify that it doesn't use fallback, since RLS server responded, even though + // subchannel is in failure mode res = picker.pickSubchannel( new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getStatus().isOk()).isTrue(); + assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); + } + @Test + public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { + fakeThrottler.nextResult = true; + + deliverResolvedAddresses(); + InOrder inOrder = inOrder(helper); + inOrder.verify(helper) + .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + Metadata headers = new Metadata(); + PickResult res; + + // Search that when the RLS server doesn't respond, that fallback is used + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + FakeSubchannel fallbackSubchannel = (FakeSubchannel) res.getSubchannel(); + assertThat(fallbackSubchannel).isNotNull(); + + assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); + assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - assertThat(subchannels).hasSize(3); - FakeSubchannel fallbackSubchannel = subchannels.getLast(); fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(1)) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); inOrder.verifyNoMoreInteractions(); res = picker.pickSubchannel( new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); - assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); + assertThat(res.getSubchannel()).isSameInstanceAs(fallbackSubchannel); res = picker.pickSubchannel( new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); - assertThat(res.getSubchannel().getAddresses()).isEqualTo(rescueSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()).isEqualTo(rescueSubchannel.getAttributes()); + assertThat(res.getSubchannel()).isSameInstanceAs(fallbackSubchannel); + + // Make sure that when RLS starts communicating that default stops being used + fakeThrottler.nextResult = false; + fakeClock.forwardTime(2, TimeUnit.SECONDS); // Expires backoff cache entries + // Create search subchannel + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + assertThat(res.getSubchannel()).isNotSameInstanceAs(fallbackSubchannel); + FakeSubchannel searchSubchannel = (FakeSubchannel) res.getSubchannel(); + assertThat(searchSubchannel).isNotNull(); + searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + // create rescue subchannel + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); + assertThat(res.getSubchannel()).isNotSameInstanceAs(fallbackSubchannel); + assertThat(res.getSubchannel()).isNotSameInstanceAs(searchSubchannel); + FakeSubchannel rescueSubchannel = (FakeSubchannel) res.getSubchannel(); + assertThat(rescueSubchannel).isNotNull(); + rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); // all channels are failed - rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); - inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - fallbackSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); - inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - inOrder.verifyNoMoreInteractions(); + rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + fallbackSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(res.getSubchannel()).isNull(); } @Test @@ -373,7 +445,7 @@ private void deliverResolvedAddresses() throws Exception { ConfigOrError parsedConfigOrError = provider.parseLoadBalancingPolicyConfig(getServiceConfig()); assertThat(parsedConfigOrError.getConfig()).isNotNull(); - rlsLb.handleResolvedAddresses(ResolvedAddresses.newBuilder() + rlsLb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mock(SocketAddress.class)))) .setLoadBalancingPolicyConfig(parsedConfigOrError.getConfig()) .build()); @@ -392,7 +464,7 @@ private Map getServiceConfig() throws IOException { private String getRlsConfigJsonStr() { return "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -501,7 +573,7 @@ public ChannelCredentials withoutBearerTokens() { @Override public ScheduledExecutorService getScheduledExecutorService() { - return fakeScheduledExecutorService; + return fakeClock.getScheduledExecutorService(); } @Override @@ -547,7 +619,7 @@ private static final class FakeSubchannel extends Subchannel { private final Attributes attributes; private List eags; private SubchannelStateListener listener; - private boolean isReady; + private volatile boolean isReady; public FakeSubchannel(List eags, Attributes attributes) { this.eags = Collections.unmodifiableList(eags); @@ -591,4 +663,20 @@ public void updateState(ConnectivityStateInfo newState) { private static boolean subchannelIsReady(Subchannel subchannel) { return subchannel instanceof FakeSubchannel && ((FakeSubchannel) subchannel).isReady; } + + private static final class FakeThrottler implements Throttler { + + private boolean nextResult = false; + + @Override + public boolean shouldThrottle() { + return nextResult; + } + + @Override + public void registerBackendResponse(boolean throttled) { + // no-op + } + } + } diff --git a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java index 215f8f2ac04..98b7101fd5e 100644 --- a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java @@ -101,7 +101,7 @@ public void convert_toResponseObject() { @Test public void convert_jsonRlsConfig() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -177,7 +177,7 @@ public void convert_jsonRlsConfig() throws IOException { RouteLookupConfig expectedConfig = RouteLookupConfig.builder() - .grpcKeyBuilders(ImmutableList.of( + .grpcKeybuilders(ImmutableList.of( GrpcKeyBuilder.create( ImmutableList.of(Name.create("service1", "create")), ImmutableList.of( @@ -216,7 +216,7 @@ public void convert_jsonRlsConfig() throws IOException { @Test public void convert_jsonRlsConfig_emptyKeyBuilders() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [],\n" + + " \"grpcKeybuilders\": [],\n" + " \"lookupService\": \"service1\",\n" + " \"lookupServiceTimeout\": \"2s\",\n" + " \"maxAge\": \"300s\",\n" @@ -240,7 +240,7 @@ public void convert_jsonRlsConfig_emptyKeyBuilders() throws IOException { @Test public void convert_jsonRlsConfig_namesNotUnique() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -329,7 +329,7 @@ public void convert_jsonRlsConfig_namesNotUnique() throws IOException { @Test public void convert_jsonRlsConfig_defaultValues() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -345,7 +345,7 @@ public void convert_jsonRlsConfig_defaultValues() throws IOException { RouteLookupConfig expectedConfig = RouteLookupConfig.builder() - .grpcKeyBuilders(ImmutableList.of( + .grpcKeybuilders(ImmutableList.of( GrpcKeyBuilder.create( ImmutableList.of(Name.create("service1", null)), ImmutableList.of(), @@ -369,7 +369,7 @@ public void convert_jsonRlsConfig_defaultValues() throws IOException { @Test public void convert_jsonRlsConfig_staleAgeCappedByMaxAge() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -402,7 +402,7 @@ public void convert_jsonRlsConfig_staleAgeCappedByMaxAge() throws IOException { RouteLookupConfig expectedConfig = RouteLookupConfig.builder() - .grpcKeyBuilders(ImmutableList.of( + .grpcKeybuilders(ImmutableList.of( GrpcKeyBuilder.create( ImmutableList.of(Name.create("service1", "create")), ImmutableList.of( @@ -428,7 +428,7 @@ public void convert_jsonRlsConfig_staleAgeCappedByMaxAge() throws IOException { @Test public void convert_jsonRlsConfig_staleAgeGivenWithoutMaxAge() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -472,7 +472,7 @@ public void convert_jsonRlsConfig_staleAgeGivenWithoutMaxAge() throws IOExceptio @Test public void convert_jsonRlsConfig_keyBuilderWithoutName() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"headers\": [\n" + " {\n" @@ -510,7 +510,7 @@ public void convert_jsonRlsConfig_keyBuilderWithoutName() throws IOException { @Test public void convert_jsonRlsConfig_nameWithoutService() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" @@ -553,7 +553,7 @@ public void convert_jsonRlsConfig_nameWithoutService() throws IOException { @Test public void convert_jsonRlsConfig_keysNotUnique() throws IOException { String jsonStr = "{\n" - + " \"grpcKeyBuilders\": [\n" + + " \"grpcKeybuilders\": [\n" + " {\n" + " \"names\": [\n" + " {\n" diff --git a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java index 82e8416563f..6ee2c01af8a 100644 --- a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java @@ -37,7 +37,7 @@ public class RlsRequestFactoryTest { private static final RouteLookupConfig RLS_CONFIG = RouteLookupConfig.builder() - .grpcKeyBuilders(ImmutableList.of( + .grpcKeybuilders(ImmutableList.of( GrpcKeyBuilder.create( ImmutableList.of(Name.create("com.google.service1", "Create")), ImmutableList.of( diff --git a/services/BUILD.bazel b/services/BUILD.bazel index c1dd3ad353f..f8cc6ad7620 100644 --- a/services/BUILD.bazel +++ b/services/BUILD.bazel @@ -2,15 +2,76 @@ load("//:java_grpc_library.bzl", "java_grpc_library") package(default_visibility = ["//visibility:public"]) +# Mirrors the dependencies included in the artifact on Maven Central for usage +# with maven_install's override_targets. Should only be used as a dep for +# pre-compiled binaries on Maven Central. +java_library( + name = "services_maven", + exports = [ + ":admin", + ":binarylog", + ":channelz", + ":health", + ":healthlb", + ":metrics", + ":metrics_internal", + ":reflection", + ], +) + +java_library( + name = "admin", + srcs = [ + "src/main/java/io/grpc/services/AdminInterface.java", + ], + deps = [ + ":channelz", + "//api", + "@com_google_code_findbugs_jsr305//jar", + ], +) + +java_library( + name = "metrics", + srcs = [ + "src/main/java/io/grpc/services/CallMetricRecorder.java", + "src/main/java/io/grpc/services/MetricRecorder.java", + "src/main/java/io/grpc/services/MetricReport.java", + ], + deps = [ + "//api", + "//context", + "@com_google_code_findbugs_jsr305//jar", + "@com_google_errorprone_error_prone_annotations//jar", + "@com_google_guava_guava//jar", + ], +) + +java_library( + name = "metrics_internal", + srcs = [ + "src/main/java/io/grpc/services/InternalCallMetricRecorder.java", + "src/main/java/io/grpc/services/InternalMetricRecorder.java", + ], + visibility = ["//:__subpackages__"], + deps = [ + ":metrics", + "//api", + "//context", + ], +) + java_library( name = "binarylog", srcs = [ "src/main/java/io/grpc/protobuf/services/BinaryLogProvider.java", "src/main/java/io/grpc/protobuf/services/BinaryLogProviderImpl.java", "src/main/java/io/grpc/protobuf/services/BinaryLogSink.java", + "src/main/java/io/grpc/protobuf/services/BinaryLogs.java", "src/main/java/io/grpc/protobuf/services/BinlogHelper.java", "src/main/java/io/grpc/protobuf/services/InetAddressUtil.java", "src/main/java/io/grpc/protobuf/services/TempFileSink.java", + "src/main/java/io/grpc/services/BinaryLogs.java", ], deps = [ "//api", @@ -33,7 +94,6 @@ java_library( deps = [ ":_channelz_java_grpc", "//api", - "//context", "//stub", "@com_google_code_findbugs_jsr305//jar", "@com_google_guava_guava//jar", @@ -51,9 +111,6 @@ java_library( deps = [ ":_reflection_java_grpc", "//api", - "//context", - "//core:internal", - "//core:util", "//protobuf", "//stub", "@com_google_code_findbugs_jsr305//jar", @@ -82,6 +139,27 @@ java_library( ], ) +java_library( + name = "healthlb", + srcs = [ + "src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java", + "src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerUtil.java", + "src/main/java/io/grpc/protobuf/services/internal/HealthCheckingRoundRobinLoadBalancerProvider.java", + ], + resources = [ + "src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider", + ], + deps = [ + ":_health_java_grpc", + "//api", + "//core:internal", + "//core:util", + "@com_google_code_findbugs_jsr305//jar", + "@com_google_guava_guava//jar", + "@io_grpc_grpc_proto//:health_java_proto", + ], +) + # These shouldn't be here, but this is better than having # a circular dependency on grpc-proto and grpc-java. diff --git a/services/build.gradle b/services/build.gradle index 2de9418c3c2..b6d945c7e98 100644 --- a/services/build.gradle +++ b/services/build.gradle @@ -8,7 +8,7 @@ plugins { description = "gRPC: Services" -[compileJava].each() { +tasks.named("compileJava").configure { // v1alpha of reflection.proto is deprecated at the file level. // Without this workaround, the project can not compile. it.options.compilerArgs += [ @@ -22,26 +22,28 @@ dependencies { api project(':grpc-protobuf'), project(':grpc-stub'), project(':grpc-core') - implementation libraries.protobuf_util, + implementation libraries.protobuf.java.util, libraries.guava - runtimeOnly libraries.errorprone - compileOnly libraries.javax_annotation + runtimeOnly libraries.errorprone.annotations + + compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), - libraries.netty_epoll, // for DomainSocketAddress + libraries.netty.transport.epoll, // for DomainSocketAddress project(':grpc-core').sourceSets.test.output // for FakeClock - testCompileOnly libraries.javax_annotation - signature "org.codehaus.mojo.signature:java17:1.0@signature" + testCompileOnly libraries.javax.annotation + signature libraries.signature.java } configureProtoCompilation() -javadoc { +tasks.named("javadoc").configure { exclude 'io/grpc/services/Internal*.java' exclude 'io/grpc/services/internal/*' + exclude 'io/grpc/protobuf/services/internal/*' } -jacocoTestReport { +tasks.named("jacocoTestReport").configure { classDirectories.from = sourceSets.main.output.collect { fileTree(dir: it, exclude: [ diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java index a1933c632fd..6ce602b9295 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java @@ -18,8 +18,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; -import io.grpc.Context.CancellationListener; import io.grpc.Context; +import io.grpc.Context.CancellationListener; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.health.v1.HealthCheckRequest; diff --git a/services/src/main/java/io/grpc/services/BinaryLogs.java b/services/src/main/java/io/grpc/services/BinaryLogs.java index e4831dc5bfa..ad8a27c3290 100644 --- a/services/src/main/java/io/grpc/services/BinaryLogs.java +++ b/services/src/main/java/io/grpc/services/BinaryLogs.java @@ -22,6 +22,8 @@ import java.io.IOException; /** + * Utility class to create BinaryLog instances. + * * @deprecated Use {@link io.grpc.protobuf.services.BinaryLogs} instead. */ @Deprecated diff --git a/services/src/main/java/io/grpc/services/CallMetricRecorder.java b/services/src/main/java/io/grpc/services/CallMetricRecorder.java index f9b6a71a4b8..d93f93606f9 100644 --- a/services/src/main/java/io/grpc/services/CallMetricRecorder.java +++ b/services/src/main/java/io/grpc/services/CallMetricRecorder.java @@ -17,6 +17,7 @@ package io.grpc.services; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.InlineMe; import io.grpc.Context; import io.grpc.ExperimentalApi; import java.util.Collections; @@ -36,13 +37,14 @@ public final class CallMetricRecorder { static final Context.Key CONTEXT_KEY = Context.key("io.grpc.services.CallMetricRecorder"); - private final AtomicReference> metrics = + private final AtomicReference> utilizationMetrics = new AtomicReference<>(); + private final AtomicReference> requestCostMetrics = + new AtomicReference<>(); + private double cpuUtilizationMetric = 0; + private double memoryUtilizationMetric = 0; private volatile boolean disabled; - CallMetricRecorder() { - } - /** * Returns the call metric recorder attached to the current {@link Context}. If there is none, * returns a no-op recorder. @@ -62,41 +64,135 @@ public static CallMetricRecorder getCurrent() { } /** - * Records a call metric measurement. If RPC has already finished, this method is no-op. + * Records a call metric measurement for utilization. + * If RPC has already finished, this method is no-op. * *

    A latter record will overwrite its former name-sakes. * * @return this recorder object * @since 1.23.0 */ + public CallMetricRecorder recordUtilizationMetric(String name, double value) { + if (disabled) { + return this; + } + if (utilizationMetrics.get() == null) { + // The chance of race of creation of the map should be very small, so it should be fine + // to create these maps that might be discarded. + utilizationMetrics.compareAndSet(null, new ConcurrentHashMap()); + } + utilizationMetrics.get().put(name, value); + return this; + } + + /** + * Records a call metric measurement for request cost. + * If RPC has already finished, this method is no-op. + * + *

    A latter record will overwrite its former name-sakes. + * + * @return this recorder object + * @since 1.47.0 + * @deprecated use {@link #recordRequestCostMetric} instead. + * This method will be removed in the future. + */ + @Deprecated + @InlineMe(replacement = "this.recordRequestCostMetric(name, value)") public CallMetricRecorder recordCallMetric(String name, double value) { + return recordRequestCostMetric(name, value); + } + + /** + * Records a call metric measurement for request cost. + * If RPC has already finished, this method is no-op. + * + *

    A latter record will overwrite its former name-sakes. + * + * @return this recorder object + * @since 1.48.1 + */ + public CallMetricRecorder recordRequestCostMetric(String name, double value) { if (disabled) { return this; } - if (metrics.get() == null) { + if (requestCostMetrics.get() == null) { // The chance of race of creation of the map should be very small, so it should be fine // to create these maps that might be discarded. - metrics.compareAndSet(null, new ConcurrentHashMap()); + requestCostMetrics.compareAndSet(null, new ConcurrentHashMap()); } - metrics.get().put(name, value); + requestCostMetrics.get().put(name, value); return this; } /** - * Returns all save metric values. No more metric values will be recorded after this method is - * called. Calling this method multiple times returns the same collection of metric values. + * Records a call metric measurement for CPU utilization. + * If RPC has already finished, this method is no-op. + * + *

    A latter record will overwrite its former name-sakes. + * + * @return this recorder object + * @since 1.47.0 + */ + public CallMetricRecorder recordCpuUtilizationMetric(double value) { + if (disabled) { + return this; + } + cpuUtilizationMetric = value; + return this; + } + + /** + * Records a call metric measurement for memory utilization. + * If RPC has already finished, this method is no-op. + * + *

    A latter record will overwrite its former name-sakes. + * + * @return this recorder object + * @since 1.47.0 + */ + public CallMetricRecorder recordMemoryUtilizationMetric(double value) { + if (disabled) { + return this; + } + memoryUtilizationMetric = value; + return this; + } + + + /** + * Returns all request cost metric values. No more metric values will be recorded after this + * method is called. Calling this method multiple times returns the same collection of metric + * values. * * @return a map containing all saved metric name-value pairs. */ Map finalizeAndDump() { disabled = true; - Map savedMetrics = metrics.get(); + Map savedMetrics = requestCostMetrics.get(); if (savedMetrics == null) { return Collections.emptyMap(); } return Collections.unmodifiableMap(savedMetrics); } + /** + * Returns all save metric values. No more metric values will be recorded after this method is + * called. Calling this method multiple times returns the same collection of metric values. + * + * @return a per-request ORCA reports containing all saved metrics. + */ + MetricReport finalizeAndDump2() { + Map savedRequestCostMetrics = finalizeAndDump(); + Map savedUtilizationMetrics = utilizationMetrics.get(); + if (savedUtilizationMetrics == null) { + savedUtilizationMetrics = Collections.emptyMap(); + } + return new MetricReport(cpuUtilizationMetric, + memoryUtilizationMetric, Collections.unmodifiableMap(savedRequestCostMetrics), + Collections.unmodifiableMap(savedUtilizationMetrics) + ); + } + @VisibleForTesting boolean isDisabled() { return disabled; diff --git a/services/src/main/java/io/grpc/services/InternalCallMetricRecorder.java b/services/src/main/java/io/grpc/services/InternalCallMetricRecorder.java index c68759e8b38..97e5e5a0aa6 100644 --- a/services/src/main/java/io/grpc/services/InternalCallMetricRecorder.java +++ b/services/src/main/java/io/grpc/services/InternalCallMetricRecorder.java @@ -40,4 +40,14 @@ public static CallMetricRecorder newCallMetricRecorder() { public static Map finalizeAndDump(CallMetricRecorder recorder) { return recorder.finalizeAndDump(); } + + public static MetricReport finalizeAndDump2(CallMetricRecorder recorder) { + return recorder.finalizeAndDump2(); + } + + public static MetricReport createMetricReport(double cpuUtilization, double memoryUtilization, + Map requestCostMetrics, Map utilizationMetrics) { + return new MetricReport(cpuUtilization, memoryUtilization, + requestCostMetrics, utilizationMetrics); + } } diff --git a/services/src/main/java/io/grpc/services/InternalMetricRecorder.java b/services/src/main/java/io/grpc/services/InternalMetricRecorder.java new file mode 100644 index 00000000000..cd36c425ac8 --- /dev/null +++ b/services/src/main/java/io/grpc/services/InternalMetricRecorder.java @@ -0,0 +1,35 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.services; + +import io.grpc.Internal; + +/** + * Internal {@link CallMetricRecorder} accessor. This is intended for usage internal to the gRPC + * team. If you *really* think you need to use this, contact the gRPC team first. + */ +@Internal +public final class InternalMetricRecorder { + + // Prevent instantiation. + private InternalMetricRecorder() { + } + + public static MetricReport getMetricReport(MetricRecorder recorder) { + return recorder.getMetricReport(); + } +} diff --git a/services/src/main/java/io/grpc/services/MetricRecorder.java b/services/src/main/java/io/grpc/services/MetricRecorder.java new file mode 100644 index 00000000000..a576386e98b --- /dev/null +++ b/services/src/main/java/io/grpc/services/MetricRecorder.java @@ -0,0 +1,93 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.services; + +import io.grpc.ExperimentalApi; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implements the service/APIs for Out-of-Band metrics reporting, only for utilization metrics. + * A user should use the public set-APIs to update the server machine's utilization metrics data. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9006") +public final class MetricRecorder { + private volatile ConcurrentHashMap metricsData = new ConcurrentHashMap<>(); + private volatile double cpuUtilization; + private volatile double memoryUtilization; + + public static MetricRecorder newInstance() { + return new MetricRecorder(); + } + + private MetricRecorder() {} + + /** + * Update the metrics value corresponding to the specified key. + */ + public void putUtilizationMetric(String key, double value) { + metricsData.put(key, value); + } + + /** + * Replace the whole metrics data using the specified map. + */ + public void setAllUtilizationMetrics(Map metrics) { + metricsData = new ConcurrentHashMap<>(metrics); + } + + /** + * Remove the metrics data entry corresponding to the specified key. + */ + public void removeUtilizationMetric(String key) { + metricsData.remove(key); + } + + /** + * Update the CPU utilization metrics data. + */ + public void setCpuUtilizationMetric(double value) { + cpuUtilization = value; + } + + /** + * Clear the CPU utilization metrics data. + */ + public void clearCpuUtilizationMetric() { + cpuUtilization = 0; + } + + /** + * Update the memory utilization metrics data. + */ + public void setMemoryUtilizationMetric(double value) { + memoryUtilization = value; + } + + /** + * Clear the memory utilization metrics data. + */ + public void clearMemoryUtilizationMetric() { + memoryUtilization = 0; + } + + MetricReport getMetricReport() { + return new MetricReport(cpuUtilization, memoryUtilization, + Collections.emptyMap(), Collections.unmodifiableMap(metricsData)); + } +} diff --git a/services/src/main/java/io/grpc/services/MetricReport.java b/services/src/main/java/io/grpc/services/MetricReport.java new file mode 100644 index 00000000000..56ab150f8af --- /dev/null +++ b/services/src/main/java/io/grpc/services/MetricReport.java @@ -0,0 +1,70 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.services; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.MoreObjects; +import io.grpc.ExperimentalApi; +import java.util.Map; + +/** + * A gRPC object of orca load report. LB policies listening at per-rpc or oob orca load reports + * will be notified of the metrics data in this data format. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9381") +public final class MetricReport { + private double cpuUtilization; + private double memoryUtilization; + private Map requestCostMetrics; + private Map utilizationMetrics; + + MetricReport(double cpuUtilization, double memoryUtilization, + Map requestCostMetrics, + Map utilizationMetrics) { + this.cpuUtilization = cpuUtilization; + this.memoryUtilization = memoryUtilization; + this.requestCostMetrics = checkNotNull(requestCostMetrics, "requestCostMetrics"); + this.utilizationMetrics = checkNotNull(utilizationMetrics, "utilizationMetrics"); + } + + public double getCpuUtilization() { + return cpuUtilization; + } + + public double getMemoryUtilization() { + return memoryUtilization; + } + + public Map getRequestCostMetrics() { + return requestCostMetrics; + } + + public Map getUtilizationMetrics() { + return utilizationMetrics; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("cpuUtilization", cpuUtilization) + .add("memoryUtilization", memoryUtilization) + .add("requestCost", requestCostMetrics) + .add("utilization", utilizationMetrics) + .toString(); + } +} diff --git a/services/src/main/proto/grpc/channelz/v1/channelz.proto b/services/src/main/proto/grpc/channelz/v1/channelz.proto index f0b3b10837e..d0781094ea8 100644 --- a/services/src/main/proto/grpc/channelz/v1/channelz.proto +++ b/services/src/main/proto/grpc/channelz/v1/channelz.proto @@ -35,7 +35,7 @@ option java_outer_classname = "ChannelzProto"; // Channel is a logical grouping of channels, subchannels, and sockets. message Channel { - // The identifier for this channel. This should bet set. + // The identifier for this channel. This should be set. ChannelRef ref = 1; // Data specific to this channel. ChannelData data = 2; diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java index 5c5b30f336d..87d4ac29be8 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java @@ -20,8 +20,8 @@ import static org.junit.Assert.fail; import io.grpc.BindableService; -import io.grpc.Context.CancellableContext; import io.grpc.Context; +import io.grpc.Context.CancellableContext; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.health.v1.HealthCheckRequest; diff --git a/services/src/test/java/io/grpc/services/CallMetricRecorderTest.java b/services/src/test/java/io/grpc/services/CallMetricRecorderTest.java index fe7a9c54df8..9811d1da92e 100644 --- a/services/src/test/java/io/grpc/services/CallMetricRecorderTest.java +++ b/services/src/test/java/io/grpc/services/CallMetricRecorderTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.common.truth.Truth; import io.grpc.Context; import java.util.Map; import org.junit.Test; @@ -37,32 +38,51 @@ public void dumpGivesEmptyResultWhenNoSavedMetricValues() { @Test public void dumpDumpsAllSavedMetricValues() { - recorder.recordCallMetric("cost1", 154353.423); - recorder.recordCallMetric("cost2", 0.1367); - recorder.recordCallMetric("cost3", 1437.34); + recorder.recordUtilizationMetric("util1", 154353.423); + recorder.recordUtilizationMetric("util2", 0.1367); + recorder.recordUtilizationMetric("util3", 1437.34); + recorder.recordRequestCostMetric("cost1", 37465.12); + recorder.recordRequestCostMetric("cost2", 10293.0); + recorder.recordRequestCostMetric("cost3", 1.0); + recorder.recordCpuUtilizationMetric(0.1928); + recorder.recordMemoryUtilizationMetric(47.4); - Map dump = recorder.finalizeAndDump(); - assertThat(dump) - .containsExactly("cost1", 154353.423, "cost2", 0.1367, "cost3", 1437.34); + MetricReport dump = recorder.finalizeAndDump2(); + Truth.assertThat(dump.getUtilizationMetrics()) + .containsExactly("util1", 154353.423, "util2", 0.1367, "util3", 1437.34); + Truth.assertThat(dump.getRequestCostMetrics()) + .containsExactly("cost1", 37465.12, "cost2", 10293.0, "cost3", 1.0); + Truth.assertThat(dump.getCpuUtilization()).isEqualTo(0.1928); + Truth.assertThat(dump.getMemoryUtilization()).isEqualTo(47.4); } @Test public void noMetricsRecordedAfterSnapshot() { Map initDump = recorder.finalizeAndDump(); - recorder.recordCallMetric("cost", 154353.423); + recorder.recordUtilizationMetric("cost", 154353.423); assertThat(recorder.finalizeAndDump()).isEqualTo(initDump); } @Test public void lastValueWinForMetricsWithSameName() { - recorder.recordCallMetric("cost1", 3412.5435); - recorder.recordCallMetric("cost2", 6441.341); - recorder.recordCallMetric("cost1", 6441.341); - recorder.recordCallMetric("cost1", 4654.67); - recorder.recordCallMetric("cost2", 75.83); - Map dump = recorder.finalizeAndDump(); - assertThat(dump) + recorder.recordRequestCostMetric("cost1", 3412.5435); + recorder.recordRequestCostMetric("cost2", 6441.341); + recorder.recordRequestCostMetric("cost1", 6441.341); + recorder.recordRequestCostMetric("cost1", 4654.67); + recorder.recordRequestCostMetric("cost2", 75.83); + recorder.recordMemoryUtilizationMetric(1.3); + recorder.recordMemoryUtilizationMetric(3.1); + recorder.recordUtilizationMetric("util1", 28374.21); + recorder.recordMemoryUtilizationMetric(9384.0); + recorder.recordUtilizationMetric("util1", 84323.3); + + MetricReport dump = recorder.finalizeAndDump2(); + Truth.assertThat(dump.getRequestCostMetrics()) .containsExactly("cost1", 4654.67, "cost2", 75.83); + Truth.assertThat(dump.getMemoryUtilization()).isEqualTo(9384.0); + Truth.assertThat(dump.getUtilizationMetrics()) + .containsExactly("util1", 84323.3); + Truth.assertThat(dump.getCpuUtilization()).isEqualTo(0); } @Test diff --git a/servlet/build.gradle b/servlet/build.gradle new file mode 100644 index 00000000000..f5ef32ae11e --- /dev/null +++ b/servlet/build.gradle @@ -0,0 +1,111 @@ +plugins { + id "java-library" + id "maven-publish" +} + +description = "gRPC: Servlet" + +// javax.servlet-api 4.0 requires a minimum of Java 8, so we might as well use that source level +sourceCompatibility = 1.8 +targetCompatibility = 1.8 + +def jettyVersion = '10.0.7' + +configurations { + itImplementation.extendsFrom(implementation) + undertowTestImplementation.extendsFrom(itImplementation) + tomcatTestImplementation.extendsFrom(itImplementation) + jettyTestImplementation.extendsFrom(itImplementation) +} + +sourceSets { + // Create a test sourceset for each classpath - could be simplified if we made new test directories + undertowTest {} + tomcatTest {} + + // Only compile these tests if java 11+ is being used + if (JavaVersion.current().isJava11Compatible()) { + jettyTest {} + } +} + +dependencies { + api project(':grpc-api') + compileOnly 'javax.servlet:javax.servlet-api:4.0.1', + libraries.javax.annotation // java 9, 10 needs it + + implementation project(':grpc-core'), + libraries.guava + + testImplementation 'javax.servlet:javax.servlet-api:4.0.1', + 'org.jetbrains.kotlinx:lincheck:2.14.1' + + itImplementation project(':grpc-servlet'), + project(':grpc-netty'), + project(':grpc-core').sourceSets.test.runtimeClasspath, + libraries.junit + itImplementation(project(':grpc-interop-testing')) { + // Avoid grpc-netty-shaded dependency + exclude group: 'io.grpc', module: 'grpc-alts' + exclude group: 'io.grpc', module: 'grpc-xds' + } + + undertowTestImplementation 'io.undertow:undertow-servlet:2.2.14.Final' + + tomcatTestImplementation 'org.apache.tomcat.embed:tomcat-embed-core:9.0.56' + + jettyTestImplementation "org.eclipse.jetty:jetty-servlet:${jettyVersion}", + "org.eclipse.jetty.http2:http2-server:${jettyVersion}", + "org.eclipse.jetty:jetty-client:${jettyVersion}" + project(':grpc-testing') +} + +test { + if (JavaVersion.current().isJava9Compatible()) { + jvmArgs += [ + // required for Lincheck + '--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED', + '--add-exports=java.base/jdk.internal.util=ALL-UNNAMED', + ] + } +} + +// Set up individual classpaths for each test, to avoid any mismatch, +// and ensure they are only used when supported by the current jvm +check.dependsOn(tasks.register('undertowTest', Test) { + classpath = sourceSets.undertowTest.runtimeClasspath + testClassesDirs = sourceSets.undertowTest.output.classesDirs +}) +check.dependsOn(tasks.register('tomcat9Test', Test) { + classpath = sourceSets.tomcatTest.runtimeClasspath + testClassesDirs = sourceSets.tomcatTest.output.classesDirs + + // Provide a temporary directory for tomcat to be deleted after test finishes + def tomcatTempDir = "$buildDir/tomcat_catalina_base" + systemProperty 'catalina.base', tomcatTempDir + doLast { + file(tomcatTempDir).deleteDir() + } + + // tomcat-embed-core 9 presently performs illegal reflective access on + // java.io.ObjectStreamClass$Caches.localDescs and sun.rmi.transport.Target.ccl, + // see https://lists.apache.org/thread/s0xr7tk2kfkkxfjps9n7dhh4cypfdhyy + if (JavaVersion.current().isJava9Compatible()) { + jvmArgs += ['--add-opens=java.base/java.io=ALL-UNNAMED', '--add-opens=java.rmi/sun.rmi.transport=ALL-UNNAMED'] + } +}) + +// Only run these tests if java 11+ is being used +if (JavaVersion.current().isJava11Compatible()) { + check.dependsOn(tasks.register('jettyTest', Test) { + classpath = sourceSets.jettyTest.runtimeClasspath + testClassesDirs = sourceSets.jettyTest.output.classesDirs + }) +} + +jacocoTestReport { + executionData undertowTest, tomcat9Test + if (JavaVersion.current().isJava11Compatible()) { + executionData jettyTest + } +} diff --git a/servlet/jakarta/build.gradle b/servlet/jakarta/build.gradle new file mode 100644 index 00000000000..59f5ac78d80 --- /dev/null +++ b/servlet/jakarta/build.gradle @@ -0,0 +1,128 @@ +plugins { + id "java-library" + id "maven-publish" +} + +description = "gRPC: Jakarta Servlet" +sourceCompatibility = 1.8 +targetCompatibility = 1.8 + +// Set up classpaths and source directories for different servlet tests +configurations { + itImplementation.extendsFrom(implementation) + jettyTestImplementation.extendsFrom(itImplementation) + tomcatTestImplementation.extendsFrom(itImplementation) + undertowTestImplementation.extendsFrom(itImplementation) +} + +sourceSets { + undertowTest { + java { + include '**/Undertow*.java' + } + } + tomcatTest { + java { + include '**/Tomcat*.java' + } + } + // Only run these tests if java 11+ is being used + if (JavaVersion.current().isJava11Compatible()) { + jettyTest { + java { + include '**/Jetty*.java' + } + } + } +} + +// Mechanically transform sources from grpc-servlet to use the corrected packages +def migrate(String name, String inputDir, SourceSet sourceSet) { + def outputDir = layout.buildDirectory.dir('generated/sources/jakarta-' + name) + sourceSet.java.srcDir outputDir + return tasks.register('migrateSources' + name.capitalize(), Sync) { task -> + into(outputDir) + from("$inputDir/io/grpc/servlet") { + into('io/grpc/servlet/jakarta') + filter { String line -> + line.replaceAll('javax\\.servlet', 'jakarta.servlet') + .replaceAll('io\\.grpc\\.servlet', 'io.grpc.servlet.jakarta') + } + } + } +} + +compileJava.dependsOn migrate('main', '../src/main/java', sourceSets.main) + +sourcesJar.dependsOn migrateSourcesMain + +// Build the set of sourceSets and classpaths to modify, since Jetty 11 requires Java 11 +// and must be skipped +compileUndertowTestJava.dependsOn(migrate('undertowTest', '../src/undertowTest/java', sourceSets.undertowTest)) +compileTomcatTestJava.dependsOn(migrate('tomcatTest', '../src/tomcatTest/java', sourceSets.tomcatTest)) +if (JavaVersion.current().isJava11Compatible()) { + compileJettyTestJava.dependsOn(migrate('jettyTest', '../src/jettyTest/java', sourceSets.jettyTest)) +} + +// Disable checkstyle for this project, since it consists only of generated code +tasks.withType(Checkstyle) { + enabled = false +} + +dependencies { + api project(':grpc-api') + compileOnly 'jakarta.servlet:jakarta.servlet-api:5.0.0', + libraries.javax.annotation + + implementation project(':grpc-core'), + libraries.guava + + itImplementation project(':grpc-servlet-jakarta'), + project(':grpc-netty'), + project(':grpc-core').sourceSets.test.runtimeClasspath, + libraries.junit + itImplementation(project(':grpc-interop-testing')) { + // Avoid grpc-netty-shaded dependency + exclude group: 'io.grpc', module: 'grpc-alts' + exclude group: 'io.grpc', module: 'grpc-xds' + } + + tomcatTestImplementation 'org.apache.tomcat.embed:tomcat-embed-core:10.0.14' + + jettyTestImplementation "org.eclipse.jetty:jetty-servlet:11.0.7", + "org.eclipse.jetty.http2:http2-server:11.0.7" + + undertowTestImplementation 'io.undertow:undertow-servlet-jakartaee9:2.2.13.Final' +} + +// Set up individual classpaths for each test, to avoid any mismatch, +// and ensure they are only used when supported by the current jvm +check.dependsOn(tasks.register('undertowTest', Test) { + classpath = sourceSets.undertowTest.runtimeClasspath + testClassesDirs = sourceSets.undertowTest.output.classesDirs +}) +check.dependsOn(tasks.register('tomcat10Test', Test) { + classpath = sourceSets.tomcatTest.runtimeClasspath + testClassesDirs = sourceSets.tomcatTest.output.classesDirs + + // Provide a temporary directory for tomcat to be deleted after test finishes + def tomcatTempDir = "$buildDir/tomcat_catalina_base" + systemProperty 'catalina.base', tomcatTempDir + doLast { + file(tomcatTempDir).deleteDir() + } + + // tomcat-embed-core 10 presently performs illegal reflective access on + // java.io.ObjectStreamClass$Caches.localDescs and sun.rmi.transport.Target.ccl, + // see https://lists.apache.org/thread/s0xr7tk2kfkkxfjps9n7dhh4cypfdhyy + if (JavaVersion.current().isJava9Compatible()) { + jvmArgs += ['--add-opens=java.base/java.io=ALL-UNNAMED', '--add-opens=java.rmi/sun.rmi.transport=ALL-UNNAMED'] + } +}) +// Only run these tests if java 11+ is being used +if (JavaVersion.current().isJava11Compatible()) { + check.dependsOn(tasks.register('jetty11Test', Test) { + classpath = sourceSets.jettyTest.runtimeClasspath + testClassesDirs = sourceSets.jettyTest.output.classesDirs + }) +} diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/GrpcServletSmokeTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/GrpcServletSmokeTest.java new file mode 100644 index 00000000000..0208645706e --- /dev/null +++ b/servlet/src/jettyTest/java/io/grpc/servlet/GrpcServletSmokeTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import io.grpc.BindableService; +import io.grpc.Channel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; +import io.grpc.testing.integration.TestServiceGrpc; +import io.grpc.testing.integration.TestServiceImpl; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.api.ContentResponse; +import org.eclipse.jetty.http2.parser.RateControl; +import org.eclipse.jetty.http2.server.HTTP2CServerConnectionFactory; +import org.eclipse.jetty.server.HttpConfiguration; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Smoke test for {@link GrpcServlet}. */ +@RunWith(JUnit4.class) +public class GrpcServletSmokeTest { + private static final String HOST = "localhost"; + private static final String MYAPP = "/grpc.testing.TestService"; + + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + private final ScheduledExecutorService scheduledExecutorService = + Executors.newSingleThreadScheduledExecutor(); + private int port; + private Server server; + + @Before + public void startServer() { + BindableService service = new TestServiceImpl(scheduledExecutorService); + GrpcServlet grpcServlet = new GrpcServlet(ImmutableList.of(service)); + server = new Server(0); + ServerConnector sc = (ServerConnector)server.getConnectors()[0]; + HTTP2CServerConnectionFactory factory = + new HTTP2CServerConnectionFactory(new HttpConfiguration()); + + // Explicitly disable safeguards against malicious clients, as some unit tests trigger this + factory.setRateControlFactory(new RateControl.Factory() {}); + + sc.addConnectionFactory(factory); + ServletContextHandler context = + new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath(MYAPP); + context.addServlet(new ServletHolder(grpcServlet), "/*"); + server.setHandler(context); + + try { + server.start(); + } catch (Exception e) { + throw new AssertionError(e); + } + + port = sc.getLocalPort(); + } + + @After + public void tearDown() { + scheduledExecutorService.shutdown(); + try { + server.stop(); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + @Test + public void unaryCall() { + Channel channel = cleanupRule.register( + ManagedChannelBuilder.forAddress(HOST, port).usePlaintext().build()); + SimpleResponse response = TestServiceGrpc.newBlockingStub(channel).unaryCall( + SimpleRequest.newBuilder() + .setResponseSize(1234) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFromUtf8("hello foo"))) + .build()); + assertThat(response.getPayload().getBody().size()).isEqualTo(1234); + } + + @Test + public void httpGetRequest() throws Exception { + HttpClient httpClient = new HttpClient(); + try { + httpClient.start(); + ContentResponse response = + httpClient.GET("http://" + HOST + ":" + port + MYAPP + "/UnaryCall"); + assertThat(response.getStatus()).isEqualTo(405); + assertThat(response.getContentAsString()).contains("GET method not supported"); + } finally { + httpClient.stop(); + } + } +} diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/JettyInteropTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/JettyInteropTest.java new file mode 100644 index 00000000000..ebdf029fe27 --- /dev/null +++ b/servlet/src/jettyTest/java/io/grpc/servlet/JettyInteropTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.testing.integration.AbstractInteropTest; +import org.eclipse.jetty.http2.parser.RateControl; +import org.eclipse.jetty.http2.server.HTTP2CServerConnectionFactory; +import org.eclipse.jetty.server.HttpConfiguration; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.After; + +public class JettyInteropTest extends AbstractInteropTest { + + private static final String HOST = "localhost"; + private static final String MYAPP = "/grpc.testing.TestService"; + private int port; + private Server server; + + @After + @Override + public void tearDown() { + super.tearDown(); + try { + server.stop(); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + @Override + protected ServerBuilder getServerBuilder() { + return new ServletServerBuilder().maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + } + + @Override + protected void startServer(ServerBuilder builer) { + GrpcServlet grpcServlet = + new GrpcServlet(((ServletServerBuilder) builer).buildServletAdapter()); + server = new Server(0); + ServerConnector sc = (ServerConnector)server.getConnectors()[0]; + HTTP2CServerConnectionFactory factory = + new HTTP2CServerConnectionFactory(new HttpConfiguration()); + + // Explicitly disable safeguards against malicious clients, as some unit tests trigger this + factory.setRateControlFactory(new RateControl.Factory() {}); + + sc.addConnectionFactory(factory); + ServletContextHandler context = + new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath(MYAPP); + context.addServlet(new ServletHolder(grpcServlet), "/*"); + server.setHandler(context); + + try { + server.start(); + } catch (Exception e) { + throw new AssertionError(e); + } + + port = sc.getLocalPort(); + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + NettyChannelBuilder builder = + (NettyChannelBuilder) ManagedChannelBuilder.forAddress(HOST, port) + .usePlaintext() + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + InternalNettyChannelBuilder.setStatsEnabled(builder, false); + builder.intercept(createCensusStatsClientInterceptor()); + return builder; + } +} diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java new file mode 100644 index 00000000000..7941afc9b4d --- /dev/null +++ b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java @@ -0,0 +1,249 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import io.grpc.InternalChannelz; +import io.grpc.InternalInstrumented; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.FakeClock; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerTransportListener; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NegotiationType; +import io.grpc.netty.NettyChannelBuilder; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.eclipse.jetty.http2.parser.RateControl; +import org.eclipse.jetty.http2.server.HTTP2CServerConnectionFactory; +import org.eclipse.jetty.server.HttpConfiguration; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.Ignore; +import org.junit.Test; + + +public class JettyTransportTest extends AbstractTransportTest { + private static final String MYAPP = "/service"; + + private final FakeClock fakeClock = new FakeClock(); + private Server jettyServer; + private int port; + + + @Override + protected InternalServer newServer(List streamTracerFactories) { + return new InternalServer() { + final InternalServer delegate = + new ServletServerBuilder().buildTransportServers(streamTracerFactories); + + @Override + public void start(ServerListener listener) throws IOException { + delegate.start(listener); + ScheduledExecutorService scheduler = fakeClock.getScheduledExecutorService(); + ServerTransportListener serverTransportListener = + listener.transportCreated(new ServletServerBuilder.ServerTransportImpl(scheduler)); + ServletAdapter adapter = + new ServletAdapter(serverTransportListener, streamTracerFactories, + Integer.MAX_VALUE); + GrpcServlet grpcServlet = new GrpcServlet(adapter); + + jettyServer = new Server(0); + ServerConnector sc = (ServerConnector) jettyServer.getConnectors()[0]; + HttpConfiguration httpConfiguration = new HttpConfiguration(); + + // Must be set for several tests to pass, so that the request handling can begin before + // content arrives. + httpConfiguration.setDelayDispatchUntilContent(false); + + HTTP2CServerConnectionFactory factory = + new HTTP2CServerConnectionFactory(httpConfiguration); + factory.setRateControlFactory(new RateControl.Factory() { + }); + sc.addConnectionFactory(factory); + ServletContextHandler context = + new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath(MYAPP); + context.addServlet(new ServletHolder(grpcServlet), "/*"); + jettyServer.setHandler(context); + + try { + jettyServer.start(); + } catch (Exception e) { + throw new AssertionError(e); + } + + port = sc.getLocalPort(); + } + + @Override + public void shutdown() { + delegate.shutdown(); + } + + @Override + public SocketAddress getListenSocketAddress() { + return delegate.getListenSocketAddress(); + } + + @Override + public InternalInstrumented getListenSocketStats() { + return delegate.getListenSocketStats(); + } + + @Override + public List getListenSocketAddresses() { + return delegate.getListenSocketAddresses(); + } + + @Nullable + @Override + public List> getListenSocketStatsList() { + return delegate.getListenSocketStatsList(); + } + }; + } + + @Override + protected InternalServer newServer(int port, + List streamTracerFactories) { + return newServer(streamTracerFactories); + } + + @Override + protected ManagedClientTransport newClientTransport(InternalServer server) { + NettyChannelBuilder nettyChannelBuilder = NettyChannelBuilder + // Although specified here, address is ignored because we never call build. + .forAddress("localhost", 0) + .flowControlWindow(65 * 1024) + .negotiationType(NegotiationType.PLAINTEXT); + InternalNettyChannelBuilder + .setTransportTracerFactory(nettyChannelBuilder, fakeClockTransportTracer); + ClientTransportFactory clientFactory = + InternalNettyChannelBuilder.buildTransportFactory(nettyChannelBuilder); + return clientFactory.newClientTransport( + new InetSocketAddress("localhost", port), + new ClientTransportFactory.ClientTransportOptions() + .setAuthority(testAuthority(server)) + .setEagAttributes(eagAttrs()), + transportLogger()); + } + + @Override + protected String testAuthority(InternalServer server) { + return "localhost:" + port; + } + + @Override + protected void advanceClock(long offset, TimeUnit unit) { + fakeClock.forwardTime(offset, unit); + } + + @Override + protected long fakeCurrentTimeNanos() { + return fakeClock.getTicker().read(); + } + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void serverAlreadyListening() { + } + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void openStreamPreventsTermination() { + } + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void shutdownNowKillsServerStream() { + } + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void serverNotListening() { + } + + // FIXME + @Override + @Ignore("Servlet flow control not implemented yet") + @Test + public void flowControlPushBack() { + } + + // FIXME + @Override + @Ignore("Jetty is broken on client RST_STREAM") + @Test + public void shutdownNowKillsClientStream() { + } + + @Override + @Ignore("Server side sockets are managed by the servlet container") + @Test + public void socketStats() { + } + + @Override + @Ignore("serverTransportListener will not terminate") + @Test + public void clientStartAndStopOnceConnected() { + } + + @Override + @Ignore("clientStreamTracer1.getInboundTrailers() is not null; listeners.poll() doesn't apply") + @Test + public void serverCancel() { + } + + @Override + @Ignore("This doesn't apply: Ensure that for a closed ServerStream, interactions are noops") + @Test + public void interactionsAfterServerStreamCloseAreNoops() { + } + + @Override + @Ignore("listeners.poll() doesn't apply") + @Test + public void interactionsAfterClientStreamCancelAreNoops() { + } + + @Override + @Ignore("assertNull(serverStatus.getCause()) isn't true") + @Test + public void clientCancel() { + } + + @Override + @Ignore("regression since bumping grpc v1.46 to v1.53") + @Test + public void messageProducerOnlyProducesRequestedMessages() {} +} diff --git a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java new file mode 100644 index 00000000000..4f4e37fda87 --- /dev/null +++ b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java @@ -0,0 +1,282 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.servlet.ServletServerStream.toHexString; +import static java.util.logging.Level.FINE; +import static java.util.logging.Level.FINEST; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.InternalLogId; +import io.grpc.servlet.ServletServerStream.ServletTransportState; +import java.io.IOException; +import java.time.Duration; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; +import java.util.function.BiFunction; +import java.util.function.BooleanSupplier; +import java.util.logging.Logger; +import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.servlet.AsyncContext; +import javax.servlet.ServletOutputStream; + +/** Handles write actions from the container thread and the application thread. */ +final class AsyncServletOutputStreamWriter { + + /** + * Memory boundary for write actions. + * + *

    +   * WriteState curState = writeState.get();  // mark a boundary
    +   * doSomething();  // do something within the boundary
    +   * boolean successful = writeState.compareAndSet(curState, newState); // try to mark a boundary
    +   * if (successful) {
    +   *   // state has not changed since
    +   *   return;
    +   * } else {
    +   *   // state is changed by another thread while doSomething(), need recompute
    +   * }
    +   * 
    + * + *

    There are two threads, the container thread (calling {@code onWritePossible()}) and the + * application thread (calling {@code runOrBuffer()}) that read and update the + * writeState. Only onWritePossible() may turn {@code readyAndDrained} from false to true, and + * only runOrBuffer() may turn it from true to false. + */ + private final AtomicReference writeState = new AtomicReference<>(WriteState.DEFAULT); + + private final Log log; + private final BiFunction writeAction; + private final ActionItem flushAction; + private final ActionItem completeAction; + private final BooleanSupplier isReady; + + /** + * New write actions will be buffered into this queue if the servlet output stream is not ready or + * the queue is not drained. + */ + // SPSC queue would do + private final Queue writeChain = new ConcurrentLinkedQueue<>(); + // for a theoretical race condition that onWritePossible() is called immediately after isReady() + // returns false and before writeState.compareAndSet() + @Nullable + private volatile Thread parkingThread; + + AsyncServletOutputStreamWriter( + AsyncContext asyncContext, + ServletTransportState transportState, + InternalLogId logId) throws IOException { + Logger logger = Logger.getLogger(AsyncServletOutputStreamWriter.class.getName()); + this.log = new Log() { + @Override + public void fine(String str, Object... params) { + if (logger.isLoggable(FINE)) { + logger.log(FINE, "[" + logId + "]" + str, params); + } + } + + @Override + public void finest(String str, Object... params) { + if (logger.isLoggable(FINEST)) { + logger.log(FINEST, "[" + logId + "] " + str, params); + } + } + }; + + ServletOutputStream outputStream = asyncContext.getResponse().getOutputStream(); + this.writeAction = (byte[] bytes, Integer numBytes) -> () -> { + outputStream.write(bytes, 0, numBytes); + transportState.runOnTransportThread(() -> transportState.onSentBytes(numBytes)); + log.finest("outbound data: length={0}, bytes={1}", numBytes, toHexString(bytes, numBytes)); + }; + this.flushAction = () -> { + log.finest("flushBuffer"); + asyncContext.getResponse().flushBuffer(); + }; + this.completeAction = () -> { + log.fine("call is completing"); + transportState.runOnTransportThread( + () -> { + transportState.complete(); + asyncContext.complete(); + log.fine("call completed"); + }); + }; + this.isReady = () -> outputStream.isReady(); + } + + /** + * Constructor without java.util.logging and javax.servlet.* dependency, so that Lincheck can run. + * + * @param writeAction Provides an {@link ActionItem} to write given bytes with specified length. + * @param isReady Indicates whether the writer can write bytes at the moment (asynchronously). + */ + @VisibleForTesting + AsyncServletOutputStreamWriter( + BiFunction writeAction, + ActionItem flushAction, + ActionItem completeAction, + BooleanSupplier isReady, + Log log) { + this.writeAction = writeAction; + this.flushAction = flushAction; + this.completeAction = completeAction; + this.isReady = isReady; + this.log = log; + } + + /** Called from application thread. */ + void writeBytes(byte[] bytes, int numBytes) throws IOException { + runOrBuffer(writeAction.apply(bytes, numBytes)); + } + + /** Called from application thread. */ + void flush() throws IOException { + runOrBuffer(flushAction); + } + + /** Called from application thread. */ + void complete() { + try { + runOrBuffer(completeAction); + } catch (IOException ignore) { + // actually completeAction does not throw IOException + } + } + + /** Called from the container thread {@link javax.servlet.WriteListener#onWritePossible()}. */ + void onWritePossible() throws IOException { + log.finest("onWritePossible: ENTRY. The servlet output stream becomes ready"); + assureReadyAndDrainedTurnsFalse(); + while (isReady.getAsBoolean()) { + WriteState curState = writeState.get(); + + ActionItem actionItem = writeChain.poll(); + if (actionItem != null) { + actionItem.run(); + continue; + } + + if (writeState.compareAndSet(curState, curState.withReadyAndDrained(true))) { + // state has not changed since. + log.finest( + "onWritePossible: EXIT. All data available now is sent out and the servlet output" + + " stream is still ready"); + return; + } + // else, state changed by another thread (runOrBuffer()), need to drain the writeChain + // again + } + log.finest("onWritePossible: EXIT. The servlet output stream becomes not ready"); + } + + private void assureReadyAndDrainedTurnsFalse() { + // readyAndDrained should have been set to false already. + // Just in case due to a race condition readyAndDrained is still true at this moment and is + // being set to false by runOrBuffer() concurrently. + while (writeState.get().readyAndDrained) { + parkingThread = Thread.currentThread(); + LockSupport.parkNanos(Duration.ofMinutes(1).toNanos()); // should return immediately + } + parkingThread = null; + } + + /** + * Either execute the write action directly, or buffer the action and let the container thread + * drain it. + * + *

    Called from application thread. + */ + private void runOrBuffer(ActionItem actionItem) throws IOException { + WriteState curState = writeState.get(); + if (curState.readyAndDrained) { // write to the outputStream directly + actionItem.run(); + if (actionItem == completeAction) { + return; + } + if (!isReady.getAsBoolean()) { + boolean successful = + writeState.compareAndSet(curState, curState.withReadyAndDrained(false)); + LockSupport.unpark(parkingThread); + checkState(successful, "Bug: curState is unexpectedly changed by another thread"); + log.finest("the servlet output stream becomes not ready"); + } + } else { // buffer to the writeChain + writeChain.offer(actionItem); + if (!writeState.compareAndSet(curState, curState.withReadyAndDrained(false))) { + checkState( + writeState.get().readyAndDrained, + "Bug: onWritePossible() should have changed readyAndDrained to true, but not"); + ActionItem lastItem = writeChain.poll(); + if (lastItem != null) { + checkState(lastItem == actionItem, "Bug: lastItem != actionItem"); + runOrBuffer(lastItem); + } + } // state has not changed since + } + } + + /** Write actions, e.g. writeBytes, flush, complete. */ + @FunctionalInterface + @VisibleForTesting + interface ActionItem { + void run() throws IOException; + } + + @VisibleForTesting // Lincheck test can not run with java.util.logging dependency. + interface Log { + default void fine(String str, Object...params) {} + + default void finest(String str, Object...params) {} + } + + private static final class WriteState { + + static final WriteState DEFAULT = new WriteState(false); + + /** + * The servlet output stream is ready and the writeChain is empty. + * + *

    readyAndDrained turns from false to true when: + * {@code onWritePossible()} exits while currently there is no more data to write, but the last + * check of {@link javax.servlet.ServletOutputStream#isReady()} is true. + * + *

    readyAndDrained turns from true to false when: + * {@code runOrBuffer()} exits while either the action item is written directly to the + * servlet output stream and the check of {@link javax.servlet.ServletOutputStream#isReady()} + * right after that returns false, or the action item is buffered into the writeChain. + */ + final boolean readyAndDrained; + + WriteState(boolean readyAndDrained) { + this.readyAndDrained = readyAndDrained; + } + + /** + * Only {@code onWritePossible()} can set readyAndDrained to true, and only {@code + * runOrBuffer()} can set it to false. + */ + @CheckReturnValue + WriteState withReadyAndDrained(boolean readyAndDrained) { + return new WriteState(readyAndDrained); + } + } +} diff --git a/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java b/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java new file mode 100644 index 00000000000..a73b1fdfe6d --- /dev/null +++ b/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java @@ -0,0 +1,80 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.BindableService; +import io.grpc.ExperimentalApi; +import java.io.IOException; +import java.util.List; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * A simple servlet backed by a gRPC server. Must set {@code asyncSupported} to true. The {@code + * /contextRoot/urlPattern} must match the gRPC services' path, which is + * "/full-service-name/short-method-name". + * + *

    The API is experimental. The authors would like to know more about the real usecases. Users + * are welcome to provide feedback by commenting on + * the tracking issue. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/5066") +public class GrpcServlet extends HttpServlet { + private static final long serialVersionUID = 1L; + + private final ServletAdapter servletAdapter; + + @VisibleForTesting + GrpcServlet(ServletAdapter servletAdapter) { + this.servletAdapter = servletAdapter; + } + + /** + * Instantiate the servlet serving the given list of gRPC services. ServerInterceptors can be + * added on each gRPC service by {@link + * io.grpc.ServerInterceptors#intercept(BindableService, io.grpc.ServerInterceptor...)} + */ + public GrpcServlet(List bindableServices) { + this(loadServices(bindableServices)); + } + + private static ServletAdapter loadServices(List bindableServices) { + ServletServerBuilder serverBuilder = new ServletServerBuilder(); + bindableServices.forEach(serverBuilder::addService); + return serverBuilder.buildServletAdapter(); + } + + @Override + protected final void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException { + servletAdapter.doGet(request, response); + } + + @Override + protected final void doPost(HttpServletRequest request, HttpServletResponse response) + throws IOException { + servletAdapter.doPost(request, response); + } + + @Override + public void destroy() { + servletAdapter.destroy(); + super.destroy(); + } +} diff --git a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java new file mode 100644 index 00000000000..5a567916f99 --- /dev/null +++ b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java @@ -0,0 +1,333 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; +import static java.util.logging.Level.FINE; +import static java.util.logging.Level.FINEST; + +import com.google.common.io.BaseEncoding; +import io.grpc.Attributes; +import io.grpc.ExperimentalApi; +import io.grpc.Grpc; +import io.grpc.InternalLogId; +import io.grpc.InternalMetadata; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ReadableBuffers; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * An adapter that transforms {@link HttpServletRequest} into gRPC request and lets a gRPC server + * process it, and transforms the gRPC response into {@link HttpServletResponse}. An adapter can be + * instantiated by {@link ServletServerBuilder#buildServletAdapter()}. + * + *

    In a servlet, calling {@link #doPost(HttpServletRequest, HttpServletResponse)} inside {@link + * javax.servlet.http.HttpServlet#doPost(HttpServletRequest, HttpServletResponse)} makes the servlet + * backed by the gRPC server associated with the adapter. The servlet must support Asynchronous + * Processing and must be deployed to a container that supports servlet 4.0 and enables HTTP/2. + * + *

    The API is experimental. The authors would like to know more about the real usecases. Users + * are welcome to provide feedback by commenting on + * the tracking issue. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/5066") +public final class ServletAdapter { + + static final Logger logger = Logger.getLogger(ServletAdapter.class.getName()); + + private final ServerTransportListener transportListener; + private final List streamTracerFactories; + private final int maxInboundMessageSize; + private final Attributes attributes; + + ServletAdapter( + ServerTransportListener transportListener, + List streamTracerFactories, + int maxInboundMessageSize) { + this.transportListener = transportListener; + this.streamTracerFactories = streamTracerFactories; + this.maxInboundMessageSize = maxInboundMessageSize; + attributes = transportListener.transportReady(Attributes.EMPTY); + } + + /** + * Call this method inside {@link javax.servlet.http.HttpServlet#doGet(HttpServletRequest, + * HttpServletResponse)} to serve gRPC GET request. + * + *

    This method is currently not implemented. + * + *

    Note that in rare case gRPC client sends GET requests. + * + *

    Do not modify {@code req} and {@code resp} before or after calling this method. However, + * calling {@code resp.setBufferSize()} before invocation is allowed. + */ + public void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { + resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, "GET method not supported"); + } + + /** + * Call this method inside {@link javax.servlet.http.HttpServlet#doPost(HttpServletRequest, + * HttpServletResponse)} to serve gRPC POST request. + * + *

    Do not modify {@code req} and {@code resp} before or after calling this method. However, + * calling {@code resp.setBufferSize()} before invocation is allowed. + */ + public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { + checkArgument(req.isAsyncSupported(), "servlet does not support asynchronous operation"); + checkArgument(ServletAdapter.isGrpc(req), "the request is not a gRPC request"); + + InternalLogId logId = InternalLogId.allocate(ServletAdapter.class, null); + logger.log(FINE, "[{0}] RPC started", logId); + + AsyncContext asyncCtx = req.startAsync(req, resp); + + String method = req.getRequestURI().substring(1); // remove the leading "/" + Metadata headers = getHeaders(req); + + if (logger.isLoggable(FINEST)) { + logger.log(FINEST, "[{0}] method: {1}", new Object[] {logId, method}); + logger.log(FINEST, "[{0}] headers: {1}", new Object[] {logId, headers}); + } + + Long timeoutNanos = headers.get(TIMEOUT_KEY); + if (timeoutNanos == null) { + timeoutNanos = 0L; + } + asyncCtx.setTimeout(TimeUnit.NANOSECONDS.toMillis(timeoutNanos)); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext(streamTracerFactories, method, headers); + + ServletServerStream stream = new ServletServerStream( + asyncCtx, + statsTraceCtx, + maxInboundMessageSize, + attributes.toBuilder() + .set( + Grpc.TRANSPORT_ATTR_REMOTE_ADDR, + new InetSocketAddress(req.getRemoteHost(), req.getRemotePort())) + .set( + Grpc.TRANSPORT_ATTR_LOCAL_ADDR, + new InetSocketAddress(req.getLocalAddr(), req.getLocalPort())) + .build(), + getAuthority(req), + logId); + + transportListener.streamCreated(stream, method, headers); + stream.transportState().runOnTransportThread(stream.transportState()::onStreamAllocated); + + asyncCtx.getRequest().getInputStream() + .setReadListener(new GrpcReadListener(stream, asyncCtx, logId)); + asyncCtx.addListener(new GrpcAsyncListener(stream, logId)); + } + + // This method must use Enumeration and its members, since that is the only way to read headers + // from the servlet api. + @SuppressWarnings("JdkObsolete") + private static Metadata getHeaders(HttpServletRequest req) { + Enumeration headerNames = req.getHeaderNames(); + checkNotNull( + headerNames, "Servlet container does not allow HttpServletRequest.getHeaderNames()"); + List byteArrays = new ArrayList<>(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + Enumeration values = req.getHeaders(headerName); + if (values == null) { + continue; + } + while (values.hasMoreElements()) { + String value = values.nextElement(); + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + byteArrays.add(headerName.getBytes(StandardCharsets.US_ASCII)); + byteArrays.add(BaseEncoding.base64().decode(value)); + } else { + byteArrays.add(headerName.getBytes(StandardCharsets.US_ASCII)); + byteArrays.add(value.getBytes(StandardCharsets.US_ASCII)); + } + } + } + return InternalMetadata.newMetadata(byteArrays.toArray(new byte[][]{})); + } + + // This method must use HttpRequest#getRequestURL or HttpUtils#getRequestURL, both of which + // can only return StringBuffer instances + @SuppressWarnings("JdkObsolete") + private static String getAuthority(HttpServletRequest req) { + try { + return new URI(req.getRequestURL().toString()).getAuthority(); + } catch (URISyntaxException e) { + logger.log(FINE, "Error getting authority from the request URL {0}", req.getRequestURL()); + return req.getServerName() + ":" + req.getServerPort(); + } + } + + /** + * Call this method when the adapter is no longer needed. The gRPC server will be terminated. + */ + public void destroy() { + transportListener.transportTerminated(); + } + + private static final class GrpcAsyncListener implements AsyncListener { + final InternalLogId logId; + final ServletServerStream stream; + + GrpcAsyncListener(ServletServerStream stream, InternalLogId logId) { + this.stream = stream; + this.logId = logId; + } + + @Override + public void onComplete(AsyncEvent event) {} + + @Override + public void onTimeout(AsyncEvent event) { + if (logger.isLoggable(FINE)) { + logger.log(FINE, String.format("[{%s}] Timeout: ", logId), event.getThrowable()); + } + // If the resp is not committed, cancel() to avoid being redirected to an error page. + // Else, the container will send RST_STREAM in the end. + if (!event.getAsyncContext().getResponse().isCommitted()) { + stream.cancel(Status.DEADLINE_EXCEEDED); + } else { + stream.transportState().runOnTransportThread( + () -> stream.transportState().transportReportStatus(Status.DEADLINE_EXCEEDED)); + } + } + + @Override + public void onError(AsyncEvent event) { + if (logger.isLoggable(FINE)) { + logger.log(FINE, String.format("[{%s}] Error: ", logId), event.getThrowable()); + } + + // If the resp is not committed, cancel() to avoid being redirected to an error page. + // Else, the container will send RST_STREAM at the end. + if (!event.getAsyncContext().getResponse().isCommitted()) { + stream.cancel(Status.fromThrowable(event.getThrowable())); + } else { + stream.transportState().runOnTransportThread( + () -> stream.transportState().transportReportStatus( + Status.fromThrowable(event.getThrowable()))); + } + } + + @Override + public void onStartAsync(AsyncEvent event) {} + } + + private static final class GrpcReadListener implements ReadListener { + final ServletServerStream stream; + final AsyncContext asyncCtx; + final ServletInputStream input; + final InternalLogId logId; + + GrpcReadListener( + ServletServerStream stream, + AsyncContext asyncCtx, + InternalLogId logId) throws IOException { + this.stream = stream; + this.asyncCtx = asyncCtx; + input = asyncCtx.getRequest().getInputStream(); + this.logId = logId; + } + + final byte[] buffer = new byte[4 * 1024]; + + @Override + public void onDataAvailable() throws IOException { + logger.log(FINEST, "[{0}] onDataAvailable: ENTRY", logId); + + while (input.isReady()) { + int length = input.read(buffer); + if (length == -1) { + logger.log(FINEST, "[{0}] inbound data: read end of stream", logId); + return; + } else { + if (logger.isLoggable(FINEST)) { + logger.log( + FINEST, + "[{0}] inbound data: length = {1}, bytes = {2}", + new Object[] {logId, length, ServletServerStream.toHexString(buffer, length)}); + } + + byte[] copy = Arrays.copyOf(buffer, length); + stream.transportState().runOnTransportThread( + () -> stream.transportState().inboundDataReceived(ReadableBuffers.wrap(copy), false)); + } + } + + logger.log(FINEST, "[{0}] onDataAvailable: EXIT", logId); + } + + @Override + public void onAllDataRead() { + logger.log(FINE, "[{0}] onAllDataRead", logId); + stream.transportState().runOnTransportThread(() -> + stream.transportState().inboundDataReceived(ReadableBuffers.empty(), true)); + } + + @Override + public void onError(Throwable t) { + if (logger.isLoggable(FINE)) { + logger.log(FINE, String.format("[{%s}] Error: ", logId), t); + } + // If the resp is not committed, cancel() to avoid being redirected to an error page. + // Else, the container will send RST_STREAM at the end. + if (!asyncCtx.getResponse().isCommitted()) { + stream.cancel(Status.fromThrowable(t)); + } else { + stream.transportState().runOnTransportThread( + () -> stream.transportState() + .transportReportStatus(Status.fromThrowable(t))); + } + } + } + + /** + * Checks whether an incoming {@code HttpServletRequest} may come from a gRPC client. + * + * @return true if the request comes from a gRPC client + */ + public static boolean isGrpc(HttpServletRequest request) { + return request.getContentType() != null + && request.getContentType().contains(GrpcUtil.CONTENT_TYPE_GRPC); + } +} diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java new file mode 100644 index 00000000000..3e852ea3c09 --- /dev/null +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java @@ -0,0 +1,268 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.Attributes; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingServerBuilder; +import io.grpc.Internal; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalInstrumented; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ServerImplBuilder; +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.SharedResourceHolder; +import java.io.File; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Builder to build a gRPC server that can run as a servlet. This is for advanced custom settings. + * Normally, users should consider extending the out-of-box {@link GrpcServlet} directly instead. + * + *

    The API is experimental. The authors would like to know more about the real usecases. Users + * are welcome to provide feedback by commenting on + * the tracking issue. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/5066") +@NotThreadSafe +public final class ServletServerBuilder extends ForwardingServerBuilder { + List streamTracerFactories; + int maxInboundMessageSize = DEFAULT_MAX_MESSAGE_SIZE; + + private final ServerImplBuilder serverImplBuilder; + + private ScheduledExecutorService scheduler; + private boolean internalCaller; + private boolean usingCustomScheduler; + private InternalServerImpl internalServer; + + public ServletServerBuilder() { + serverImplBuilder = new ServerImplBuilder(this::buildTransportServers); + } + + /** + * Builds a gRPC server that can run as a servlet. + * + *

    The returned server will not be started or bound to a port. + * + *

    Users should not call this method directly. Instead users should call + * {@link #buildServletAdapter()} which internally will call {@code build()} and {@code start()} + * appropriately. + * + * @throws IllegalStateException if this method is called by users directly + */ + @Override + public Server build() { + checkState(internalCaller, "build() method should not be called directly by an application"); + return super.build(); + } + + /** + * Creates a {@link ServletAdapter}. + */ + public ServletAdapter buildServletAdapter() { + return new ServletAdapter(buildAndStart(), streamTracerFactories, maxInboundMessageSize); + } + + private ServerTransportListener buildAndStart() { + Server server; + try { + internalCaller = true; + server = build().start(); + } catch (IOException e) { + // actually this should never happen + throw new RuntimeException(e); + } finally { + internalCaller = false; + } + + if (!usingCustomScheduler) { + scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); + } + + // Create only one "transport" for all requests because it has no knowledge of which request is + // associated with which client socket. This "transport" does not do socket connection, the + // container does. + ServerTransportImpl serverTransport = new ServerTransportImpl(scheduler); + ServerTransportListener delegate = + internalServer.serverListener.transportCreated(serverTransport); + return new ServerTransportListener() { + @Override + public void streamCreated(ServerStream stream, String method, Metadata headers) { + delegate.streamCreated(stream, method, headers); + } + + @Override + public Attributes transportReady(Attributes attributes) { + return delegate.transportReady(attributes); + } + + @Override + public void transportTerminated() { + server.shutdown(); + delegate.transportTerminated(); + if (!usingCustomScheduler) { + SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, scheduler); + } + } + }; + } + + @VisibleForTesting + InternalServer buildTransportServers( + List streamTracerFactories) { + checkNotNull(streamTracerFactories, "streamTracerFactories"); + this.streamTracerFactories = streamTracerFactories; + internalServer = new InternalServerImpl(); + return internalServer; + } + + @Internal + @Override + protected ServerBuilder delegate() { + return serverImplBuilder; + } + + /** + * Throws {@code UnsupportedOperationException}. TLS should be configured by the servlet + * container. + */ + @Override + public ServletServerBuilder useTransportSecurity(File certChain, File privateKey) { + throw new UnsupportedOperationException("TLS should be configured by the servlet container"); + } + + @Override + public ServletServerBuilder maxInboundMessageSize(int bytes) { + checkArgument(bytes >= 0, "bytes must be >= 0"); + maxInboundMessageSize = bytes; + return this; + } + + /** + * Provides a custom scheduled executor service to the server builder. + * + * @return this + */ + public ServletServerBuilder scheduledExecutorService(ScheduledExecutorService scheduler) { + this.scheduler = checkNotNull(scheduler, "scheduler"); + usingCustomScheduler = true; + return this; + } + + private static final class InternalServerImpl implements InternalServer { + + ServerListener serverListener; + + InternalServerImpl() {} + + @Override + public void start(ServerListener listener) { + serverListener = listener; + } + + @Override + public void shutdown() { + if (serverListener != null) { + serverListener.serverShutdown(); + } + } + + @Override + public SocketAddress getListenSocketAddress() { + return new SocketAddress() { + @Override + public String toString() { + return "ServletServer"; + } + }; + } + + @Override + public InternalInstrumented getListenSocketStats() { + // sockets are managed by the servlet container, grpc is ignorant of that + return null; + } + + @Override + public List getListenSocketAddresses() { + return Collections.emptyList(); + } + + @Nullable + @Override + public List> getListenSocketStatsList() { + return null; + } + } + + @VisibleForTesting + static final class ServerTransportImpl implements ServerTransport { + + private final InternalLogId logId = InternalLogId.allocate(ServerTransportImpl.class, null); + private final ScheduledExecutorService scheduler; + + ServerTransportImpl(ScheduledExecutorService scheduler) { + this.scheduler = checkNotNull(scheduler, "scheduler"); + } + + @Override + public void shutdown() {} + + @Override + public void shutdownNow(Status reason) {} + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return scheduler; + } + + @Override + public ListenableFuture getStats() { + // does not support instrumentation + return null; + } + + @Override + public InternalLogId getLogId() { + return logId; + } + } +} diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java b/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java new file mode 100644 index 00000000000..0415eea942e --- /dev/null +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java @@ -0,0 +1,336 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static io.grpc.internal.GrpcUtil.CONTENT_TYPE_GRPC; +import static io.grpc.internal.GrpcUtil.CONTENT_TYPE_KEY; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.logging.Level.FINE; +import static java.util.logging.Level.FINEST; +import static java.util.logging.Level.WARNING; + +import com.google.common.io.BaseEncoding; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.Attributes; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.internal.AbstractServerStream; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.SerializingExecutor; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.TransportFrameUtil; +import io.grpc.internal.TransportTracer; +import io.grpc.internal.WritableBuffer; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import javax.servlet.AsyncContext; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletResponse; + +final class ServletServerStream extends AbstractServerStream { + + private static final Logger logger = Logger.getLogger(ServletServerStream.class.getName()); + + private final ServletTransportState transportState; + private final Sink sink = new Sink(); + private final AsyncContext asyncCtx; + private final HttpServletResponse resp; + private final Attributes attributes; + private final String authority; + private final InternalLogId logId; + private final AsyncServletOutputStreamWriter writer; + + ServletServerStream( + AsyncContext asyncCtx, + StatsTraceContext statsTraceCtx, + int maxInboundMessageSize, + Attributes attributes, + String authority, + InternalLogId logId) throws IOException { + super(ByteArrayWritableBuffer::new, statsTraceCtx); + transportState = + new ServletTransportState(maxInboundMessageSize, statsTraceCtx, new TransportTracer()); + this.attributes = attributes; + this.authority = authority; + this.logId = logId; + this.asyncCtx = asyncCtx; + this.resp = (HttpServletResponse) asyncCtx.getResponse(); + this.writer = new AsyncServletOutputStreamWriter( + asyncCtx, transportState, logId); + resp.getOutputStream().setWriteListener(new GrpcWriteListener()); + } + + @Override + protected ServletTransportState transportState() { + return transportState; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public String getAuthority() { + return authority; + } + + @Override + public int streamId() { + return -1; + } + + @Override + protected Sink abstractServerStreamSink() { + return sink; + } + + private void writeHeadersToServletResponse(Metadata metadata) { + // Discard any application supplied duplicates of the reserved headers + metadata.discardAll(CONTENT_TYPE_KEY); + metadata.discardAll(GrpcUtil.TE_HEADER); + metadata.discardAll(GrpcUtil.USER_AGENT_KEY); + + if (logger.isLoggable(FINE)) { + logger.log(FINE, "[{0}] writeHeaders {1}", new Object[] {logId, metadata}); + } + + resp.setStatus(HttpServletResponse.SC_OK); + resp.setContentType(CONTENT_TYPE_GRPC); + + byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(metadata); + for (int i = 0; i < serializedHeaders.length; i += 2) { + resp.addHeader( + new String(serializedHeaders[i], StandardCharsets.US_ASCII), + new String(serializedHeaders[i + 1], StandardCharsets.US_ASCII)); + } + } + + final class ServletTransportState extends TransportState { + + private final SerializingExecutor transportThreadExecutor = + new SerializingExecutor(MoreExecutors.directExecutor()); + + private ServletTransportState( + int maxMessageSize, StatsTraceContext statsTraceCtx, TransportTracer transportTracer) { + super(maxMessageSize, statsTraceCtx, transportTracer); + } + + @Override + public void runOnTransportThread(Runnable r) { + transportThreadExecutor.execute(r); + } + + @Override + public void bytesRead(int numBytes) { + // no-op + // no flow control yet + } + + @Override + public void deframeFailed(Throwable cause) { + if (logger.isLoggable(FINE)) { + logger.log(FINE, String.format("[{%s}] Exception processing message", logId), cause); + } + cancel(Status.fromThrowable(cause)); + } + } + + private static final class ByteArrayWritableBuffer implements WritableBuffer { + + private final int capacity; + final byte[] bytes; + private int index; + + ByteArrayWritableBuffer(int capacityHint) { + this.bytes = new byte[min(1024 * 1024, max(4096, capacityHint))]; + this.capacity = bytes.length; + } + + @Override + public void write(byte[] src, int srcIndex, int length) { + System.arraycopy(src, srcIndex, bytes, index, length); + index += length; + } + + @Override + public void write(byte b) { + bytes[index++] = b; + } + + @Override + public int writableBytes() { + return capacity - index; + } + + @Override + public int readableBytes() { + return index; + } + + @Override + public void release() {} + } + + private final class GrpcWriteListener implements WriteListener { + + @Override + public void onError(Throwable t) { + if (logger.isLoggable(FINE)) { + logger.log(FINE, String.format("[{%s}] Error: ", logId), t); + } + + // If the resp is not committed, cancel() to avoid being redirected to an error page. + // Else, the container will send RST_STREAM at the end. + if (!resp.isCommitted()) { + cancel(Status.fromThrowable(t)); + } else { + transportState.runOnTransportThread( + () -> transportState.transportReportStatus(Status.fromThrowable(t))); + } + } + + @Override + public void onWritePossible() throws IOException { + writer.onWritePossible(); + } + } + + private final class Sink implements AbstractServerStream.Sink { + final TrailerSupplier trailerSupplier = new TrailerSupplier(); + + @Override + public void writeHeaders(Metadata headers) { + writeHeadersToServletResponse(headers); + resp.setTrailerFields(trailerSupplier); + try { + writer.flush(); + } catch (IOException e) { + logger.log(WARNING, String.format("[{%s}] Exception when flushBuffer", logId), e); + cancel(Status.fromThrowable(e)); + } + } + + @Override + public void writeFrame(@Nullable WritableBuffer frame, boolean flush, int numMessages) { + if (frame == null && !flush) { + return; + } + + if (logger.isLoggable(FINEST)) { + logger.log( + FINEST, + "[{0}] writeFrame: numBytes = {1}, flush = {2}, numMessages = {3}", + new Object[]{logId, frame == null ? 0 : frame.readableBytes(), flush, numMessages}); + } + + try { + if (frame != null) { + int numBytes = frame.readableBytes(); + if (numBytes > 0) { + onSendingBytes(numBytes); + } + writer.writeBytes(((ByteArrayWritableBuffer) frame).bytes, frame.readableBytes()); + } + + if (flush) { + writer.flush(); + } + } catch (IOException e) { + logger.log(WARNING, String.format("[{%s}] Exception writing message", logId), e); + cancel(Status.fromThrowable(e)); + } + } + + @Override + public void writeTrailers(Metadata trailers, boolean headersSent, Status status) { + if (logger.isLoggable(FINE)) { + logger.log( + FINE, + "[{0}] writeTrailers: {1}, headersSent = {2}, status = {3}", + new Object[] {logId, trailers, headersSent, status}); + } + if (!headersSent) { + writeHeadersToServletResponse(trailers); + } else { + byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(trailers); + for (int i = 0; i < serializedHeaders.length; i += 2) { + String key = new String(serializedHeaders[i], StandardCharsets.US_ASCII); + String newValue = new String(serializedHeaders[i + 1], StandardCharsets.US_ASCII); + trailerSupplier.get().computeIfPresent(key, (k, v) -> v + "," + newValue); + trailerSupplier.get().putIfAbsent(key, newValue); + } + } + + writer.complete(); + } + + @Override + public void cancel(Status status) { + if (resp.isCommitted() && Code.DEADLINE_EXCEEDED == status.getCode()) { + return; // let the servlet timeout, the container will sent RST_STREAM automatically + } + transportState.runOnTransportThread(() -> transportState.transportReportStatus(status)); + // There is no way to RST_STREAM with CANCEL code, so write trailers instead + close(Status.CANCELLED.withCause(status.asRuntimeException()), new Metadata()); + CountDownLatch countDownLatch = new CountDownLatch(1); + transportState.runOnTransportThread(() -> { + asyncCtx.complete(); + countDownLatch.countDown(); + }); + try { + countDownLatch.await(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } + + private static final class TrailerSupplier implements Supplier> { + final Map trailers = Collections.synchronizedMap(new HashMap<>()); + + TrailerSupplier() {} + + @Override + public Map get() { + return trailers; + } + } + + static String toHexString(byte[] bytes, int length) { + String hex = BaseEncoding.base16().encode(bytes, 0, min(length, 64)); + if (length > 80) { + hex += "..."; + } + if (length > 64) { + int offset = max(64, length - 16); + hex += BaseEncoding.base16().encode(bytes, offset, length - offset); + } + return hex; + } +} diff --git a/servlet/src/main/java/io/grpc/servlet/package-info.java b/servlet/src/main/java/io/grpc/servlet/package-info.java new file mode 100644 index 00000000000..13d521fdde5 --- /dev/null +++ b/servlet/src/main/java/io/grpc/servlet/package-info.java @@ -0,0 +1,26 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * API that implements gRPC server as a servlet. The API requires that the application container + * supports Servlet 4.0 and enables HTTP/2. + * + *

    The API is experimental. The authors would like to know more about the real usecases. Users + * are welcome to provide feedback by commenting on + * the tracking issue. + */ +@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/5066") +package io.grpc.servlet; diff --git a/servlet/src/test/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java b/servlet/src/test/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java new file mode 100644 index 00000000000..61da2bf4c69 --- /dev/null +++ b/servlet/src/test/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java @@ -0,0 +1,174 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static com.google.common.truth.Truth.assertWithMessage; +import static org.jetbrains.kotlinx.lincheck.strategy.managed.ManagedStrategyGuaranteeKt.forClasses; + +import io.grpc.servlet.AsyncServletOutputStreamWriter.ActionItem; +import io.grpc.servlet.AsyncServletOutputStreamWriter.Log; +import java.io.IOException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import org.jetbrains.kotlinx.lincheck.LinChecker; +import org.jetbrains.kotlinx.lincheck.annotations.OpGroupConfig; +import org.jetbrains.kotlinx.lincheck.annotations.Operation; +import org.jetbrains.kotlinx.lincheck.annotations.Param; +import org.jetbrains.kotlinx.lincheck.paramgen.BooleanGen; +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingCTest; +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions; +import org.jetbrains.kotlinx.lincheck.verifier.VerifierState; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test concurrency correctness of {@link AsyncServletOutputStreamWriter} using model checking with + * Lincheck. + * + *

    This test should only call AsyncServletOutputStreamWriter's API surface and not rely on any + * implementation detail such as whether it's using a lock-free approach or not. + * + *

    The test executes two threads concurrently, one for write and flush, and the other for + * onWritePossible up to {@link #OPERATIONS_PER_THREAD} operations on each thread. Lincheck will + * test all possibly interleaves (on context switch) between the two threads, and then verify the + * operations are linearizable in each interleave scenario. + */ +@ModelCheckingCTest +@OpGroupConfig(name = "update", nonParallel = true) +@OpGroupConfig(name = "write", nonParallel = true) +@Param(name = "keepReady", gen = BooleanGen.class) +@RunWith(JUnit4.class) +public class AsyncServletOutputStreamWriterConcurrencyTest extends VerifierState { + private static final int OPERATIONS_PER_THREAD = 6; + + private final AsyncServletOutputStreamWriter writer; + private final boolean[] keepReadyArray = new boolean[OPERATIONS_PER_THREAD]; + + private volatile boolean isReady; + // when isReadyReturnedFalse, writer.onWritePossible() will be called. + private volatile boolean isReadyReturnedFalse; + private int producerIndex; + private int consumerIndex; + private int bytesWritten; + + /** Public no-args constructor. */ + public AsyncServletOutputStreamWriterConcurrencyTest() { + BiFunction writeAction = + (bytes, numBytes) -> () -> { + assertWithMessage("write should only be called while isReady() is true") + .that(isReady) + .isTrue(); + // The byte to be written must equal to consumerIndex, otherwise execution order is wrong + assertWithMessage("write in wrong order").that(bytes[0]).isEqualTo((byte) consumerIndex); + bytesWritten++; + writeOrFlush(); + }; + + ActionItem flushAction = () -> { + assertWithMessage("flush must only be called while isReady() is true").that(isReady).isTrue(); + writeOrFlush(); + }; + + writer = new AsyncServletOutputStreamWriter( + writeAction, + flushAction, + () -> { }, + this::isReady, + new Log() {}); + } + + private void writeOrFlush() { + boolean keepReady = keepReadyArray[consumerIndex]; + if (!keepReady) { + isReady = false; + } + consumerIndex++; + } + + private boolean isReady() { + if (!isReady) { + assertWithMessage("isReady() already returned false, onWritePossible() will be invoked") + .that(isReadyReturnedFalse).isFalse(); + isReadyReturnedFalse = true; + } + return isReady; + } + + /** + * Writes a single byte with value equal to {@link #producerIndex}. + * + * @param keepReady when the byte is written: + * the ServletOutputStream should remain ready if keepReady == true; + * the ServletOutputStream should become unready if keepReady == false. + */ + // @com.google.errorprone.annotations.Keep + @Operation(group = "write") + public void write(@Param(name = "keepReady") boolean keepReady) throws IOException { + keepReadyArray[producerIndex] = keepReady; + writer.writeBytes(new byte[]{(byte) producerIndex}, 1); + producerIndex++; + } + + /** + * Flushes the writer. + * + * @param keepReady when flushing: + * the ServletOutputStream should remain ready if keepReady == true; + * the ServletOutputStream should become unready if keepReady == false. + */ + // @com.google.errorprone.annotations.Keep // called by lincheck reflectively + @Operation(group = "write") + public void flush(@Param(name = "keepReady") boolean keepReady) throws IOException { + keepReadyArray[producerIndex] = keepReady; + writer.flush(); + producerIndex++; + } + + /** If the writer is not ready, let it turn ready and call writer.onWritePossible(). */ + // @com.google.errorprone.annotations.Keep // called by lincheck reflectively + @Operation(group = "update") + public void maybeOnWritePossible() throws IOException { + if (isReadyReturnedFalse) { + isReadyReturnedFalse = false; + isReady = true; + writer.onWritePossible(); + } + } + + @Override + protected Object extractState() { + return bytesWritten; + } + + @Test + public void linCheck() { + ModelCheckingOptions options = new ModelCheckingOptions() + .actorsBefore(0) + .threads(2) + .actorsPerThread(OPERATIONS_PER_THREAD) + .actorsAfter(0) + .addGuarantee( + forClasses( + ConcurrentLinkedQueue.class.getName(), + AtomicReference.class.getName()) + .allMethods() + .treatAsAtomic()); + LinChecker.check(AsyncServletOutputStreamWriterConcurrencyTest.class, options); + } +} diff --git a/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java b/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java new file mode 100644 index 00000000000..d571cfd45d5 --- /dev/null +++ b/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import java.util.Enumeration; +import java.util.StringTokenizer; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import javax.servlet.AsyncContext; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test for {@link ServletServerBuilder}. */ +@RunWith(JUnit4.class) +public class ServletServerBuilderTest { + + @Test + public void scheduledExecutorService() throws Exception { + ScheduledExecutorService scheduler = mock(ScheduledExecutorService.class); + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + ServletInputStream inputStream = mock(ServletInputStream.class); + ServletOutputStream outputStream = mock(ServletOutputStream.class); + ScheduledFuture future = mock(ScheduledFuture.class); + + doReturn(future).when(scheduler).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); + doReturn(true).when(request).isAsyncSupported(); + doReturn(asyncContext).when(request).startAsync(request, response); + doReturn("application/grpc").when(request).getContentType(); + doReturn("/hello/world").when(request).getRequestURI(); + @SuppressWarnings({"JdkObsolete", "unchecked"}) // Required by servlet API signatures. + // StringTokenizer is actually Enumeration + Enumeration headerNames = + (Enumeration) ((Enumeration) new StringTokenizer("grpc-timeout")); + @SuppressWarnings({"JdkObsolete", "unchecked"}) + Enumeration headers = + (Enumeration) ((Enumeration) new StringTokenizer("1m")); + doReturn(headerNames).when(request).getHeaderNames(); + doReturn(headers).when(request).getHeaders("grpc-timeout"); + doReturn(new StringBuffer("localhost:8080")).when(request).getRequestURL(); + doReturn(inputStream).when(request).getInputStream(); + doReturn("1.1.1.1").when(request).getLocalAddr(); + doReturn(8080).when(request).getLocalPort(); + doReturn("remote").when(request).getRemoteHost(); + doReturn(80).when(request).getRemotePort(); + doReturn(outputStream).when(response).getOutputStream(); + doReturn(request).when(asyncContext).getRequest(); + doReturn(response).when(asyncContext).getResponse(); + + ServletServerBuilder serverBuilder = + new ServletServerBuilder().scheduledExecutorService(scheduler); + ServletAdapter servletAdapter = serverBuilder.buildServletAdapter(); + servletAdapter.doPost(request, response); + + verify(asyncContext).setTimeout(1); + + // The following just verifies that scheduler is populated to the transport. + // It doesn't matter what tasks (such as handshake timeout and request deadline) are actually + // scheduled. + verify(scheduler, timeout(5000).atLeastOnce()) + .schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); + } +} diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java new file mode 100644 index 00000000000..f28a7419286 --- /dev/null +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.testing.integration.AbstractInteropTest; +import java.io.File; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.apache.coyote.http2.Http2Protocol; +import org.apache.tomcat.util.http.fileupload.FileUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Interop test for Tomcat server and Netty client. + */ +public class TomcatInteropTest extends AbstractInteropTest { + + private static final String HOST = "localhost"; + private static final String MYAPP = "/grpc.testing.TestService"; + private int port; + private Tomcat server; + + @After + @Override + public void tearDown() { + super.tearDown(); + try { + server.stop(); + } catch (LifecycleException e) { + throw new AssertionError(e); + } + } + + @AfterClass + public static void cleanUp() throws Exception { + FileUtils.deleteDirectory(new File("tomcat.0")); + } + + @Override + protected ServerBuilder getServerBuilder() { + return new ServletServerBuilder().maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + } + + @Override + protected void startServer(ServerBuilder builer) { + server = new Tomcat(); + server.setPort(0); + Context ctx = server.addContext(MYAPP, new File("build/tmp").getAbsolutePath()); + Tomcat + .addServlet( + ctx, "io.grpc.servlet.TomcatInteropTest", + new GrpcServlet(((ServletServerBuilder) builer).buildServletAdapter())) + .setAsyncSupported(true); + ctx.addServletMappingDecoded("/*", "io.grpc.servlet.TomcatInteropTest"); + + // Explicitly disable safeguards against malicious clients, as some unit tests trigger these + Http2Protocol http2Protocol = new Http2Protocol(); + http2Protocol.setOverheadCountFactor(0); + http2Protocol.setOverheadWindowUpdateThreshold(0); + http2Protocol.setOverheadContinuationThreshold(0); + http2Protocol.setOverheadDataThreshold(0); + + server.getConnector().addUpgradeProtocol(http2Protocol); + try { + server.start(); + } catch (LifecycleException e) { + throw new RuntimeException(e); + } + + port = server.getConnector().getLocalPort(); + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + NettyChannelBuilder builder = + (NettyChannelBuilder) ManagedChannelBuilder.forAddress(HOST, port) + .usePlaintext() + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + InternalNettyChannelBuilder.setStatsEnabled(builder, false); + builder.intercept(createCensusStatsClientInterceptor()); + return builder; + } + + @Override + protected boolean metricsExpected() { + return false; // otherwise re-test will not work + } + + // FIXME + @Override + @Ignore("Tomcat is broken on client GOAWAY") + @Test + public void gracefulShutdown() {} + + // FIXME + @Override + @Ignore("Tomcat is not able to send trailer only") + @Test + public void specialStatusMessage() {} + + // FIXME + @Override + @Ignore("Tomcat is not able to send trailer only") + @Test + public void unimplementedMethod() {} + + // FIXME + @Override + @Ignore("Tomcat is not able to send trailer only") + @Test + public void statusCodeAndMessage() {} + + // FIXME + @Override + @Ignore("Tomcat is not able to send trailer only") + @Test + public void emptyStream() {} +} diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java new file mode 100644 index 00000000000..43c69e13fdd --- /dev/null +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java @@ -0,0 +1,270 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalInstrumented; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.FakeClock; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerTransportListener; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NegotiationType; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.servlet.ServletServerBuilder.ServerTransportImpl; +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.apache.coyote.http2.Http2Protocol; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Transport test for Tomcat server and Netty client. + */ +public class TomcatTransportTest extends AbstractTransportTest { + private static final String MYAPP = "/service"; + + private final FakeClock fakeClock = new FakeClock(); + + private Tomcat tomcatServer; + private int port; + + @After + @Override + public void tearDown() throws InterruptedException { + super.tearDown(); + try { + tomcatServer.stop(); + } catch (LifecycleException e) { + throw new AssertionError(e); + } + } + + @Override + protected InternalServer newServer(List streamTracerFactories) { + return new InternalServer() { + final InternalServer delegate = + new ServletServerBuilder().buildTransportServers(streamTracerFactories); + + @Override + public void start(ServerListener listener) throws IOException { + delegate.start(listener); + ScheduledExecutorService scheduler = fakeClock.getScheduledExecutorService(); + ServerTransportListener serverTransportListener = + listener.transportCreated(new ServerTransportImpl(scheduler)); + ServletAdapter adapter = + new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + GrpcServlet grpcServlet = new GrpcServlet(adapter); + + tomcatServer = new Tomcat(); + tomcatServer.setPort(0); + Context ctx = tomcatServer.addContext(MYAPP, new File("build/tmp").getAbsolutePath()); + Tomcat.addServlet(ctx, "TomcatTransportTest", grpcServlet) + .setAsyncSupported(true); + ctx.addServletMappingDecoded("/*", "TomcatTransportTest"); + tomcatServer.getConnector().addUpgradeProtocol(new Http2Protocol()); + try { + tomcatServer.start(); + } catch (LifecycleException e) { + throw new RuntimeException(e); + } + + port = tomcatServer.getConnector().getLocalPort(); + } + + @Override + public void shutdown() { + delegate.shutdown(); + } + + @Override + public SocketAddress getListenSocketAddress() { + return delegate.getListenSocketAddress(); + } + + @Override + public InternalInstrumented getListenSocketStats() { + return delegate.getListenSocketStats(); + } + + @Override + public List getListenSocketAddresses() { + return delegate.getListenSocketAddresses(); + } + + @Nullable + @Override + public List> getListenSocketStatsList() { + return delegate.getListenSocketStatsList(); + } + }; + } + + @Override + protected InternalServer newServer( + int port, List streamTracerFactories) { + return newServer(streamTracerFactories); + } + + @Override + protected ManagedClientTransport newClientTransport(InternalServer server) { + NettyChannelBuilder nettyChannelBuilder = NettyChannelBuilder + // Although specified here, address is ignored because we never call build. + .forAddress("localhost", 0) + .flowControlWindow(65 * 1024) + .negotiationType(NegotiationType.PLAINTEXT); + InternalNettyChannelBuilder + .setTransportTracerFactory(nettyChannelBuilder, fakeClockTransportTracer); + ClientTransportFactory clientFactory = + InternalNettyChannelBuilder.buildTransportFactory(nettyChannelBuilder); + return clientFactory.newClientTransport( + new InetSocketAddress("localhost", port), + new ClientTransportFactory.ClientTransportOptions() + .setAuthority(testAuthority(server)) + .setEagAttributes(eagAttrs()), + transportLogger()); + } + + @Override + protected String testAuthority(InternalServer server) { + return "localhost:" + port; + } + + @Override + protected void advanceClock(long offset, TimeUnit unit) { + fakeClock.forwardNanos(unit.toNanos(offset)); + } + + @Override + protected long fakeCurrentTimeNanos() { + return fakeClock.getTicker().read(); + } + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void serverAlreadyListening() {} + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void openStreamPreventsTermination() {} + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void shutdownNowKillsServerStream() {} + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void serverNotListening() {} + + // FIXME + @Override + @Ignore("Tomcat is broken on client GOAWAY") + @Test + public void newStream_duringShutdown() {} + + // FIXME + @Override + @Ignore("Tomcat is broken on client GOAWAY") + @Test + public void ping_duringShutdown() {} + + // FIXME + @Override + @Ignore("Tomcat is broken on client RST_STREAM") + @Test + public void frameAfterRstStreamShouldNotBreakClientChannel() {} + + // FIXME + @Override + @Ignore("Tomcat is broken on client RST_STREAM") + @Test + public void shutdownNowKillsClientStream() {} + + // FIXME + @Override + @Ignore("Servlet flow control not implemented yet") + @Test + public void flowControlPushBack() {} + + @Override + @Ignore("Server side sockets are managed by the servlet container") + @Test + public void socketStats() {} + + @Override + @Ignore("serverTransportListener will not terminate") + @Test + public void clientStartAndStopOnceConnected() {} + + @Override + @Ignore("clientStreamTracer1.getInboundTrailers() is not null; listeners.poll() doesn't apply") + @Test + public void serverCancel() {} + + @Override + @Ignore("This doesn't apply: Ensure that for a closed ServerStream, interactions are noops") + @Test + public void interactionsAfterServerStreamCloseAreNoops() {} + + @Override + @Ignore("listeners.poll() doesn't apply") + @Test + public void interactionsAfterClientStreamCancelAreNoops() {} + + @Override + @Ignore("assertNull(serverStatus.getCause()) isn't true") + @Test + public void clientCancel() {} + + @Override + @Ignore("Tomcat does not support trailers only") + @Test + public void earlyServerClose_noServerHeaders() {} + + @Override + @Ignore("Tomcat does not support trailers only") + @Test + public void earlyServerClose_serverFailure() {} + + @Override + @Ignore("Tomcat does not support trailers only") + @Test + public void earlyServerClose_serverFailure_withClientCancelOnListenerClosed() {} + + @Override + @Ignore("regression since bumping grpc v1.46 to v1.53") + @Test + public void messageProducerOnlyProducesRequestedMessages() {} +} diff --git a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowInteropTest.java b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowInteropTest.java new file mode 100644 index 00000000000..600400b14b8 --- /dev/null +++ b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowInteropTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static io.undertow.servlet.Servlets.defaultContainer; +import static io.undertow.servlet.Servlets.deployment; +import static io.undertow.servlet.Servlets.servlet; + +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.testing.integration.AbstractInteropTest; +import io.undertow.Handlers; +import io.undertow.Undertow; +import io.undertow.UndertowOptions; +import io.undertow.server.HttpHandler; +import io.undertow.server.handlers.PathHandler; +import io.undertow.servlet.api.DeploymentInfo; +import io.undertow.servlet.api.DeploymentManager; +import io.undertow.servlet.api.InstanceFactory; +import io.undertow.servlet.util.ImmediateInstanceHandle; +import java.net.InetSocketAddress; +import javax.servlet.Servlet; +import javax.servlet.ServletException; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Interop test for Undertow server and Netty client. + */ +public class UndertowInteropTest extends AbstractInteropTest { + private static final String HOST = "localhost"; + private static final String MYAPP = "/grpc.testing.TestService"; + private int port; + private Undertow server; + private DeploymentManager manager; + + @After + @Override + public void tearDown() { + super.tearDown(); + if (server != null) { + server.stop(); + } + if (manager != null) { + try { + manager.stop(); + } catch (ServletException e) { + throw new AssertionError("failed to stop container", e); + } + } + } + + @Override + protected ServletServerBuilder getServerBuilder() { + return new ServletServerBuilder().maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + } + + @Override + protected void startServer(ServerBuilder builder) { + GrpcServlet grpcServlet = + new GrpcServlet(((ServletServerBuilder) builder).buildServletAdapter()); + InstanceFactory instanceFactory = + () -> new ImmediateInstanceHandle<>(grpcServlet); + DeploymentInfo servletBuilder = + deployment() + .setClassLoader(UndertowInteropTest.class.getClassLoader()) + .setContextPath(MYAPP) + .setDeploymentName("UndertowInteropTest.war") + .addServlets( + servlet("InteropTestServlet", GrpcServlet.class, instanceFactory) + .addMapping("/*") + .setAsyncSupported(true)); + + manager = defaultContainer().addDeployment(servletBuilder); + manager.deploy(); + + HttpHandler servletHandler; + try { + servletHandler = manager.start(); + } catch (ServletException e) { + throw new RuntimeException(e); + } + PathHandler path = Handlers.path(Handlers.redirect(MYAPP)) + .addPrefixPath("/", servletHandler); // for unimplementedService test + server = Undertow.builder() + .setServerOption(UndertowOptions.ENABLE_HTTP2, true) + .setServerOption(UndertowOptions.SHUTDOWN_TIMEOUT, 5000 /* 5 sec */) + .addHttpListener(0, HOST) + .setHandler(path) + .build(); + server.start(); + port = ((InetSocketAddress) server.getListenerInfo().get(0).getAddress()).getPort(); + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + NettyChannelBuilder builder = (NettyChannelBuilder) ManagedChannelBuilder + .forAddress(HOST, port) + .usePlaintext() + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + InternalNettyChannelBuilder.setStatsEnabled(builder, false); + builder.intercept(createCensusStatsClientInterceptor()); + return builder; + } + + // FIXME + @Override + @Ignore("Undertow is broken on client GOAWAY") + @Test + public void gracefulShutdown() {} +} diff --git a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java new file mode 100644 index 00000000000..9d894b5e3f2 --- /dev/null +++ b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java @@ -0,0 +1,304 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static io.undertow.servlet.Servlets.defaultContainer; +import static io.undertow.servlet.Servlets.deployment; +import static io.undertow.servlet.Servlets.servlet; + +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalInstrumented; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.FakeClock; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerTransportListener; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NegotiationType; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.servlet.ServletServerBuilder.ServerTransportImpl; +import io.undertow.Handlers; +import io.undertow.Undertow; +import io.undertow.UndertowOptions; +import io.undertow.server.HttpHandler; +import io.undertow.server.handlers.PathHandler; +import io.undertow.servlet.api.DeploymentInfo; +import io.undertow.servlet.api.DeploymentManager; +import io.undertow.servlet.api.InstanceFactory; +import io.undertow.servlet.util.ImmediateInstanceHandle; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import javax.servlet.Servlet; +import javax.servlet.ServletException; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Transport test for Undertow server and Netty client. + */ +public class UndertowTransportTest extends AbstractTransportTest { + + private static final String HOST = "localhost"; + private static final String MYAPP = "/service"; + + private final FakeClock fakeClock = new FakeClock(); + + private Undertow undertowServer; + private DeploymentManager manager; + private int port; + + @After + @Override + public void tearDown() throws InterruptedException { + super.tearDown(); + if (undertowServer != null) { + undertowServer.stop(); + } + if (manager != null) { + try { + manager.stop(); + } catch (ServletException e) { + throw new AssertionError("failed to stop container", e); + } + } + } + + @Override + protected InternalServer newServer(List + streamTracerFactories) { + return new InternalServer() { + final InternalServer delegate = + new ServletServerBuilder().buildTransportServers(streamTracerFactories); + + @Override + public void start(ServerListener listener) throws IOException { + delegate.start(listener); + ScheduledExecutorService scheduler = fakeClock.getScheduledExecutorService(); + ServerTransportListener serverTransportListener = + listener.transportCreated(new ServerTransportImpl(scheduler)); + ServletAdapter adapter = + new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + GrpcServlet grpcServlet = new GrpcServlet(adapter); + InstanceFactory instanceFactory = + () -> new ImmediateInstanceHandle<>(grpcServlet); + DeploymentInfo servletBuilder = + deployment() + .setClassLoader(UndertowInteropTest.class.getClassLoader()) + .setContextPath(MYAPP) + .setDeploymentName("UndertowTransportTest.war") + .addServlets( + servlet("TransportTestServlet", GrpcServlet.class, instanceFactory) + .addMapping("/*") + .setAsyncSupported(true)); + + manager = defaultContainer().addDeployment(servletBuilder); + manager.deploy(); + + HttpHandler servletHandler; + try { + servletHandler = manager.start(); + } catch (ServletException e) { + throw new RuntimeException(e); + } + PathHandler path = + Handlers.path(Handlers.redirect(MYAPP)) + .addPrefixPath("/", servletHandler); // for unimplementedService test + undertowServer = + Undertow.builder() + .setServerOption(UndertowOptions.ENABLE_HTTP2, true) + .setServerOption(UndertowOptions.SHUTDOWN_TIMEOUT, 5000 /* 5 sec */) + .addHttpListener(0, HOST) + .setHandler(path) + .build(); + undertowServer.start(); + port = ((InetSocketAddress) undertowServer.getListenerInfo().get(0).getAddress()).getPort(); + } + + @Override + public void shutdown() { + delegate.shutdown(); + } + + @Override + public SocketAddress getListenSocketAddress() { + return delegate.getListenSocketAddress(); + } + + @Override + public InternalInstrumented getListenSocketStats() { + return delegate.getListenSocketStats(); + } + + @Override + public List getListenSocketAddresses() { + return delegate.getListenSocketAddresses(); + } + + @Nullable + @Override + public List> getListenSocketStatsList() { + return delegate.getListenSocketStatsList(); + } + }; + } + + @Override + protected InternalServer newServer(int port, + List streamTracerFactories) { + return newServer(streamTracerFactories); + } + + @Override + protected ManagedClientTransport newClientTransport(InternalServer server) { + NettyChannelBuilder nettyChannelBuilder = NettyChannelBuilder + // Although specified here, address is ignored because we never call build. + .forAddress("localhost", 0) + .flowControlWindow(65 * 1024) + .negotiationType(NegotiationType.PLAINTEXT); + InternalNettyChannelBuilder + .setTransportTracerFactory(nettyChannelBuilder, fakeClockTransportTracer); + ClientTransportFactory clientFactory = + InternalNettyChannelBuilder.buildTransportFactory(nettyChannelBuilder); + return clientFactory.newClientTransport( + new InetSocketAddress("localhost", port), + new ClientTransportFactory.ClientTransportOptions() + .setAuthority(testAuthority(server)) + .setEagAttributes(eagAttrs()), + transportLogger()); + } + + @Override + protected String testAuthority(InternalServer server) { + return "localhost:" + port; + } + + @Override + protected void advanceClock(long offset, TimeUnit unit) { + fakeClock.forwardNanos(unit.toNanos(offset)); + } + + @Override + protected long fakeCurrentTimeNanos() { + return fakeClock.getTicker().read(); + } + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void serverAlreadyListening() {} + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void openStreamPreventsTermination() {} + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void shutdownNowKillsServerStream() {} + + @Override + @Ignore("Skip the test, server lifecycle is managed by the container") + @Test + public void serverNotListening() {} + + @Override + @Ignore("Skip the test, can not set HTTP/2 SETTINGS_MAX_HEADER_LIST_SIZE") + @Test + public void serverChecksInboundMetadataSize() {} + + // FIXME + @Override + @Ignore("Undertow is broken on client GOAWAY") + @Test + public void newStream_duringShutdown() {} + + // FIXME + @Override + @Ignore("Undertow is broken on client GOAWAY") + @Test + public void ping_duringShutdown() {} + + // FIXME + @Override + @Ignore("Undertow is broken on client RST_STREAM") + @Test + public void frameAfterRstStreamShouldNotBreakClientChannel() {} + + // FIXME + @Override + @Ignore("Undertow is broken on client RST_STREAM") + @Test + public void shutdownNowKillsClientStream() {} + + // FIXME: https://github.com/grpc/grpc-java/issues/8925 + @Override + @Ignore("flaky") + @Test + public void clientCancelFromWithinMessageRead() {} + + // FIXME + @Override + @Ignore("Servlet flow control not implemented yet") + @Test + public void flowControlPushBack() {} + + @Override + @Ignore("Server side sockets are managed by the servlet container") + @Test + public void socketStats() {} + + @Override + @Ignore("serverTransportListener will not terminate") + @Test + public void clientStartAndStopOnceConnected() {} + + @Override + @Ignore("clientStreamTracer1.getInboundTrailers() is not null; listeners.poll() doesn't apply") + @Test + public void serverCancel() {} + + @Override + @Ignore("This doesn't apply: Ensure that for a closed ServerStream, interactions are noops") + @Test + public void interactionsAfterServerStreamCloseAreNoops() {} + + @Override + @Ignore("listeners.poll() doesn't apply") + @Test + public void interactionsAfterClientStreamCancelAreNoops() {} + + + @Override + @Ignore("assertNull(serverStatus.getCause()) isn't true") + @Test + public void clientCancel() {} + + @Override + @Ignore("regression since bumping grpc v1.46 to v1.53") + @Test + public void messageProducerOnlyProducesRequestedMessages() {} +} diff --git a/settings.gradle b/settings.gradle index cd50337d5c2..92db19a8839 100644 --- a/settings.gradle +++ b/settings.gradle @@ -4,13 +4,14 @@ pluginManagement { id "com.android.library" version "4.2.0" id "com.github.johnrengelman.shadow" version "7.1.2" id "com.github.kt3k.coveralls" version "2.12.0" - id "com.google.osdetector" version "1.7.0" - id "com.google.protobuf" version "0.8.18" + id "com.google.cloud.tools.jib" version "3.3.1" + id "com.google.osdetector" version "1.7.1" + id "com.google.protobuf" version "0.9.1" id "digital.wup.android-maven-publish" version "3.6.3" id "me.champeau.gradle.japicmp" version "0.3.0" - id "me.champeau.jmh" version "0.6.6" - id "net.ltgt.errorprone" version "2.0.2" - id "ru.vyarus.animalsniffer" version "1.5.4" + id "me.champeau.jmh" version "0.6.8" + id "net.ltgt.errorprone" version "3.0.1" + id "ru.vyarus.animalsniffer" version "1.6.0" } resolutionStrategy { eachPlugin { @@ -37,6 +38,7 @@ include ":grpc-protobuf" include ":grpc-protobuf-lite" include ":grpc-netty" include ":grpc-netty-shaded" +include ":grpc-googleapis" include ":grpc-grpclb" include ":grpc-testing" include ":grpc-testing-proto" @@ -46,11 +48,14 @@ include ":grpc-all" include ":grpc-alts" include ":grpc-benchmarks" include ":grpc-services" +include ":grpc-servlet" +include ":grpc-servlet-jakarta" include ":grpc-xds" include ":grpc-bom" include ":grpc-rls" include ":grpc-authz" -include ":grpc-observability" +include ":grpc-gcp-observability" +include ":grpc-istio-interop-testing" project(':grpc-api').projectDir = "$rootDir/api" as File project(':grpc-core').projectDir = "$rootDir/core" as File @@ -63,6 +68,7 @@ project(':grpc-protobuf').projectDir = "$rootDir/protobuf" as File project(':grpc-protobuf-lite').projectDir = "$rootDir/protobuf-lite" as File project(':grpc-netty').projectDir = "$rootDir/netty" as File project(':grpc-netty-shaded').projectDir = "$rootDir/netty/shaded" as File +project(':grpc-googleapis').projectDir = "$rootDir/googleapis" as File project(':grpc-grpclb').projectDir = "$rootDir/grpclb" as File project(':grpc-testing').projectDir = "$rootDir/testing" as File project(':grpc-testing-proto').projectDir = "$rootDir/testing-proto" as File @@ -72,11 +78,14 @@ project(':grpc-all').projectDir = "$rootDir/all" as File project(':grpc-alts').projectDir = "$rootDir/alts" as File project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File project(':grpc-services').projectDir = "$rootDir/services" as File +project(':grpc-servlet').projectDir = "$rootDir/servlet" as File +project(':grpc-servlet-jakarta').projectDir = "$rootDir/servlet/jakarta" as File project(':grpc-xds').projectDir = "$rootDir/xds" as File project(':grpc-bom').projectDir = "$rootDir/bom" as File project(':grpc-rls').projectDir = "$rootDir/rls" as File project(':grpc-authz').projectDir = "$rootDir/authz" as File -project(':grpc-observability').projectDir = "$rootDir/observability" as File +project(':grpc-gcp-observability').projectDir = "$rootDir/gcp-observability" as File +project(':grpc-istio-interop-testing').projectDir = "$rootDir/istio-interop-testing" as File if (settings.hasProperty('skipCodegen') && skipCodegen.toBoolean()) { println '*** Skipping the build of codegen and compilation of proto files because skipCodegen=true' diff --git a/stub/build.gradle b/stub/build.gradle index 2b5a6a4edb6..16a9ca2d995 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -10,13 +10,13 @@ description = "gRPC: Stub" dependencies { api project(':grpc-api'), libraries.guava - implementation libraries.errorprone + implementation libraries.errorprone.annotations testImplementation libraries.truth, project(':grpc-testing') - signature "org.codehaus.mojo.signature:java17:1.0@signature" - signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" + signature libraries.signature.java + signature libraries.signature.android } -javadoc { +tasks.named("javadoc").configure { exclude 'io/grpc/stub/Internal*' } diff --git a/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java b/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java index 3c090357473..c6f912cb3a7 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java @@ -35,7 +35,7 @@ public abstract class AbstractAsyncStub> extends AbstractStub { protected AbstractAsyncStub(Channel channel, CallOptions callOptions) { - super(channel, callOptions); + super(channel, callOptions); } /** diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index 04ed83f083a..6986a285ae2 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -19,8 +19,10 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ListenableFuture; import io.grpc.CallOptions; @@ -39,6 +41,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.locks.LockSupport; import java.util.logging.Level; import java.util.logging.Logger; @@ -53,6 +56,11 @@ public final class ClientCalls { private static final Logger logger = Logger.getLogger(ClientCalls.class.getName()); + @VisibleForTesting + static boolean rejectRunnableOnExecutor = + !Strings.isNullOrEmpty(System.getenv("GRPC_CLIENT_CALL_REJECT_RUNNABLE")) + && Boolean.parseBoolean(System.getenv("GRPC_CLIENT_CALL_REJECT_RUNNABLE")); + // Prevent instantiation private ClientCalls() {} @@ -153,6 +161,7 @@ public static RespT blockingUnaryCall( // Now wait for onClose() to be called, so interceptors can clean up } } + executor.shutdown(); return getUnchecked(responseFuture); } catch (RuntimeException e) { // Something very bad happened. All bets are off; it may be dangerous to wait for onClose(). @@ -500,6 +509,7 @@ void onStart() { private static final class UnaryStreamToFuture extends StartableListener { private final GrpcFuture responseFuture; private RespT value; + private boolean isValueReceived = false; // Non private to avoid synthetic class UnaryStreamToFuture(GrpcFuture responseFuture) { @@ -512,17 +522,18 @@ public void onHeaders(Metadata headers) { @Override public void onMessage(RespT value) { - if (this.value != null) { + if (this.isValueReceived) { throw Status.INTERNAL.withDescription("More than one value received for unary call") .asRuntimeException(); } this.value = value; + this.isValueReceived = true; } @Override public void onClose(Status status, Metadata trailers) { if (status.isOk()) { - if (value == null) { + if (!isValueReceived) { // No value received so mark the future as an error responseFuture.setException( Status.INTERNAL.withDescription("No value received for unary call") @@ -626,6 +637,9 @@ private Object waitForNext() { // Now wait for onClose() to be called, so interceptors can clean up } } + if (next == this || next instanceof StatusRuntimeException) { + threadless.shutdown(); + } return next; } } finally { @@ -712,7 +726,10 @@ private static final class ThreadlessExecutor extends ConcurrentLinkedQueue extends StreamObserver { /** - * Called by the runtime priot to the start of a call to provide a reference to the + * Called by the runtime prior to the start of a call to provide a reference to the * {@link ClientCallStreamObserver} for the outbound stream. This can be used to listen to * onReady events, disable auto inbound flow and perform other advanced functions. * diff --git a/stub/src/main/java/io/grpc/stub/annotations/RpcMethod.java b/stub/src/main/java/io/grpc/stub/annotations/RpcMethod.java index fbf46baed87..3615eddfa79 100644 --- a/stub/src/main/java/io/grpc/stub/annotations/RpcMethod.java +++ b/stub/src/main/java/io/grpc/stub/annotations/RpcMethod.java @@ -25,7 +25,7 @@ /** * {@link RpcMethod} contains a limited subset of information about the RPC to assist * - * Java Annotation Processors. + * Java Annotation Processors. * *

    * This annotation is used by the gRPC stub compiler to annotate {@link MethodDescriptor} diff --git a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java index c394fc09de6..d5cf572a8b9 100644 --- a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; @@ -58,6 +59,8 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -97,10 +100,12 @@ public class ClientCallsTest { private ArgumentCaptor> methodDescriptorCaptor; @Captor private ArgumentCaptor callOptionsCaptor; + private boolean originalRejectRunnableOnExecutor; @Before public void setUp() { MockitoAnnotations.initMocks(this); + originalRejectRunnableOnExecutor = ClientCalls.rejectRunnableOnExecutor; } @After @@ -111,6 +116,7 @@ public void tearDown() { if (channel != null) { channel.shutdownNow(); } + ClientCalls.rejectRunnableOnExecutor = originalRejectRunnableOnExecutor; } @Test @@ -217,6 +223,49 @@ class NoopUnaryMethod implements UnaryMethod { assertTrue("context not cancelled", methodImpl.observer.isCancelled()); } + @Test + public void blockingUnaryCall2_rejectExecutionOnClose() throws Exception { + Integer req = 2; + + class NoopUnaryMethod implements UnaryMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopUnaryMethod methodImpl = new NoopUnaryMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + try { + ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + assertNotNull("callOptionsExecutor should not be null", interceptor.savedExecutor); + ClientCalls.rejectRunnableOnExecutor = true; + try { + interceptor.savedExecutor.execute(() -> { }); + fail(); + } catch (Exception ex) { + assertTrue(ex instanceof RejectedExecutionException); + } + } + @Test public void blockingUnaryCall_HasBlockingStubType() { NoopClientCall call = new NoopClientCall() { @@ -858,6 +907,7 @@ class NoopServerStreamingMethod implements ServerStreamingMethod ClientCall interceptCall( @@ -867,6 +917,7 @@ public ClientCall interceptCall( super.start(new SimpleForwardingClientCallListener(listener) { @Override public void onClose(Status status, Metadata trailers) { onCloseCalled = true; + savedExecutor = callOptions.getExecutor(); super.onClose(status, trailers); } }, headers); diff --git a/testing-proto/build.gradle b/testing-proto/build.gradle index f12fe9250cc..168c059e66c 100644 --- a/testing-proto/build.gradle +++ b/testing-proto/build.gradle @@ -3,6 +3,7 @@ plugins { id "maven-publish" id "com.google.protobuf" + id "ru.vyarus.animalsniffer" } description = "gRPC: Testing Protos" @@ -10,9 +11,10 @@ description = "gRPC: Testing Protos" dependencies { api project(':grpc-protobuf'), project(':grpc-stub') - compileOnly libraries.javax_annotation + compileOnly libraries.javax.annotation testImplementation libraries.truth - testRuntimeOnly libraries.javax_annotation + testRuntimeOnly libraries.javax.annotation + signature libraries.signature.java } configureProtoCompilation() diff --git a/testing/build.gradle b/testing/build.gradle index 0879eca502a..35da61f11fc 100644 --- a/testing/build.gradle +++ b/testing/build.gradle @@ -3,6 +3,7 @@ plugins { id "maven-publish" id "me.champeau.gradle.japicmp" + id "ru.vyarus.animalsniffer" } description = "gRPC: Testing" @@ -14,21 +15,24 @@ dependencies { project(':grpc-stub'), libraries.junit // Only io.grpc.internal.testing.StatsTestUtils depends on opencensus_api, for internal use. - compileOnly libraries.opencensus_api + compileOnly libraries.opencensus.api runtimeOnly project(":grpc-context") // Pull in newer version than census-api - testImplementation (libraries.mockito) { + testImplementation (libraries.mockito.core) { // prefer our own versions instead of mockito's dependency exclude group: 'org.hamcrest', module: 'hamcrest-core' } testImplementation project(':grpc-testing-proto'), project(':grpc-core').sourceSets.test.output + + signature libraries.signature.java + signature libraries.signature.android } -javadoc { exclude 'io/grpc/internal/**' } +tasks.named("javadoc").configure { exclude 'io/grpc/internal/**' } -jacocoTestReport { +tasks.named("jacocoTestReport").configure { classDirectories.from = sourceSets.main.output.collect { fileTree(dir: it, exclude: [ diff --git a/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java b/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java index 444537962e8..1d0061f6120 100644 --- a/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java +++ b/testing/src/main/java/io/grpc/internal/testing/TestStreamTracer.java @@ -18,6 +18,7 @@ import io.grpc.Status; import io.grpc.StreamTracer; +import java.util.Locale; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -180,6 +181,7 @@ public void outboundMessageSent( int seqNo, long optionalWireSize, long optionalUncompressedSize) { outboundEvents.add( String.format( + Locale.US, "outboundMessageSent(%d, %d, %d)", seqNo, optionalWireSize, optionalUncompressedSize)); } @@ -189,6 +191,7 @@ public void inboundMessageRead( int seqNo, long optionalWireSize, long optionalUncompressedSize) { inboundEvents.add( String.format( + Locale.US, "inboundMessageRead(%d, %d, %d)", seqNo, optionalWireSize, optionalUncompressedSize)); } diff --git a/testing/src/main/java/io/grpc/testing/GrpcCleanupRule.java b/testing/src/main/java/io/grpc/testing/GrpcCleanupRule.java index 47a3416d4d3..f518dcb9e52 100644 --- a/testing/src/main/java/io/grpc/testing/GrpcCleanupRule.java +++ b/testing/src/main/java/io/grpc/testing/GrpcCleanupRule.java @@ -22,24 +22,23 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Ticker; +import com.google.common.collect.Lists; import io.grpc.ExperimentalApi; import io.grpc.ManagedChannel; import io.grpc.Server; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; -import org.junit.rules.TestRule; +import org.junit.rules.ExternalResource; import org.junit.runner.Description; -import org.junit.runners.model.MultipleFailureException; import org.junit.runners.model.Statement; /** - * A JUnit {@link TestRule} that can register gRPC resources and manages its automatic release at - * the end of the test. If any of the resources registered to the rule can not be successfully - * released, the test will fail. + * A JUnit {@link ExternalResource} that can register gRPC resources and manages its automatic + * release at the end of the test. If any of the resources registered to the rule can not be + * successfully released, the test will fail. * *

    Example usage: *

    {@code @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
    @@ -73,13 +72,13 @@
      */
     @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2488")
     @NotThreadSafe
    -public final class GrpcCleanupRule implements TestRule {
    +public final class GrpcCleanupRule extends ExternalResource {
     
       private final List resources = new ArrayList<>();
       private long timeoutNanos = TimeUnit.SECONDS.toNanos(10L);
       private Stopwatch stopwatch = Stopwatch.createUnstarted();
     
    -  private Throwable firstException;
    +  private boolean abruptShutdown;
     
       /**
        * Sets a positive total time limit for the automatic resource cleanup. If any of the resources
    @@ -144,71 +143,70 @@ void register(Resource resource) {
         resources.add(resource);
       }
     
    +  // The class extends ExternalResource so it can be used in JUnit 5. But JUnit 5 will only call
    +  // before() and after(), thus code cannot assume this method will be called.
       @Override
       public Statement apply(final Statement base, Description description) {
    -    return new Statement() {
    +    return super.apply(new Statement() {
           @Override
           public void evaluate() throws Throwable {
    -        firstException = null;
    +        abruptShutdown = false;
             try {
               base.evaluate();
             } catch (Throwable t) {
    -          firstException = t;
    -
    -          try {
    -            teardown();
    -          } catch (Throwable t2) {
    -            throw new MultipleFailureException(Arrays.asList(t, t2));
    -          }
    -
    +          abruptShutdown = true;
               throw t;
             }
    -
    -        teardown();
    -        if (firstException != null) {
    -          throw firstException;
    -        }
           }
    -    };
    +    }, description);
       }
     
       /**
        * Releases all the registered resources.
        */
    -  private void teardown() {
    +  @Override
    +  protected void after() {
         stopwatch.reset();
         stopwatch.start();
     
    -    if (firstException == null) {
    +    InterruptedException interrupted = null;
    +    if (!abruptShutdown) {
    +      for (Resource resource : Lists.reverse(resources)) {
    +        resource.cleanUp();
    +      }
    +
           for (int i = resources.size() - 1; i >= 0; i--) {
    -        resources.get(i).cleanUp();
    +        try {
    +          boolean released = resources.get(i).awaitReleased(
    +              timeoutNanos - stopwatch.elapsed(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS);
    +          if (released) {
    +            resources.remove(i);
    +          }
    +        } catch (InterruptedException e) {
    +          Thread.currentThread().interrupt();
    +          interrupted = e;
    +          break;
    +        }
           }
         }
     
    -    for (int i = resources.size() - 1; i >= 0; i--) {
    -      if (firstException != null) {
    -        resources.get(i).forceCleanUp();
    -        continue;
    +    if (!resources.isEmpty()) {
    +      for (Resource resource : Lists.reverse(resources)) {
    +        resource.forceCleanUp();
           }
     
           try {
    -        boolean released = resources.get(i).awaitReleased(
    -            timeoutNanos - stopwatch.elapsed(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS);
    -        if (!released) {
    -          firstException = new AssertionError(
    -              "Resource " + resources.get(i) + " can not be released in time at the end of test");
    +        if (interrupted != null) {
    +          throw new AssertionError(
    +              "Thread interrupted before resources gracefully released", interrupted);
    +        } else if (!abruptShutdown) {
    +          throw new AssertionError(
    +            "Resources could not be released in time at the end of test: " + resources);
             }
    -      } catch (InterruptedException e) {
    -        Thread.currentThread().interrupt();
    -        firstException = e;
    -      }
    -
    -      if (firstException != null) {
    -        resources.get(i).forceCleanUp();
    +      } finally {
    +        resources.clear();
           }
         }
    -
    -    resources.clear();
       }
     
       @VisibleForTesting
    diff --git a/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java b/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java
    index 3be53ae8b8b..a5a6783d53f 100644
    --- a/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java
    +++ b/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java
    @@ -248,13 +248,13 @@ public void multiResource_awaitReleasedFails() throws Throwable {
     
         inOrder.verify(resource3).awaitReleased(anyLong(), any(TimeUnit.class));
         inOrder.verify(resource2).awaitReleased(anyLong(), any(TimeUnit.class));
    +    inOrder.verify(resource1).awaitReleased(anyLong(), any(TimeUnit.class));
         inOrder.verify(resource2).forceCleanUp();
    -    inOrder.verify(resource1).forceCleanUp();
     
         inOrder.verifyNoMoreInteractions();
     
         verify(resource3, never()).forceCleanUp();
    -    verify(resource1, never()).awaitReleased(anyLong(), any(TimeUnit.class));
    +    verify(resource1, never()).forceCleanUp();
       }
     
       @Test
    @@ -280,7 +280,7 @@ public void multiResource_awaitReleasedInterrupted() throws Throwable {
         boolean cleanupFailed = false;
         try {
           grpcCleanup.apply(statement, null /* description*/).evaluate();
    -    } catch (InterruptedException e) {
    +    } catch (Throwable e) {
           cleanupFailed = true;
         }
     
    @@ -381,7 +381,8 @@ public void multiResource_timeoutCalculation_customTimeout() throws Throwable {
       @Test
       public void baseTestFailsThenCleanupFails() throws Throwable {
         // setup
    -    Exception baseTestFailure = new Exception();
    +    Exception baseTestFailure = new Exception("base test failure");
    +    Exception cleanupFailure = new RuntimeException("force cleanup failed");
     
         Statement statement = mock(Statement.class);
         doThrow(baseTestFailure).when(statement).evaluate();
    @@ -389,7 +390,7 @@ public void baseTestFailsThenCleanupFails() throws Throwable {
         Resource resource1 = mock(Resource.class);
         Resource resource2 = mock(Resource.class);
         Resource resource3 = mock(Resource.class);
    -    doThrow(new RuntimeException()).when(resource2).forceCleanUp();
    +    doThrow(cleanupFailure).when(resource2).forceCleanUp();
     
         InOrder inOrder = inOrder(statement, resource1, resource2, resource3);
         GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
    @@ -407,8 +408,14 @@ public void baseTestFailsThenCleanupFails() throws Throwable {
         }
     
         // verify
    -    assertThat(failure).isInstanceOf(MultipleFailureException.class);
    -    assertSame(baseTestFailure, ((MultipleFailureException) failure).getFailures().get(0));
    +    if (failure instanceof MultipleFailureException) {
    +      // JUnit 4.13+
    +      assertThat(((MultipleFailureException) failure).getFailures())
    +          .containsExactly(baseTestFailure, cleanupFailure);
    +    } else {
    +      // JUnit 4.12. Suffers from https://github.com/junit-team/junit4/issues/1334
    +      assertThat(failure).isSameInstanceAs(cleanupFailure);
    +    }
     
         inOrder.verify(statement).evaluate();
         inOrder.verify(resource3).forceCleanUp();
    diff --git a/xds/BUILD.bazel b/xds/BUILD.bazel
    new file mode 100644
    index 00000000000..e62b183f9e8
    --- /dev/null
    +++ b/xds/BUILD.bazel
    @@ -0,0 +1,166 @@
    +load("//:java_grpc_library.bzl", "java_grpc_library")
    +
    +# Mirrors the dependencies included in the artifact on Maven Central for usage
    +# with maven_install's override_targets. Should only be used as a dep for
    +# pre-compiled binaries on Maven Central.
    +java_library(
    +    name = "xds_maven",
    +    visibility = ["//visibility:public"],
    +    exports = [
    +        ":orca",
    +        ":xds",
    +    ],
    +)
    +
    +java_library(
    +    name = "xds",
    +    srcs = glob(
    +        [
    +            "src/main/java/**/*.java",
    +            "third_party/zero-allocation-hashing/main/java/**/*.java",
    +        ],
    +        exclude = ["src/main/java/io/grpc/xds/orca/**"],
    +    ),
    +    resources = glob([
    +        "src/main/resources/**",
    +    ]),
    +    visibility = ["//visibility:public"],
    +    deps = [
    +        ":envoy_service_discovery_v2_java_grpc",
    +        ":envoy_service_discovery_v3_java_grpc",
    +        ":envoy_service_load_stats_v2_java_grpc",
    +        ":envoy_service_load_stats_v3_java_grpc",
    +        ":envoy_service_status_v3_java_grpc",
    +        ":xds_protos_java",
    +        "//:auto_value_annotations",
    +        "//alts",
    +        "//api",
    +        "//context",
    +        "//core:internal",
    +        "//core:util",
    +        "//netty",
    +        "//stub",
    +        "@com_google_code_findbugs_jsr305//jar",
    +        "@com_google_code_gson_gson//jar",
    +        "@com_google_errorprone_error_prone_annotations//jar",
    +        "@com_google_googleapis//google/rpc:rpc_java_proto",
    +        "@com_google_guava_guava//jar",
    +        "@com_google_protobuf//:protobuf_java",
    +        "@com_google_protobuf//:protobuf_java_util",
    +        "@com_google_re2j_re2j//jar",
    +        "@io_netty_netty_buffer//jar",
    +        "@io_netty_netty_codec//jar",
    +        "@io_netty_netty_common//jar",
    +        "@io_netty_netty_handler//jar",
    +        "@io_netty_netty_transport//jar",
    +    ],
    +)
    +
    +java_proto_library(
    +    name = "xds_protos_java",
    +    deps = [
    +        "@com_github_cncf_udpa//udpa/type/v1:pkg",
    +        "@com_github_cncf_xds//xds/data/orca/v3:pkg",
    +        "@com_github_cncf_xds//xds/service/orca/v3:pkg",
    +        "@com_github_cncf_xds//xds/type/v3:pkg",
    +        "@envoy_api//envoy/admin/v3:pkg",
    +        "@envoy_api//envoy/api/v2:pkg",
    +        "@envoy_api//envoy/api/v2/core:pkg",
    +        "@envoy_api//envoy/api/v2/endpoint:pkg",
    +        "@envoy_api//envoy/config/cluster/aggregate/v2alpha:pkg",
    +        "@envoy_api//envoy/config/cluster/v3:pkg",
    +        "@envoy_api//envoy/config/core/v3:pkg",
    +        "@envoy_api//envoy/config/endpoint/v3:pkg",
    +        "@envoy_api//envoy/config/filter/http/fault/v2:pkg",
    +        "@envoy_api//envoy/config/filter/http/router/v2:pkg",
    +        "@envoy_api//envoy/config/filter/network/http_connection_manager/v2:pkg",
    +        "@envoy_api//envoy/config/listener/v3:pkg",
    +        "@envoy_api//envoy/config/rbac/v3:pkg",
    +        "@envoy_api//envoy/config/route/v3:pkg",
    +        "@envoy_api//envoy/extensions/clusters/aggregate/v3:pkg",
    +        "@envoy_api//envoy/extensions/filters/common/fault/v3:pkg",
    +        "@envoy_api//envoy/extensions/filters/http/fault/v3:pkg",
    +        "@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg",
    +        "@envoy_api//envoy/extensions/filters/http/router/v3:pkg",
    +        "@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg",
    +        "@envoy_api//envoy/extensions/load_balancing_policies/least_request/v3:pkg",
    +        "@envoy_api//envoy/extensions/load_balancing_policies/ring_hash/v3:pkg",
    +        "@envoy_api//envoy/extensions/load_balancing_policies/round_robin/v3:pkg",
    +        "@envoy_api//envoy/extensions/load_balancing_policies/wrr_locality/v3:pkg",
    +        "@envoy_api//envoy/extensions/transport_sockets/tls/v3:pkg",
    +        "@envoy_api//envoy/service/discovery/v2:pkg",
    +        "@envoy_api//envoy/service/discovery/v3:pkg",
    +        "@envoy_api//envoy/service/load_stats/v2:pkg",
    +        "@envoy_api//envoy/service/load_stats/v3:pkg",
    +        "@envoy_api//envoy/service/status/v3:pkg",
    +        "@envoy_api//envoy/type/matcher/v3:pkg",
    +        "@envoy_api//envoy/type/v3:pkg",
    +    ],
    +)
    +
    +java_grpc_library(
    +    name = "envoy_service_discovery_v2_java_grpc",
    +    srcs = ["@envoy_api//envoy/service/discovery/v2:pkg"],
    +    deps = [":xds_protos_java"],
    +)
    +
    +java_grpc_library(
    +    name = "envoy_service_discovery_v3_java_grpc",
    +    srcs = ["@envoy_api//envoy/service/discovery/v3:pkg"],
    +    deps = [":xds_protos_java"],
    +)
    +
    +java_grpc_library(
    +    name = "envoy_service_load_stats_v2_java_grpc",
    +    srcs = ["@envoy_api//envoy/service/load_stats/v2:pkg"],
    +    deps = [":xds_protos_java"],
    +)
    +
    +java_grpc_library(
    +    name = "envoy_service_load_stats_v3_java_grpc",
    +    srcs = ["@envoy_api//envoy/service/load_stats/v3:pkg"],
    +    deps = [":xds_protos_java"],
    +)
    +
    +java_grpc_library(
    +    name = "envoy_service_status_v3_java_grpc",
    +    srcs = ["@envoy_api//envoy/service/status/v3:pkg"],
    +    deps = [":xds_protos_java"],
    +)
    +
    +java_library(
    +    name = "orca",
    +    srcs = glob([
    +        "src/main/java/io/grpc/xds/orca/*.java",
    +    ]),
    +    visibility = ["//visibility:public"],
    +    deps = [
    +        ":orca_protos_java",
    +        ":xds_service_orca_v3_java_grpc",
    +        "//api",
    +        "//context",
    +        "//core:internal",
    +        "//core:util",
    +        "//protobuf",
    +        "//services:metrics",
    +        "//services:metrics_internal",
    +        "//stub",
    +        "@com_google_code_findbugs_jsr305//jar",
    +        "@com_google_guava_guava//jar",
    +        "@com_google_protobuf//:protobuf_java_util",
    +    ],
    +)
    +
    +java_proto_library(
    +    name = "orca_protos_java",
    +    deps = [
    +        "@com_github_cncf_xds//xds/data/orca/v3:pkg",
    +        "@com_github_cncf_xds//xds/service/orca/v3:pkg",
    +    ],
    +)
    +
    +java_grpc_library(
    +    name = "xds_service_orca_v3_java_grpc",
    +    srcs = ["@com_github_cncf_xds//xds/service/orca/v3:pkg"],
    +    deps = [":orca_protos_java"],
    +)
    diff --git a/xds/build.gradle b/xds/build.gradle
    index 60bd4c7c60d..764dc530d97 100644
    --- a/xds/build.gradle
    +++ b/xds/build.gradle
    @@ -11,85 +11,135 @@ plugins {
     
     description = "gRPC: XDS plugin"
     
    -[compileJava].each() {
    -    it.options.compilerArgs += [
    -        // valueOf(int) in RoutingPriority has been deprecated
    -        "-Xlint:-deprecation",
    -        // only has AutoValue annotation processor
    -        "-Xlint:-processing",
    -    ]
    -    appendToProperty(
    -            it.options.errorprone.excludedPaths,
    -            ".*/build/generated/sources/annotationProcessor/java/.*",
    -            "|")
    +evaluationDependsOn(project(':grpc-core').path)
    +
    +sourceSets {
    +    thirdparty {
    +        java {
    +            srcDir "${projectDir}/third_party/zero-allocation-hashing/main/java"
    +        }
    +        proto {
    +            srcDir 'third_party/envoy/src/main/proto'
    +            srcDir 'third_party/protoc-gen-validate/src/main/proto'
    +            srcDir 'third_party/xds/src/main/proto'
    +            srcDir 'third_party/googleapis/src/main/proto'
    +            srcDir 'third_party/istio/src/main/proto'
    +        }
    +    }
    +    test {
    +        java {
    +            srcDir "${projectDir}/third_party/zero-allocation-hashing/test/java"
    +        }
    +    }
     }
     
    -evaluationDependsOn(project(':grpc-core').path)
    +configurations {
    +    pomDeps {
    +        extendsFrom configurations.thirdpartyRuntimeClasspath, configurations.shadow
    +    }
    +}
     
     dependencies {
    -    implementation project(':grpc-protobuf'),
    +    thirdpartyCompileOnly libraries.javax.annotation
    +    thirdpartyImplementation project(':grpc-protobuf'),
                 project(':grpc-stub'),
    +            libraries.opencensus.proto
    +    implementation sourceSets.thirdparty.output
    +    implementation project(':grpc-stub'),
                 project(':grpc-core'),
                 project(':grpc-services'),
                 project(':grpc-auth'),
                 project(path: ':grpc-alts', configuration: 'shadow'),
                 libraries.gson,
                 libraries.re2j,
    -            libraries.bouncycastle,
    -            libraries.autovalue_annotation,
    -            libraries.opencensus_proto,
    -            libraries.protobuf_util
    +            libraries.auto.value.annotations,
    +            libraries.protobuf.java.util
         def nettyDependency = implementation project(':grpc-netty')
     
         testImplementation project(':grpc-rls')
         testImplementation project(':grpc-core').sourceSets.test.output
     
    -    annotationProcessor libraries.autovalue
    -    compileOnly libraries.javax_annotation,
    -            // At runtime use the epoll included in grpc-netty-shaded
    -            libraries.netty_epoll
    +    annotationProcessor libraries.auto.value
    +    // At runtime use the epoll included in grpc-netty-shaded
    +    compileOnly libraries.netty.transport.epoll
     
         testImplementation project(':grpc-testing'),
    -            project(':grpc-testing-proto'),
    -            libraries.netty_epoll
    -    testImplementation (libraries.guava_testlib) {
    +            project(':grpc-testing-proto')
    +    testImplementation (libraries.netty.transport.epoll) {
    +        artifact {
    +            classifier = "linux-x86_64"
    +        }
    +    }
    +    testImplementation (libraries.guava.testlib) {
             exclude group: 'junit', module: 'junit'
         }
     
         shadow configurations.implementation.getDependencies().minus([nettyDependency])
         shadow project(path: ':grpc-netty-shaded', configuration: 'shadow')
     
    -    signature "org.codehaus.mojo.signature:java17:1.0@signature"
    -    testRuntimeOnly libraries.netty_tcnative
    -}
    -
    -sourceSets {
    -    main {
    -        java {
    -            srcDir "${projectDir}/third_party/zero-allocation-hashing/main/java"
    +    signature libraries.signature.java
    +    testRuntimeOnly libraries.netty.tcnative,
    +            libraries.netty.tcnative.classes
    +    testRuntimeOnly (libraries.netty.tcnative) {
    +        artifact {
    +            classifier = "linux-x86_64"
             }
    -        proto {
    -            srcDir 'third_party/envoy/src/main/proto'
    -            srcDir 'third_party/protoc-gen-validate/src/main/proto'
    -            srcDir 'third_party/xds/src/main/proto'
    -            srcDir 'third_party/googleapis/src/main/proto'
    -            srcDir 'third_party/istio/src/main/proto'
    +    }
    +    testRuntimeOnly (libraries.netty.tcnative) {
    +        artifact {
    +            classifier = "linux-aarch_64"
             }
         }
    -    test {
    -        java {
    -            srcDir "${projectDir}/third_party/zero-allocation-hashing/test/java"
    +    testRuntimeOnly (libraries.netty.tcnative) {
    +        artifact {
    +            classifier = "osx-x86_64"
    +        }
    +    }
    +    testRuntimeOnly (libraries.netty.tcnative) {
    +        artifact {
    +            classifier = "osx-aarch_64"
    +        }
    +    }
    +    testRuntimeOnly (libraries.netty.tcnative) {
    +        artifact {
    +            classifier = "windows-x86_64"
             }
         }
     }
     
     configureProtoCompilation()
     
    -jar {
    +tasks.named("compileThirdpartyJava").configure {
    +    options.errorprone.enabled = false
    +    options.compilerArgs += [
    +        // valueOf(int) in RoutingPriority has been deprecated
    +        "-Xlint:-deprecation",
    +    ]
    +}
    +
    +tasks.named("checkstyleThirdparty").configure {
    +    enabled = false
    +}
    +
    +tasks.named("compileJava").configure {
    +    it.options.compilerArgs += [
    +        // TODO: remove
    +        "-Xlint:-deprecation",
    +        // only has AutoValue annotation processor
    +        "-Xlint:-processing",
    +    ]
    +    appendToProperty(
    +            it.options.errorprone.excludedPaths,
    +            ".*/build/generated/sources/annotationProcessor/java/.*",
    +            "|")
    +}
    +
    +tasks.named("jar").configure {
         archiveClassifier = 'original'
    +    from sourceSets.thirdparty.output
     }
     
    -javadoc {
    +tasks.named("javadoc").configure {
         // Exclusions here should generally also be relocated
         exclude 'com/github/udpa/**'
         exclude 'com/github/xds/**'
    @@ -109,7 +159,7 @@ javadoc {
     }
     
     def prefixName = 'io.grpc.xds'
    -shadowJar {
    +tasks.named("shadowJar").configure {
         archiveClassifier = null
         dependencies {
             include(project(':grpc-xds'))
    @@ -130,7 +180,9 @@ shadowJar {
         exclude "**/*.proto"
     }
     
    -task checkPackageLeakage(dependsOn: shadowJar) {
    +def checkPackageLeakage = tasks.register("checkPackageLeakage") {
    +    inputs.files(shadowJar).withNormalizer(CompileClasspathNormalizer)
    +    outputs.file("${buildDir}/tmp/${name}") // Fake output for UP-TO-DATE checking
         doLast {
             def jarEntryPrefixName = prefixName.replaceAll('\\.', '/')
             shadowJar.outputs.getFiles().each { jar ->
    @@ -153,11 +205,11 @@ task checkPackageLeakage(dependsOn: shadowJar) {
         }
     }
     
    -test {
    +tasks.named("test").configure {
         dependsOn checkPackageLeakage
     }
     
    -jacocoTestReport {
    +tasks.named("jacocoTestReport").configure {
         classDirectories.from = sourceSets.main.output.collect {
             fileTree(dir: it,
             exclude: [ // Exclusions here should generally also be relocated
    @@ -182,7 +234,11 @@ publishing {
     
                 pom.withXml {
                     def dependenciesNode = new Node(null, 'dependencies')
    -                project.configurations.shadow.allDependencies.each { dep ->
    +                project.configurations.pomDeps.allDependencies.each { dep ->
    +                    if (dep.group == null && dep.name == 'unspecified') {
    +                        // Ignore the thirdparty self-dependency
    +                        return;
    +                    }
                         def dependencyNode = dependenciesNode.appendNode('dependency')
                         dependencyNode.appendNode('groupId', dep.group)
                         dependencyNode.appendNode('artifactId', dep.name)
    diff --git a/xds/src/generated/main/grpc/com/google/security/meshca/v1/MeshCertificateServiceGrpc.java b/xds/src/generated/main/grpc/com/google/security/meshca/v1/MeshCertificateServiceGrpc.java
    deleted file mode 100644
    index 939188ace73..00000000000
    --- a/xds/src/generated/main/grpc/com/google/security/meshca/v1/MeshCertificateServiceGrpc.java
    +++ /dev/null
    @@ -1,307 +0,0 @@
    -package com.google.security.meshca.v1;
    -
    -import static io.grpc.MethodDescriptor.generateFullMethodName;
    -
    -/**
    - * 
    - * Service for managing certificates issued by the CSM CA.
    - * 
    - */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: security/proto/providers/google/meshca.proto") -@io.grpc.stub.annotations.GrpcGenerated -public final class MeshCertificateServiceGrpc { - - private MeshCertificateServiceGrpc() {} - - public static final String SERVICE_NAME = "google.security.meshca.v1.MeshCertificateService"; - - // Static method descriptors that strictly reflect the proto. - private static volatile io.grpc.MethodDescriptor getCreateCertificateMethod; - - @io.grpc.stub.annotations.RpcMethod( - fullMethodName = SERVICE_NAME + '/' + "CreateCertificate", - requestType = com.google.security.meshca.v1.MeshCertificateRequest.class, - responseType = com.google.security.meshca.v1.MeshCertificateResponse.class, - methodType = io.grpc.MethodDescriptor.MethodType.UNARY) - public static io.grpc.MethodDescriptor getCreateCertificateMethod() { - io.grpc.MethodDescriptor getCreateCertificateMethod; - if ((getCreateCertificateMethod = MeshCertificateServiceGrpc.getCreateCertificateMethod) == null) { - synchronized (MeshCertificateServiceGrpc.class) { - if ((getCreateCertificateMethod = MeshCertificateServiceGrpc.getCreateCertificateMethod) == null) { - MeshCertificateServiceGrpc.getCreateCertificateMethod = getCreateCertificateMethod = - io.grpc.MethodDescriptor.newBuilder() - .setType(io.grpc.MethodDescriptor.MethodType.UNARY) - .setFullMethodName(generateFullMethodName(SERVICE_NAME, "CreateCertificate")) - .setSampledToLocalTracing(true) - .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( - com.google.security.meshca.v1.MeshCertificateRequest.getDefaultInstance())) - .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( - com.google.security.meshca.v1.MeshCertificateResponse.getDefaultInstance())) - .setSchemaDescriptor(new MeshCertificateServiceMethodDescriptorSupplier("CreateCertificate")) - .build(); - } - } - } - return getCreateCertificateMethod; - } - - /** - * Creates a new async stub that supports all call types for the service - */ - public static MeshCertificateServiceStub newStub(io.grpc.Channel channel) { - io.grpc.stub.AbstractStub.StubFactory factory = - new io.grpc.stub.AbstractStub.StubFactory() { - @java.lang.Override - public MeshCertificateServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - return new MeshCertificateServiceStub(channel, callOptions); - } - }; - return MeshCertificateServiceStub.newStub(factory, channel); - } - - /** - * Creates a new blocking-style stub that supports unary and streaming output calls on the service - */ - public static MeshCertificateServiceBlockingStub newBlockingStub( - io.grpc.Channel channel) { - io.grpc.stub.AbstractStub.StubFactory factory = - new io.grpc.stub.AbstractStub.StubFactory() { - @java.lang.Override - public MeshCertificateServiceBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - return new MeshCertificateServiceBlockingStub(channel, callOptions); - } - }; - return MeshCertificateServiceBlockingStub.newStub(factory, channel); - } - - /** - * Creates a new ListenableFuture-style stub that supports unary calls on the service - */ - public static MeshCertificateServiceFutureStub newFutureStub( - io.grpc.Channel channel) { - io.grpc.stub.AbstractStub.StubFactory factory = - new io.grpc.stub.AbstractStub.StubFactory() { - @java.lang.Override - public MeshCertificateServiceFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - return new MeshCertificateServiceFutureStub(channel, callOptions); - } - }; - return MeshCertificateServiceFutureStub.newStub(factory, channel); - } - - /** - *
    -   * Service for managing certificates issued by the CSM CA.
    -   * 
    - */ - public static abstract class MeshCertificateServiceImplBase implements io.grpc.BindableService { - - /** - *
    -     * Using provided CSR, returns a signed certificate that represents a GCP
    -     * service account identity.
    -     * 
    - */ - public void createCertificate(com.google.security.meshca.v1.MeshCertificateRequest request, - io.grpc.stub.StreamObserver responseObserver) { - io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getCreateCertificateMethod(), responseObserver); - } - - @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { - return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) - .addMethod( - getCreateCertificateMethod(), - io.grpc.stub.ServerCalls.asyncUnaryCall( - new MethodHandlers< - com.google.security.meshca.v1.MeshCertificateRequest, - com.google.security.meshca.v1.MeshCertificateResponse>( - this, METHODID_CREATE_CERTIFICATE))) - .build(); - } - } - - /** - *
    -   * Service for managing certificates issued by the CSM CA.
    -   * 
    - */ - public static final class MeshCertificateServiceStub extends io.grpc.stub.AbstractAsyncStub { - private MeshCertificateServiceStub( - io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - super(channel, callOptions); - } - - @java.lang.Override - protected MeshCertificateServiceStub build( - io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - return new MeshCertificateServiceStub(channel, callOptions); - } - - /** - *
    -     * Using provided CSR, returns a signed certificate that represents a GCP
    -     * service account identity.
    -     * 
    - */ - public void createCertificate(com.google.security.meshca.v1.MeshCertificateRequest request, - io.grpc.stub.StreamObserver responseObserver) { - io.grpc.stub.ClientCalls.asyncUnaryCall( - getChannel().newCall(getCreateCertificateMethod(), getCallOptions()), request, responseObserver); - } - } - - /** - *
    -   * Service for managing certificates issued by the CSM CA.
    -   * 
    - */ - public static final class MeshCertificateServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { - private MeshCertificateServiceBlockingStub( - io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - super(channel, callOptions); - } - - @java.lang.Override - protected MeshCertificateServiceBlockingStub build( - io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - return new MeshCertificateServiceBlockingStub(channel, callOptions); - } - - /** - *
    -     * Using provided CSR, returns a signed certificate that represents a GCP
    -     * service account identity.
    -     * 
    - */ - public com.google.security.meshca.v1.MeshCertificateResponse createCertificate(com.google.security.meshca.v1.MeshCertificateRequest request) { - return io.grpc.stub.ClientCalls.blockingUnaryCall( - getChannel(), getCreateCertificateMethod(), getCallOptions(), request); - } - } - - /** - *
    -   * Service for managing certificates issued by the CSM CA.
    -   * 
    - */ - public static final class MeshCertificateServiceFutureStub extends io.grpc.stub.AbstractFutureStub { - private MeshCertificateServiceFutureStub( - io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - super(channel, callOptions); - } - - @java.lang.Override - protected MeshCertificateServiceFutureStub build( - io.grpc.Channel channel, io.grpc.CallOptions callOptions) { - return new MeshCertificateServiceFutureStub(channel, callOptions); - } - - /** - *
    -     * Using provided CSR, returns a signed certificate that represents a GCP
    -     * service account identity.
    -     * 
    - */ - public com.google.common.util.concurrent.ListenableFuture createCertificate( - com.google.security.meshca.v1.MeshCertificateRequest request) { - return io.grpc.stub.ClientCalls.futureUnaryCall( - getChannel().newCall(getCreateCertificateMethod(), getCallOptions()), request); - } - } - - private static final int METHODID_CREATE_CERTIFICATE = 0; - - private static final class MethodHandlers implements - io.grpc.stub.ServerCalls.UnaryMethod, - io.grpc.stub.ServerCalls.ServerStreamingMethod, - io.grpc.stub.ServerCalls.ClientStreamingMethod, - io.grpc.stub.ServerCalls.BidiStreamingMethod { - private final MeshCertificateServiceImplBase serviceImpl; - private final int methodId; - - MethodHandlers(MeshCertificateServiceImplBase serviceImpl, int methodId) { - this.serviceImpl = serviceImpl; - this.methodId = methodId; - } - - @java.lang.Override - @java.lang.SuppressWarnings("unchecked") - public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { - switch (methodId) { - case METHODID_CREATE_CERTIFICATE: - serviceImpl.createCertificate((com.google.security.meshca.v1.MeshCertificateRequest) request, - (io.grpc.stub.StreamObserver) responseObserver); - break; - default: - throw new AssertionError(); - } - } - - @java.lang.Override - @java.lang.SuppressWarnings("unchecked") - public io.grpc.stub.StreamObserver invoke( - io.grpc.stub.StreamObserver responseObserver) { - switch (methodId) { - default: - throw new AssertionError(); - } - } - } - - private static abstract class MeshCertificateServiceBaseDescriptorSupplier - implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { - MeshCertificateServiceBaseDescriptorSupplier() {} - - @java.lang.Override - public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { - return com.google.security.meshca.v1.MeshCaProto.getDescriptor(); - } - - @java.lang.Override - public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { - return getFileDescriptor().findServiceByName("MeshCertificateService"); - } - } - - private static final class MeshCertificateServiceFileDescriptorSupplier - extends MeshCertificateServiceBaseDescriptorSupplier { - MeshCertificateServiceFileDescriptorSupplier() {} - } - - private static final class MeshCertificateServiceMethodDescriptorSupplier - extends MeshCertificateServiceBaseDescriptorSupplier - implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { - private final String methodName; - - MeshCertificateServiceMethodDescriptorSupplier(String methodName) { - this.methodName = methodName; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { - return getServiceDescriptor().findMethodByName(methodName); - } - } - - private static volatile io.grpc.ServiceDescriptor serviceDescriptor; - - public static io.grpc.ServiceDescriptor getServiceDescriptor() { - io.grpc.ServiceDescriptor result = serviceDescriptor; - if (result == null) { - synchronized (MeshCertificateServiceGrpc.class) { - result = serviceDescriptor; - if (result == null) { - serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) - .setSchemaDescriptor(new MeshCertificateServiceFileDescriptorSupplier()) - .addMethod(getCreateCertificateMethod()) - .build(); - } - } - } - return result; - } -} diff --git a/xds/src/generated/main/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java b/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/ClusterDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/ClusterDiscoveryServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/ClusterDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/ClusterDiscoveryServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/EndpointDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/EndpointDiscoveryServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/EndpointDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/EndpointDiscoveryServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/ListenerDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/ListenerDiscoveryServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/ListenerDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/ListenerDiscoveryServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/RouteDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/RouteDiscoveryServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/RouteDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/RouteDiscoveryServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/ScopedRoutesDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/ScopedRoutesDiscoveryServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/ScopedRoutesDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/ScopedRoutesDiscoveryServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/VirtualHostDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/VirtualHostDiscoveryServiceGrpc.java similarity index 100% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/api/v2/VirtualHostDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/api/v2/VirtualHostDiscoveryServiceGrpc.java diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v2/AggregatedDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v2/AggregatedDiscoveryServiceGrpc.java similarity index 97% rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v2/AggregatedDiscoveryServiceGrpc.java rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v2/AggregatedDiscoveryServiceGrpc.java index d66b7a40a39..971a9faf5cc 100644 --- a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v2/AggregatedDiscoveryServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v2/AggregatedDiscoveryServiceGrpc.java @@ -4,7 +4,7 @@ /** *
    - * See https://github.com/lyft/envoy-api#apis for a description of the role of
    + * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
      * ADS and how it is intended to be used by a management server. ADS requests
      * have the same structure as their singleton xDS counterparts, but can
      * multiplex many resource types on a single stream. The type_url in the
    @@ -131,7 +131,7 @@ public AggregatedDiscoveryServiceFutureStub newStub(io.grpc.Channel channel, io.
     
       /**
        * 
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    @@ -180,7 +180,7 @@ public io.grpc.stub.StreamObserver
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    @@ -222,7 +222,7 @@ public io.grpc.stub.StreamObserver
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    @@ -245,7 +245,7 @@ protected AggregatedDiscoveryServiceBlockingStub build(
     
       /**
        * 
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v2/SecretDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v2/SecretDiscoveryServiceGrpc.java
    similarity index 100%
    rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v2/SecretDiscoveryServiceGrpc.java
    rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v2/SecretDiscoveryServiceGrpc.java
    diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java
    similarity index 97%
    rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java
    rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java
    index 21f9f537d58..0cb84f8d277 100644
    --- a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java
    +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java
    @@ -4,7 +4,7 @@
     
     /**
      * 
    - * See https://github.com/lyft/envoy-api#apis for a description of the role of
    + * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
      * ADS and how it is intended to be used by a management server. ADS requests
      * have the same structure as their singleton xDS counterparts, but can
      * multiplex many resource types on a single stream. The type_url in the
    @@ -131,7 +131,7 @@ public AggregatedDiscoveryServiceFutureStub newStub(io.grpc.Channel channel, io.
     
       /**
        * 
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    @@ -180,7 +180,7 @@ public io.grpc.stub.StreamObserver
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    @@ -222,7 +222,7 @@ public io.grpc.stub.StreamObserver
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    @@ -245,7 +245,7 @@ protected AggregatedDiscoveryServiceBlockingStub build(
     
       /**
        * 
    -   * See https://github.com/lyft/envoy-api#apis for a description of the role of
    +   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
        * ADS and how it is intended to be used by a management server. ADS requests
        * have the same structure as their singleton xDS counterparts, but can
        * multiplex many resource types on a single stream. The type_url in the
    diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/load_stats/v2/LoadReportingServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v2/LoadReportingServiceGrpc.java
    similarity index 100%
    rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/service/load_stats/v2/LoadReportingServiceGrpc.java
    rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v2/LoadReportingServiceGrpc.java
    diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java
    similarity index 100%
    rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java
    rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java
    diff --git a/xds/src/generated/main/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java
    similarity index 100%
    rename from xds/src/generated/main/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java
    rename to xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java
    diff --git a/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java b/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java
    index 61780c60a55..b5384616925 100644
    --- a/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java
    +++ b/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java
    @@ -36,16 +36,22 @@
     import io.grpc.SynchronizationContext;
     import io.grpc.SynchronizationContext.ScheduledHandle;
     import io.grpc.internal.BackoffPolicy;
    +import io.grpc.stub.ClientCallStreamObserver;
    +import io.grpc.stub.ClientResponseObserver;
     import io.grpc.stub.StreamObserver;
     import io.grpc.xds.Bootstrapper.ServerInfo;
    -import io.grpc.xds.ClientXdsClient.XdsChannelFactory;
     import io.grpc.xds.EnvoyProtoData.Node;
     import io.grpc.xds.XdsClient.ResourceStore;
     import io.grpc.xds.XdsClient.XdsResponseHandler;
    +import io.grpc.xds.XdsClientImpl.XdsChannelFactory;
     import io.grpc.xds.XdsLogger.XdsLogLevel;
     import java.util.Collection;
     import java.util.Collections;
    +import java.util.HashMap;
    +import java.util.HashSet;
     import java.util.List;
    +import java.util.Map;
    +import java.util.Set;
     import java.util.concurrent.ScheduledExecutorService;
     import java.util.concurrent.TimeUnit;
     import javax.annotation.Nullable;
    @@ -56,22 +62,7 @@
      */
     final class AbstractXdsClient {
     
    -  private static final String ADS_TYPE_URL_LDS_V2 = "type.googleapis.com/envoy.api.v2.Listener";
    -  private static final String ADS_TYPE_URL_LDS =
    -      "type.googleapis.com/envoy.config.listener.v3.Listener";
    -  private static final String ADS_TYPE_URL_RDS_V2 =
    -      "type.googleapis.com/envoy.api.v2.RouteConfiguration";
    -  private static final String ADS_TYPE_URL_RDS =
    -      "type.googleapis.com/envoy.config.route.v3.RouteConfiguration";
    -  @VisibleForTesting
    -  static final String ADS_TYPE_URL_CDS_V2 = "type.googleapis.com/envoy.api.v2.Cluster";
    -  private static final String ADS_TYPE_URL_CDS =
    -      "type.googleapis.com/envoy.config.cluster.v3.Cluster";
    -  private static final String ADS_TYPE_URL_EDS_V2 =
    -      "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment";
    -  private static final String ADS_TYPE_URL_EDS =
    -      "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment";
    -
    +  public static final String CLOSED_BY_SERVER = "Closed by server";
       private final SynchronizationContext syncContext;
       private final InternalLogId logId;
       private final XdsLogger logger;
    @@ -84,14 +75,12 @@ final class AbstractXdsClient {
       private final BackoffPolicy.Provider backoffPolicyProvider;
       private final Stopwatch stopwatch;
       private final Node bootstrapNode;
    +  private final XdsClient.TimerLaunch timerLaunch;
     
       // Last successfully applied version_info for each resource type. Starts with empty string.
       // A version_info is used to update management server with client's most recent knowledge of
       // resources.
    -  private String ldsVersion = "";
    -  private String rdsVersion = "";
    -  private String cdsVersion = "";
    -  private String edsVersion = "";
    +  private final Map, String> versions = new HashMap<>();
     
       private boolean shutdown;
       @Nullable
    @@ -114,7 +103,8 @@ final class AbstractXdsClient {
           timeService,
           SynchronizationContext syncContext,
           BackoffPolicy.Provider backoffPolicyProvider,
    -      Supplier stopwatchSupplier) {
    +      Supplier stopwatchSupplier,
    +      XdsClient.TimerLaunch timerLaunch) {
         this.serverInfo = checkNotNull(serverInfo, "serverInfo");
         this.channel = checkNotNull(xdsChannelFactory, "xdsChannelFactory").create(serverInfo);
         this.xdsResponseHandler = checkNotNull(xdsResponseHandler, "xdsResponseHandler");
    @@ -124,6 +114,7 @@ final class AbstractXdsClient {
         this.timeService = checkNotNull(timeService, "timeService");
         this.syncContext = checkNotNull(syncContext, "syncContext");
         this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
    +    this.timerLaunch  = checkNotNull(timerLaunch, "timerLaunch");
         stopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get();
         logId = InternalLogId.allocate("xds-client", serverInfo.target());
         logger = XdsLogger.withLogId(logId);
    @@ -162,16 +153,16 @@ public String toString() {
        * Updates the resource subscription for the given resource type.
        */
       // Must be synchronized.
    -  void adjustResourceSubscription(ResourceType type) {
    +  void adjustResourceSubscription(XdsResourceType resourceType) {
         if (isInBackoff()) {
           return;
         }
         if (adsStream == null) {
           startRpcStream();
         }
    -    Collection resources = resourceStore.getSubscribedResources(serverInfo, type);
    +    Collection resources = resourceStore.getSubscribedResources(serverInfo, resourceType);
         if (resources != null) {
    -      adsStream.sendDiscoveryRequest(type, resources);
    +      adsStream.sendDiscoveryRequest(resourceType, resources);
         }
       }
     
    @@ -180,26 +171,10 @@ void adjustResourceSubscription(ResourceType type) {
        * and sends an ACK request to the management server.
        */
       // Must be synchronized.
    -  void ackResponse(ResourceType type, String versionInfo, String nonce) {
    -    switch (type) {
    -      case LDS:
    -        ldsVersion = versionInfo;
    -        break;
    -      case RDS:
    -        rdsVersion = versionInfo;
    -        break;
    -      case CDS:
    -        cdsVersion = versionInfo;
    -        break;
    -      case EDS:
    -        edsVersion = versionInfo;
    -        break;
    -      case UNKNOWN:
    -      default:
    -        throw new AssertionError("Unknown resource type: " + type);
    -    }
    +  void ackResponse(XdsResourceType type, String versionInfo, String nonce) {
    +    versions.put(type, versionInfo);
         logger.log(XdsLogLevel.INFO, "Sending ACK for {0} update, nonce: {1}, current version: {2}",
    -        type, nonce, versionInfo);
    +        type.typeName(), nonce, versionInfo);
         Collection resources = resourceStore.getSubscribedResources(serverInfo, type);
         if (resources == null) {
           resources = Collections.emptyList();
    @@ -212,10 +187,10 @@ void ackResponse(ResourceType type, String versionInfo, String nonce) {
        * accepted version) to the management server.
        */
       // Must be synchronized.
    -  void nackResponse(ResourceType type, String nonce, String errorDetail) {
    -    String versionInfo = getCurrentVersion(type);
    +  void nackResponse(XdsResourceType type, String nonce, String errorDetail) {
    +    String versionInfo = versions.getOrDefault(type, "");
         logger.log(XdsLogLevel.INFO, "Sending NACK for {0} update, nonce: {1}, current version: {2}",
    -        type, nonce, versionInfo);
    +        type.typeName(), nonce, versionInfo);
         Collection resources = resourceStore.getSubscribedResources(serverInfo, type);
         if (resources == null) {
           resources = Collections.emptyList();
    @@ -231,6 +206,27 @@ boolean isInBackoff() {
         return rpcRetryTimer != null && rpcRetryTimer.isPending();
       }
     
    +  boolean isReady() {
    +    return adsStream != null && adsStream.isReady();
    +  }
    +
    +  /**
    +   * Starts a timer for each requested resource that hasn't been responded to and
    +   * has been waiting for the channel to get ready.
    +   */
    +  void readyHandler() {
    +    if (!isReady()) {
    +      return;
    +    }
    +
    +    if (isInBackoff()) {
    +      rpcRetryTimer.cancel();
    +      rpcRetryTimer = null;
    +    }
    +
    +    timerLaunch.startSubscriberTimersIfNeeded(serverInfo);
    +  }
    +
       /**
        * Establishes the RPC connection by creating a new RPC stream on the given channel for
        * xDS protocol communication.
    @@ -238,11 +234,7 @@ boolean isInBackoff() {
       // Must be synchronized.
       private void startRpcStream() {
         checkState(adsStream == null, "Previous adsStream has not been cleared yet");
    -    if (serverInfo.useProtocolV3()) {
    -      adsStream = new AdsStreamV3();
    -    } else {
    -      adsStream = new AdsStreamV2();
    -    }
    +    adsStream = new AdsStreamV3();
         Context prevContext = context.attach();
         try {
           adsStream.start();
    @@ -253,30 +245,6 @@ private void startRpcStream() {
         stopwatch.reset().start();
       }
     
    -  /** Returns the latest accepted version of the given resource type. */
    -  // Must be synchronized.
    -  String getCurrentVersion(ResourceType type) {
    -    String version;
    -    switch (type) {
    -      case LDS:
    -        version = ldsVersion;
    -        break;
    -      case RDS:
    -        version = rdsVersion;
    -        break;
    -      case CDS:
    -        version = cdsVersion;
    -        break;
    -      case EDS:
    -        version = edsVersion;
    -        break;
    -      case UNKNOWN:
    -      default:
    -        throw new AssertionError("Unknown resource type: " + type);
    -    }
    -    return version;
    -  }
    -
       @VisibleForTesting
       final class RpcRetryTask implements Runnable {
         @Override
    @@ -285,10 +253,9 @@ public void run() {
             return;
           }
           startRpcStream();
    -      for (ResourceType type : ResourceType.values()) {
    -        if (type == ResourceType.UNKNOWN) {
    -          continue;
    -        }
    +      Set> subscribedResourceTypes =
    +          new HashSet<>(resourceStore.getSubscribedResourceTypesWithTypeUrl().values());
    +      for (XdsResourceType type : subscribedResourceTypes) {
             Collection resources = resourceStore.getSubscribedResources(serverInfo, type);
             if (resources != null) {
               adsStream.sendDiscoveryRequest(type, resources);
    @@ -298,151 +265,56 @@ public void run() {
         }
       }
     
    -  enum ResourceType {
    -    UNKNOWN, LDS, RDS, CDS, EDS;
    -
    -    String typeUrl() {
    -      switch (this) {
    -        case LDS:
    -          return ADS_TYPE_URL_LDS;
    -        case RDS:
    -          return ADS_TYPE_URL_RDS;
    -        case CDS:
    -          return ADS_TYPE_URL_CDS;
    -        case EDS:
    -          return ADS_TYPE_URL_EDS;
    -        case UNKNOWN:
    -        default:
    -          throw new AssertionError("Unknown or missing case in enum switch: " + this);
    -      }
    -    }
    -
    -    String typeUrlV2() {
    -      switch (this) {
    -        case LDS:
    -          return ADS_TYPE_URL_LDS_V2;
    -        case RDS:
    -          return ADS_TYPE_URL_RDS_V2;
    -        case CDS:
    -          return ADS_TYPE_URL_CDS_V2;
    -        case EDS:
    -          return ADS_TYPE_URL_EDS_V2;
    -        case UNKNOWN:
    -        default:
    -          throw new AssertionError("Unknown or missing case in enum switch: " + this);
    -      }
    -    }
    -
    -    @VisibleForTesting
    -    static ResourceType fromTypeUrl(String typeUrl) {
    -      switch (typeUrl) {
    -        case ADS_TYPE_URL_LDS:
    -          // fall trough
    -        case ADS_TYPE_URL_LDS_V2:
    -          return LDS;
    -        case ADS_TYPE_URL_RDS:
    -          // fall through
    -        case ADS_TYPE_URL_RDS_V2:
    -          return RDS;
    -        case ADS_TYPE_URL_CDS:
    -          // fall through
    -        case ADS_TYPE_URL_CDS_V2:
    -          return CDS;
    -        case ADS_TYPE_URL_EDS:
    -          // fall through
    -        case ADS_TYPE_URL_EDS_V2:
    -          return EDS;
    -        default:
    -          return UNKNOWN;
    -      }
    -    }
    +  @VisibleForTesting
    +  @Nullable
    +  XdsResourceType fromTypeUrl(String typeUrl) {
    +    return resourceStore.getSubscribedResourceTypesWithTypeUrl().get(typeUrl);
       }
     
       private abstract class AbstractAdsStream {
         private boolean responseReceived;
         private boolean closed;
    -
         // Response nonce for the most recently received discovery responses of each resource type.
         // Client initiated requests start response nonce with empty string.
    -    // A nonce is used to indicate the specific DiscoveryResponse each DiscoveryRequest
    -    // corresponds to.
    -    // A nonce becomes stale following a newer nonce being presented to the client in a
    -    // DiscoveryResponse.
    -    private String ldsRespNonce = "";
    -    private String rdsRespNonce = "";
    -    private String cdsRespNonce = "";
    -    private String edsRespNonce = "";
    +    // Nonce in each response is echoed back in the following ACK/NACK request. It is
    +    // used for management server to identify which response the client is ACKing/NACking.
    +    // To avoid confusion, client-initiated requests will always use the nonce in
    +    // most recently received responses of each resource type.
    +    private final Map, String> respNonces = new HashMap<>();
     
         abstract void start();
     
         abstract void sendError(Exception error);
     
    +    abstract boolean isReady();
    +
         /**
          * Sends a discovery request with the given {@code versionInfo}, {@code nonce} and
          * {@code errorDetail}. Used for reacting to a specific discovery response. For
          * client-initiated discovery requests, use {@link
    -     * #sendDiscoveryRequest(ResourceType, Collection)}.
    +     * #sendDiscoveryRequest(XdsResourceType, Collection)}.
          */
    -    abstract void sendDiscoveryRequest(ResourceType type, String versionInfo,
    +    abstract void sendDiscoveryRequest(XdsResourceType type, String version,
             Collection resources, String nonce, @Nullable String errorDetail);
     
         /**
          * Sends a client-initiated discovery request.
          */
    -    final void sendDiscoveryRequest(ResourceType type, Collection resources) {
    -      String nonce;
    -      switch (type) {
    -        case LDS:
    -          nonce = ldsRespNonce;
    -          break;
    -        case RDS:
    -          nonce = rdsRespNonce;
    -          break;
    -        case CDS:
    -          nonce = cdsRespNonce;
    -          break;
    -        case EDS:
    -          nonce = edsRespNonce;
    -          break;
    -        case UNKNOWN:
    -        default:
    -          throw new AssertionError("Unknown resource type: " + type);
    -      }
    +    final void sendDiscoveryRequest(XdsResourceType type, Collection resources) {
           logger.log(XdsLogLevel.INFO, "Sending {0} request for resources: {1}", type, resources);
    -      sendDiscoveryRequest(type, getCurrentVersion(type), resources, nonce, null);
    +      sendDiscoveryRequest(type, versions.getOrDefault(type, ""), resources,
    +          respNonces.getOrDefault(type, ""), null);
         }
     
    -    final void handleRpcResponse(
    -        ResourceType type, String versionInfo, List resources, String nonce) {
    +    final void handleRpcResponse(XdsResourceType type, String versionInfo, List resources,
    +                                 String nonce) {
    +      checkNotNull(type, "type");
           if (closed) {
             return;
           }
           responseReceived = true;
    -      // Nonce in each response is echoed back in the following ACK/NACK request. It is
    -      // used for management server to identify which response the client is ACKing/NACking.
    -      // To avoid confusion, client-initiated requests will always use the nonce in
    -      // most recently received responses of each resource type.
    -      switch (type) {
    -        case LDS:
    -          ldsRespNonce = nonce;
    -          xdsResponseHandler.handleLdsResponse(serverInfo, versionInfo, resources, nonce);
    -          break;
    -        case RDS:
    -          rdsRespNonce = nonce;
    -          xdsResponseHandler.handleRdsResponse(serverInfo, versionInfo, resources, nonce);
    -          break;
    -        case CDS:
    -          cdsRespNonce = nonce;
    -          xdsResponseHandler.handleCdsResponse(serverInfo, versionInfo, resources, nonce);
    -          break;
    -        case EDS:
    -          edsRespNonce = nonce;
    -          xdsResponseHandler.handleEdsResponse(serverInfo, versionInfo, resources, nonce);
    -          break;
    -        case UNKNOWN:
    -        default:
    -          logger.log(XdsLogLevel.WARNING, "Ignore an unknown type of DiscoveryResponse");
    -      }
    +      respNonces.put(type, nonce);
    +      xdsResponseHandler.handleResourceResponse(type, serverInfo, versionInfo, resources, nonce);
         }
     
         final void handleRpcError(Throwable t) {
    @@ -450,34 +322,33 @@ final void handleRpcError(Throwable t) {
         }
     
         final void handleRpcCompleted() {
    -      handleRpcStreamClosed(Status.UNAVAILABLE.withDescription("Closed by server"));
    +      handleRpcStreamClosed(Status.UNAVAILABLE.withDescription(CLOSED_BY_SERVER));
         }
     
         private void handleRpcStreamClosed(Status error) {
    -      checkArgument(!error.isOk(), "unexpected OK status");
           if (closed) {
             return;
           }
    +
    +      checkArgument(!error.isOk(), "unexpected OK status");
    +      String errorMsg = error.getDescription() != null
    +          && error.getDescription().equals(CLOSED_BY_SERVER)
    +              ? "ADS stream closed with status {0}: {1}. Cause: {2}"
    +              : "ADS stream failed with status {0}: {1}. Cause: {2}";
           logger.log(
    -          XdsLogLevel.ERROR,
    -          "ADS stream closed with status {0}: {1}. Cause: {2}",
    -          error.getCode(), error.getDescription(), error.getCause());
    +          XdsLogLevel.ERROR, errorMsg, error.getCode(), error.getDescription(), error.getCause());
           closed = true;
           xdsResponseHandler.handleStreamClosed(error);
           cleanUp();
    +
           if (responseReceived || retryBackoffPolicy == null) {
             // Reset the backoff sequence if had received a response, or backoff sequence
             // has never been initialized.
             retryBackoffPolicy = backoffPolicyProvider.get();
           }
    -      long delayNanos = 0;
    -      if (!responseReceived) {
    -        delayNanos =
    -            Math.max(
    -                0,
    -                retryBackoffPolicy.nextBackoffNanos()
    -                    - stopwatch.elapsed(TimeUnit.NANOSECONDS));
    -      }
    +      long delayNanos = Math.max(
    +          0,
    +          retryBackoffPolicy.nextBackoffNanos() - stopwatch.elapsed(TimeUnit.NANOSECONDS));
           logger.log(XdsLogLevel.INFO, "Retry ADS stream in {0} ns", delayNanos);
           rpcRetryTimer = syncContext.schedule(
               new RpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, timeService);
    @@ -499,107 +370,44 @@ private void cleanUp() {
         }
       }
     
    -  private final class AdsStreamV2 extends AbstractAdsStream {
    -    private StreamObserver requestWriter;
    -
    -    @Override
    -    void start() {
    -      io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc
    -          .AggregatedDiscoveryServiceStub stub =
    -          io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.newStub(channel);
    -      StreamObserver responseReaderV2 =
    -          new StreamObserver() {
    -            @Override
    -            public void onNext(final io.envoyproxy.envoy.api.v2.DiscoveryResponse response) {
    -              syncContext.execute(new Runnable() {
    -                @Override
    -                public void run() {
    -                  ResourceType type = ResourceType.fromTypeUrl(response.getTypeUrl());
    -                  if (logger.isLoggable(XdsLogLevel.DEBUG)) {
    -                    logger.log(
    -                        XdsLogLevel.DEBUG, "Received {0} response:\n{1}", type,
    -                        MessagePrinter.print(response));
    -                  }
    -                  handleRpcResponse(type, response.getVersionInfo(), response.getResourcesList(),
    -                      response.getNonce());
    -                }
    -              });
    -            }
    -
    -            @Override
    -            public void onError(final Throwable t) {
    -              syncContext.execute(new Runnable() {
    -                @Override
    -                public void run() {
    -                  handleRpcError(t);
    -                }
    -              });
    -            }
    -
    -            @Override
    -            public void onCompleted() {
    -              syncContext.execute(new Runnable() {
    -                @Override
    -                public void run() {
    -                  handleRpcCompleted();
    -                }
    -              });
    -            }
    -          };
    -      requestWriter = stub.withWaitForReady().streamAggregatedResources(responseReaderV2);
    -    }
    -
    -    @Override
    -    void sendDiscoveryRequest(ResourceType type, String versionInfo, Collection resources,
    -        String nonce, @Nullable String errorDetail) {
    -      checkState(requestWriter != null, "ADS stream has not been started");
    -      io.envoyproxy.envoy.api.v2.DiscoveryRequest.Builder builder =
    -          io.envoyproxy.envoy.api.v2.DiscoveryRequest.newBuilder()
    -              .setVersionInfo(versionInfo)
    -              .setNode(bootstrapNode.toEnvoyProtoNodeV2())
    -              .addAllResourceNames(resources)
    -              .setTypeUrl(type.typeUrlV2())
    -              .setResponseNonce(nonce);
    -      if (errorDetail != null) {
    -        com.google.rpc.Status error =
    -            com.google.rpc.Status.newBuilder()
    -                .setCode(Code.INVALID_ARGUMENT_VALUE)  // FIXME(chengyuanzhang): use correct code
    -                .setMessage(errorDetail)
    -                .build();
    -        builder.setErrorDetail(error);
    -      }
    -      io.envoyproxy.envoy.api.v2.DiscoveryRequest request = builder.build();
    -      requestWriter.onNext(request);
    -      if (logger.isLoggable(XdsLogLevel.DEBUG)) {
    -        logger.log(XdsLogLevel.DEBUG, "Sent DiscoveryRequest\n{0}", MessagePrinter.print(request));
    -      }
    -    }
    +  private final class AdsStreamV3 extends AbstractAdsStream {
    +    private StreamObserver requestWriter;
     
         @Override
    -    void sendError(Exception error) {
    -      requestWriter.onError(error);
    +    public boolean isReady() {
    +      return requestWriter != null && ((ClientCallStreamObserver) requestWriter).isReady();
         }
    -  }
    -
    -  private final class AdsStreamV3 extends AbstractAdsStream {
    -    private StreamObserver requestWriter;
     
         @Override
         void start() {
           AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceStub stub =
               AggregatedDiscoveryServiceGrpc.newStub(channel);
    -      StreamObserver responseReader = new StreamObserver() {
    +      StreamObserver responseReader =
    +          new ClientResponseObserver() {
    +
    +        @Override
    +        public void beforeStart(ClientCallStreamObserver requestStream) {
    +          requestStream.setOnReadyHandler(AbstractXdsClient.this::readyHandler);
    +        }
    +
             @Override
             public void onNext(final DiscoveryResponse response) {
               syncContext.execute(new Runnable() {
                 @Override
                 public void run() {
    -              ResourceType type = ResourceType.fromTypeUrl(response.getTypeUrl());
    +              XdsResourceType type = fromTypeUrl(response.getTypeUrl());
                   if (logger.isLoggable(XdsLogLevel.DEBUG)) {
                     logger.log(
                         XdsLogLevel.DEBUG, "Received {0} response:\n{1}", type,
                         MessagePrinter.print(response));
                   }
    +              if (type == null) {
    +                logger.log(
    +                    XdsLogLevel.WARNING,
    +                    "Ignore an unknown type of DiscoveryResponse: {0}",
    +                    response.getTypeUrl());
    +                return;
    +              }
                   handleRpcResponse(type, response.getVersionInfo(), response.getResourcesList(),
                       response.getNonce());
                 }
    @@ -626,12 +434,13 @@ public void run() {
               });
             }
           };
    -      requestWriter = stub.withWaitForReady().streamAggregatedResources(responseReader);
    +      requestWriter = stub.streamAggregatedResources(responseReader);
         }
     
         @Override
    -    void sendDiscoveryRequest(ResourceType type, String versionInfo, Collection resources,
    -        String nonce, @Nullable String errorDetail) {
    +    void sendDiscoveryRequest(XdsResourceType type, String versionInfo,
    +                              Collection resources, String nonce,
    +                              @Nullable String errorDetail) {
           checkState(requestWriter != null, "ADS stream has not been started");
           DiscoveryRequest.Builder builder =
               DiscoveryRequest.newBuilder()
    diff --git a/xds/src/main/java/io/grpc/xds/Bootstrapper.java b/xds/src/main/java/io/grpc/xds/Bootstrapper.java
    index 57cdc6f324b..df997c088d5 100644
    --- a/xds/src/main/java/io/grpc/xds/Bootstrapper.java
    +++ b/xds/src/main/java/io/grpc/xds/Bootstrapper.java
    @@ -60,12 +60,20 @@ abstract static class ServerInfo {
     
         abstract ChannelCredentials channelCredentials();
     
    -    abstract boolean useProtocolV3();
    +    abstract boolean ignoreResourceDeletion();
     
         @VisibleForTesting
         static ServerInfo create(
    -        String target, ChannelCredentials channelCredentials, boolean useProtocolV3) {
    -      return new AutoValue_Bootstrapper_ServerInfo(target, channelCredentials, useProtocolV3);
    +        String target, ChannelCredentials channelCredentials) {
    +      return new AutoValue_Bootstrapper_ServerInfo(target, channelCredentials, false);
    +    }
    +
    +    @VisibleForTesting
    +    static ServerInfo create(
    +        String target, ChannelCredentials channelCredentials,
    +        boolean ignoreResourceDeletion) {
    +      return new AutoValue_Bootstrapper_ServerInfo(target, channelCredentials,
    +          ignoreResourceDeletion);
         }
       }
     
    diff --git a/xds/src/main/java/io/grpc/xds/BootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/BootstrapperImpl.java
    index 23b14357096..6d0e78a2c4a 100644
    --- a/xds/src/main/java/io/grpc/xds/BootstrapperImpl.java
    +++ b/xds/src/main/java/io/grpc/xds/BootstrapperImpl.java
    @@ -21,11 +21,7 @@
     import com.google.common.collect.ImmutableList;
     import com.google.common.collect.ImmutableMap;
     import io.grpc.ChannelCredentials;
    -import io.grpc.InsecureChannelCredentials;
    -import io.grpc.Internal;
     import io.grpc.InternalLogId;
    -import io.grpc.TlsChannelCredentials;
    -import io.grpc.alts.GoogleDefaultChannelCredentials;
     import io.grpc.internal.GrpcUtil;
     import io.grpc.internal.GrpcUtil.GrpcBuildVersion;
     import io.grpc.internal.JsonParser;
    @@ -44,8 +40,7 @@
     /**
      * A {@link Bootstrapper} implementation that reads xDS configurations from local file system.
      */
    -@Internal
    -public class BootstrapperImpl extends Bootstrapper {
    +class BootstrapperImpl extends Bootstrapper {
     
       private static final String BOOTSTRAP_PATH_SYS_ENV_VAR = "GRPC_XDS_BOOTSTRAP";
       @VisibleForTesting
    @@ -59,14 +54,21 @@ public class BootstrapperImpl extends Bootstrapper {
       private static final String BOOTSTRAP_CONFIG_SYS_PROPERTY = "io.grpc.xds.bootstrapConfig";
       @VisibleForTesting
       static String bootstrapConfigFromSysProp = System.getProperty(BOOTSTRAP_CONFIG_SYS_PROPERTY);
    -  private static final String XDS_V3_SERVER_FEATURE = "xds_v3";
    +
    +  // Feature-gating environment variables.
    +  static boolean enableFederation =
    +      !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION"))
    +          && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION"));
    +
    +  // Client features.
       @VisibleForTesting
       static final String CLIENT_FEATURE_DISABLE_OVERPROVISIONING =
           "envoy.lb.does_not_support_overprovisioning";
       @VisibleForTesting
    -  static boolean enableFederation =
    -      !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION"))
    -          && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION"));
    +  static final String CLIENT_FEATURE_RESOURCE_IN_SOTW = "xds.config.resource-in-sotw";
    +
    +  // Server features.
    +  private static final String SERVER_FEATURE_IGNORE_RESOURCE_DELETION = "ignore_resource_deletion";
     
       private final XdsLogger logger;
       private FileReader reader = LocalFileReader.INSTANCE;
    @@ -179,6 +181,7 @@ BootstrapInfo bootstrap(Map rawData) throws XdsInitializationExceptio
         nodeBuilder.setUserAgentName(buildVersion.getUserAgent());
         nodeBuilder.setUserAgentVersion(buildVersion.getImplementationVersion());
         nodeBuilder.addClientFeatures(CLIENT_FEATURE_DISABLE_OVERPROVISIONING);
    +    nodeBuilder.addClientFeatures(CLIENT_FEATURE_RESOURCE_IN_SOTW);
         builder.node(nodeBuilder.build());
     
         Map certProvidersBlob = JsonUtil.getObject(rawData, "certificate_providers");
    @@ -247,7 +250,7 @@ BootstrapInfo bootstrap(Map rawData) throws XdsInitializationExceptio
             authorityInfoMapBuilder.put(
                 authorityName, AuthorityInfo.create(clientListnerTemplate, authorityServers));
           }
    -      builder.authorities(authorityInfoMapBuilder.build());
    +      builder.authorities(authorityInfoMapBuilder.buildOrThrow());
         }
     
         return builder.build();
    @@ -277,13 +280,14 @@ private static List parseServerInfos(List rawServerConfigs, XdsLo
                 "Server " + serverUri + ": no supported channel credentials found");
           }
     
    -      boolean useProtocolV3 = false;
    +      boolean ignoreResourceDeletion = false;
           List serverFeatures = JsonUtil.getListOfStrings(serverConfig, "server_features");
           if (serverFeatures != null) {
             logger.log(XdsLogLevel.INFO, "Server features: {0}", serverFeatures);
    -        useProtocolV3 = serverFeatures.contains(XDS_V3_SERVER_FEATURE);
    +        ignoreResourceDeletion = serverFeatures.contains(SERVER_FEATURE_IGNORE_RESOURCE_DELETION);
           }
    -      servers.add(ServerInfo.create(serverUri, channelCredentials, useProtocolV3));
    +      servers.add(
    +          ServerInfo.create(serverUri, channelCredentials, ignoreResourceDeletion));
         }
         return servers.build();
       }
    @@ -326,14 +330,15 @@ private static ChannelCredentials parseChannelCredentials(List> j
             throw new XdsInitializationException(
                 "Invalid bootstrap: server " + serverUri + " with 'channel_creds' type unspecified");
           }
    -      switch (type) {
    -        case "google_default":
    -          return GoogleDefaultChannelCredentials.create();
    -        case "insecure":
    -          return InsecureChannelCredentials.create();
    -        case "tls":
    -          return TlsChannelCredentials.create();
    -        default:
    +      XdsCredentialsProvider provider =  XdsCredentialsRegistry.getDefaultRegistry()
    +          .getProvider(type);
    +      if (provider != null) {
    +        Map config = JsonUtil.getObject(channelCreds, "config");
    +        if (config == null) {
    +          config = ImmutableMap.of();
    +        }
    +
    +        return provider.newChannelCredentials(config);
           }
         }
         return null;
    diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java
    index 7a346e01871..0db0f59eaa2 100644
    --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java
    +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java
    @@ -25,19 +25,19 @@
     import io.grpc.LoadBalancer;
     import io.grpc.LoadBalancerProvider;
     import io.grpc.LoadBalancerRegistry;
    +import io.grpc.NameResolver;
     import io.grpc.Status;
     import io.grpc.SynchronizationContext;
     import io.grpc.internal.ObjectPool;
    +import io.grpc.internal.ServiceConfigUtil;
    +import io.grpc.internal.ServiceConfigUtil.LbConfig;
     import io.grpc.internal.ServiceConfigUtil.PolicySelection;
     import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig;
     import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig;
     import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism;
    -import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig;
    -import io.grpc.xds.RingHashLoadBalancer.RingHashConfig;
    -import io.grpc.xds.XdsClient.CdsResourceWatcher;
    -import io.grpc.xds.XdsClient.CdsUpdate;
    -import io.grpc.xds.XdsClient.CdsUpdate.ClusterType;
    -import io.grpc.xds.XdsClient.CdsUpdate.LbPolicy;
    +import io.grpc.xds.XdsClient.ResourceWatcher;
    +import io.grpc.xds.XdsClusterResource.CdsUpdate;
    +import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType;
     import io.grpc.xds.XdsLogger.XdsLogLevel;
     import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
     import java.util.ArrayDeque;
    @@ -79,9 +79,9 @@ final class CdsLoadBalancer2 extends LoadBalancer {
       }
     
       @Override
    -  public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
    +  public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
         if (this.resolvedAddresses != null) {
    -      return;
    +      return true;
         }
         logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
         this.resolvedAddresses = resolvedAddresses;
    @@ -91,6 +91,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
         logger.log(XdsLogLevel.INFO, "Config: {0}", config);
         cdsLbState = new CdsLbState(config.name);
         cdsLbState.start();
    +    return true;
       }
     
       @Override
    @@ -159,7 +160,7 @@ private void handleClusterDiscovered() {
                   instance = DiscoveryMechanism.forEds(
                       clusterState.name, clusterState.result.edsServiceName(),
                       clusterState.result.lrsServerInfo(), clusterState.result.maxConcurrentRequests(),
    -                  clusterState.result.upstreamTlsContext());
    +                  clusterState.result.upstreamTlsContext(), clusterState.result.outlierDetection());
                 } else {  // logical DNS
                   instance = DiscoveryMechanism.forLogicalDns(
                       clusterState.name, clusterState.result.dnsHostName(),
    @@ -185,22 +186,27 @@ private void handleClusterDiscovered() {
             helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(unavailable));
             return;
           }
    -      LoadBalancerProvider lbProvider = null;
    -      Object lbConfig = null;
    -      if (root.result.lbPolicy() == LbPolicy.RING_HASH) {
    -        lbProvider = lbRegistry.getProvider("ring_hash_experimental");
    -        lbConfig = new RingHashConfig(root.result.minRingSize(), root.result.maxRingSize());
    -      }
    -      if (root.result.lbPolicy() == LbPolicy.LEAST_REQUEST) {
    -        lbProvider = lbRegistry.getProvider("least_request_experimental");
    -        lbConfig = new LeastRequestConfig(root.result.choiceCount());
    -      }
    +
    +      // The LB policy config is provided in service_config.proto/JSON format. It is unwrapped
    +      // to determine the name of the policy in the load balancer registry.
    +      LbConfig unwrappedLbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(
    +          root.result.lbPolicyConfig());
    +      LoadBalancerProvider lbProvider = lbRegistry.getProvider(unwrappedLbConfig.getPolicyName());
           if (lbProvider == null) {
    -        lbProvider = lbRegistry.getProvider("round_robin");
    -        lbConfig = null;
    +        throw NameResolver.ConfigOrError.fromError(Status.UNAVAILABLE.withDescription(
    +                "No provider available for LB: " + unwrappedLbConfig.getPolicyName())).getError()
    +            .asRuntimeException();
           }
    +      NameResolver.ConfigOrError configOrError = lbProvider.parseLoadBalancingPolicyConfig(
    +          unwrappedLbConfig.getRawConfigValue());
    +      if (configOrError.getError() != null) {
    +        throw configOrError.getError().augmentDescription("Unable to parse the LB config")
    +            .asRuntimeException();
    +      }
    +
           ClusterResolverConfig config = new ClusterResolverConfig(
    -          Collections.unmodifiableList(instances), new PolicySelection(lbProvider, lbConfig));
    +          Collections.unmodifiableList(instances),
    +          new PolicySelection(lbProvider, configOrError.getConfig()));
           if (childLb == null) {
             childLb = lbRegistry.getProvider(CLUSTER_RESOLVER_POLICY_NAME).newLoadBalancer(helper);
           }
    @@ -216,7 +222,7 @@ private void handleClusterDiscoveryError(Status error) {
           }
         }
     
    -    private final class ClusterState implements CdsResourceWatcher {
    +    private final class ClusterState implements ResourceWatcher {
           private final String name;
           @Nullable
           private Map childClusterStates;
    @@ -232,12 +238,12 @@ private ClusterState(String name) {
           }
     
           private void start() {
    -        xdsClient.watchCdsResource(name, this);
    +        xdsClient.watchXdsResource(XdsClusterResource.getInstance(), name, this);
           }
     
           void shutdown() {
             shutdown = true;
    -        xdsClient.cancelCdsResourceWatch(name, this);
    +        xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(), name, this);
             if (childClusterStates != null) {  // recursively shut down all descendants
               for (ClusterState state : childClusterStates.values()) {
                 state.shutdown();
    @@ -246,7 +252,12 @@ void shutdown() {
           }
     
           @Override
    -      public void onError(final Status error) {
    +      public void onError(Status error) {
    +        Status status = Status.UNAVAILABLE
    +            .withDescription(
    +                String.format("Unable to load CDS %s. xDS server returned: %s: %s",
    +                  name, error.getCode(), error.getDescription()))
    +            .withCause(error.getCause());
             syncContext.execute(new Runnable() {
               @Override
               public void run() {
    @@ -255,7 +266,7 @@ public void run() {
                 }
                 // All watchers should receive the same error, so we only propagate it once.
                 if (ClusterState.this == root) {
    -              handleClusterDiscoveryError(error);
    +              handleClusterDiscoveryError(status);
                 }
               }
             });
    @@ -290,6 +301,7 @@ public void run() {
                 if (shutdown) {
                   return;
                 }
    +
                 logger.log(XdsLogLevel.DEBUG, "Received cluster update {0}", update);
                 discovered = true;
                 result = update;
    diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java
    index d2e3de10efa..01bd2ab27f6 100644
    --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java
    +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java
    @@ -75,7 +75,7 @@ static ConfigOrError parseLoadBalancingConfigPolicy(Map rawLoadBalanc
           return ConfigOrError.fromConfig(new CdsConfig(cluster));
         } catch (RuntimeException e) {
           return ConfigOrError.fromError(
    -          Status.fromThrowable(e).withDescription(
    +          Status.UNAVAILABLE.withCause(e).withDescription(
                   "Failed to parse CDS LB config: " + rawLoadBalancingPolicyConfig));
         }
       }
    diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java
    deleted file mode 100644
    index 95b3bbfcfdf..00000000000
    --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java
    +++ /dev/null
    @@ -1,2628 +0,0 @@
    -/*
    - * Copyright 2020 The gRPC Authors
    - *
    - * Licensed under the Apache License, Version 2.0 (the "License");
    - * you may not use this file except in compliance with the License.
    - * You may obtain a copy of the License at
    - *
    - *     http://www.apache.org/licenses/LICENSE-2.0
    - *
    - * Unless required by applicable law or agreed to in writing, software
    - * distributed under the License is distributed on an "AS IS" BASIS,
    - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    - * See the License for the specific language governing permissions and
    - * limitations under the License.
    - */
    -
    -package io.grpc.xds;
    -
    -import static com.google.common.base.Preconditions.checkArgument;
    -import static com.google.common.base.Preconditions.checkNotNull;
    -import static io.grpc.xds.Bootstrapper.XDSTP_SCHEME;
    -
    -import com.github.udpa.udpa.type.v1.TypedStruct;
    -import com.google.common.annotations.VisibleForTesting;
    -import com.google.common.base.Joiner;
    -import com.google.common.base.Splitter;
    -import com.google.common.base.Stopwatch;
    -import com.google.common.base.Strings;
    -import com.google.common.base.Supplier;
    -import com.google.common.collect.ImmutableList;
    -import com.google.common.collect.ImmutableMap;
    -import com.google.common.collect.ImmutableSet;
    -import com.google.common.util.concurrent.ListenableFuture;
    -import com.google.common.util.concurrent.SettableFuture;
    -import com.google.protobuf.Any;
    -import com.google.protobuf.Duration;
    -import com.google.protobuf.InvalidProtocolBufferException;
    -import com.google.protobuf.Message;
    -import com.google.protobuf.util.Durations;
    -import com.google.re2j.Pattern;
    -import com.google.re2j.PatternSyntaxException;
    -import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers.Thresholds;
    -import io.envoyproxy.envoy.config.cluster.v3.Cluster;
    -import io.envoyproxy.envoy.config.cluster.v3.Cluster.CustomClusterType;
    -import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType;
    -import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy;
    -import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig;
    -import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig;
    -import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions;
    -import io.envoyproxy.envoy.config.core.v3.RoutingPriority;
    -import io.envoyproxy.envoy.config.core.v3.SocketAddress;
    -import io.envoyproxy.envoy.config.core.v3.SocketAddress.PortSpecifierCase;
    -import io.envoyproxy.envoy.config.core.v3.TrafficDirection;
    -import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig;
    -import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment;
    -import io.envoyproxy.envoy.config.listener.v3.Listener;
    -import io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin;
    -import io.envoyproxy.envoy.config.route.v3.RetryPolicy.RetryBackOff;
    -import io.envoyproxy.envoy.config.route.v3.RouteConfiguration;
    -import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager;
    -import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds;
    -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
    -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
    -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext;
    -import io.envoyproxy.envoy.type.v3.FractionalPercent;
    -import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType;
    -import io.grpc.ChannelCredentials;
    -import io.grpc.Context;
    -import io.grpc.EquivalentAddressGroup;
    -import io.grpc.Grpc;
    -import io.grpc.InternalLogId;
    -import io.grpc.ManagedChannel;
    -import io.grpc.Status;
    -import io.grpc.Status.Code;
    -import io.grpc.SynchronizationContext;
    -import io.grpc.SynchronizationContext.ScheduledHandle;
    -import io.grpc.internal.BackoffPolicy;
    -import io.grpc.internal.TimeProvider;
    -import io.grpc.xds.AbstractXdsClient.ResourceType;
    -import io.grpc.xds.Bootstrapper.AuthorityInfo;
    -import io.grpc.xds.Bootstrapper.ServerInfo;
    -import io.grpc.xds.ClusterSpecifierPlugin.NamedPluginConfig;
    -import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig;
    -import io.grpc.xds.Endpoints.DropOverload;
    -import io.grpc.xds.Endpoints.LbEndpoint;
    -import io.grpc.xds.Endpoints.LocalityLbEndpoints;
    -import io.grpc.xds.EnvoyServerProtoData.CidrRange;
    -import io.grpc.xds.EnvoyServerProtoData.ConnectionSourceType;
    -import io.grpc.xds.EnvoyServerProtoData.FilterChain;
    -import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch;
    -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
    -import io.grpc.xds.Filter.ClientInterceptorBuilder;
    -import io.grpc.xds.Filter.FilterConfig;
    -import io.grpc.xds.Filter.NamedFilterConfig;
    -import io.grpc.xds.Filter.ServerInterceptorBuilder;
    -import io.grpc.xds.LoadStatsManager2.ClusterDropStats;
    -import io.grpc.xds.LoadStatsManager2.ClusterLocalityStats;
    -import io.grpc.xds.VirtualHost.Route;
    -import io.grpc.xds.VirtualHost.Route.RouteAction;
    -import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight;
    -import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy;
    -import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy;
    -import io.grpc.xds.VirtualHost.Route.RouteMatch;
    -import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher;
    -import io.grpc.xds.XdsClient.ResourceStore;
    -import io.grpc.xds.XdsClient.XdsResponseHandler;
    -import io.grpc.xds.XdsLogger.XdsLogLevel;
    -import io.grpc.xds.internal.Matchers.FractionMatcher;
    -import io.grpc.xds.internal.Matchers.HeaderMatcher;
    -import java.net.InetSocketAddress;
    -import java.net.URI;
    -import java.net.UnknownHostException;
    -import java.util.ArrayList;
    -import java.util.Collection;
    -import java.util.Collections;
    -import java.util.EnumSet;
    -import java.util.HashMap;
    -import java.util.HashSet;
    -import java.util.LinkedHashMap;
    -import java.util.List;
    -import java.util.Locale;
    -import java.util.Map;
    -import java.util.Objects;
    -import java.util.Set;
    -import java.util.concurrent.ScheduledExecutorService;
    -import java.util.concurrent.TimeUnit;
    -import javax.annotation.Nullable;
    -
    -/**
    - * XdsClient implementation for client side usages.
    - */
    -final class ClientXdsClient extends XdsClient implements XdsResponseHandler, ResourceStore {
    -
    -  // Longest time to wait, since the subscription to some resource, for concluding its absence.
    -  @VisibleForTesting
    -  static final int INITIAL_RESOURCE_FETCH_TIMEOUT_SEC = 15;
    -  private static final String TRANSPORT_SOCKET_NAME_TLS = "envoy.transport_sockets.tls";
    -  @VisibleForTesting
    -  static final long DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE = 1024L;
    -  @VisibleForTesting
    -  static final long DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE = 8 * 1024 * 1024L;
    -  @VisibleForTesting
    -  static final int DEFAULT_LEAST_REQUEST_CHOICE_COUNT = 2;
    -  @VisibleForTesting
    -  static final long MAX_RING_HASH_LB_POLICY_RING_SIZE = 8 * 1024 * 1024L;
    -  @VisibleForTesting
    -  static final String AGGREGATE_CLUSTER_TYPE_NAME = "envoy.clusters.aggregate";
    -  @VisibleForTesting
    -  static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id";
    -  @VisibleForTesting
    -  static boolean enableFaultInjection =
    -      Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION"))
    -          || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION"));
    -  @VisibleForTesting
    -  static boolean enableRetry =
    -      Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY"))
    -          || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY"));
    -  @VisibleForTesting
    -  static boolean enableRbac =
    -      Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC"))
    -          || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC"));
    -  @VisibleForTesting
    -  static boolean enableRouteLookup =
    -      !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_RLS_LB"))
    -          && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_RLS_LB"));
    -  @VisibleForTesting
    -  static boolean enableLeastRequest =
    -      !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST"))
    -          ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST"))
    -          : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest"));
    -  private static final String TYPE_URL_HTTP_CONNECTION_MANAGER_V2 =
    -      "type.googleapis.com/envoy.config.filter.network.http_connection_manager.v2"
    -          + ".HttpConnectionManager";
    -  static final String TYPE_URL_HTTP_CONNECTION_MANAGER =
    -      "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3"
    -          + ".HttpConnectionManager";
    -  private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT =
    -      "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext";
    -  private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT_V2 =
    -      "type.googleapis.com/envoy.api.v2.auth.UpstreamTlsContext";
    -  private static final String TYPE_URL_CLUSTER_CONFIG_V2 =
    -      "type.googleapis.com/envoy.config.cluster.aggregate.v2alpha.ClusterConfig";
    -  private static final String TYPE_URL_CLUSTER_CONFIG =
    -      "type.googleapis.com/envoy.extensions.clusters.aggregate.v3.ClusterConfig";
    -  private static final String TYPE_URL_TYPED_STRUCT_UDPA =
    -      "type.googleapis.com/udpa.type.v1.TypedStruct";
    -  private static final String TYPE_URL_TYPED_STRUCT =
    -      "type.googleapis.com/xds.type.v3.TypedStruct";
    -  private static final String TYPE_URL_FILTER_CONFIG =
    -      "type.googleapis.com/envoy.config.route.v3.FilterConfig";
    -  // TODO(zdapeng): need to discuss how to handle unsupported values.
    -  private static final Set SUPPORTED_RETRYABLE_CODES =
    -      Collections.unmodifiableSet(EnumSet.of(
    -          Code.CANCELLED, Code.DEADLINE_EXCEEDED, Code.INTERNAL, Code.RESOURCE_EXHAUSTED,
    -          Code.UNAVAILABLE));
    -
    -  private final SynchronizationContext syncContext = new SynchronizationContext(
    -      new Thread.UncaughtExceptionHandler() {
    -        @Override
    -        public void uncaughtException(Thread t, Throwable e) {
    -          logger.log(
    -              XdsLogLevel.ERROR,
    -              "Uncaught exception in XdsClient SynchronizationContext. Panic!",
    -              e);
    -          // TODO(chengyuanzhang): better error handling.
    -          throw new AssertionError(e);
    -        }
    -      });
    -  private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry();
    -  private final Map serverChannelMap = new HashMap<>();
    -  private final Map ldsResourceSubscribers = new HashMap<>();
    -  private final Map rdsResourceSubscribers = new HashMap<>();
    -  private final Map cdsResourceSubscribers = new HashMap<>();
    -  private final Map edsResourceSubscribers = new HashMap<>();
    -  private final LoadStatsManager2 loadStatsManager;
    -  private final Map serverLrsClientMap = new HashMap<>();
    -  private final XdsChannelFactory xdsChannelFactory;
    -  private final Bootstrapper.BootstrapInfo bootstrapInfo;
    -  private final Context context;
    -  private final ScheduledExecutorService timeService;
    -  private final BackoffPolicy.Provider backoffPolicyProvider;
    -  private final Supplier stopwatchSupplier;
    -  private final TimeProvider timeProvider;
    -  private boolean reportingLoad;
    -  private final TlsContextManager tlsContextManager;
    -  private final InternalLogId logId;
    -  private final XdsLogger logger;
    -  private volatile boolean isShutdown;
    -
    -  // TODO(zdapeng): rename to XdsClientImpl
    -  ClientXdsClient(
    -      XdsChannelFactory xdsChannelFactory,
    -      Bootstrapper.BootstrapInfo bootstrapInfo,
    -      Context context,
    -      ScheduledExecutorService timeService,
    -      BackoffPolicy.Provider backoffPolicyProvider,
    -      Supplier stopwatchSupplier,
    -      TimeProvider timeProvider,
    -      TlsContextManager tlsContextManager) {
    -    this.xdsChannelFactory = xdsChannelFactory;
    -    this.bootstrapInfo = bootstrapInfo;
    -    this.context = context;
    -    this.timeService = timeService;
    -    loadStatsManager = new LoadStatsManager2(stopwatchSupplier);
    -    this.backoffPolicyProvider = backoffPolicyProvider;
    -    this.stopwatchSupplier = stopwatchSupplier;
    -    this.timeProvider = timeProvider;
    -    this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager");
    -    logId = InternalLogId.allocate("xds-client", null);
    -    logger = XdsLogger.withLogId(logId);
    -    logger.log(XdsLogLevel.INFO, "Created");
    -  }
    -
    -  private void maybeCreateXdsChannelWithLrs(ServerInfo serverInfo) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    if (serverChannelMap.containsKey(serverInfo)) {
    -      return;
    -    }
    -    AbstractXdsClient xdsChannel = new AbstractXdsClient(
    -        xdsChannelFactory,
    -        serverInfo,
    -        bootstrapInfo.node(),
    -        this,
    -        this,
    -        context,
    -        timeService,
    -        syncContext,
    -        backoffPolicyProvider,
    -        stopwatchSupplier);
    -    LoadReportClient lrsClient = new LoadReportClient(
    -        loadStatsManager, xdsChannel.channel(), context, serverInfo.useProtocolV3(),
    -        bootstrapInfo.node(), syncContext, timeService, backoffPolicyProvider, stopwatchSupplier);
    -    serverChannelMap.put(serverInfo, xdsChannel);
    -    serverLrsClientMap.put(serverInfo, lrsClient);
    -  }
    -
    -  @Override
    -  public void handleLdsResponse(
    -      ServerInfo serverInfo, String versionInfo, List resources, String nonce) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    Map parsedResources = new HashMap<>(resources.size());
    -    Set unpackedResources = new HashSet<>(resources.size());
    -    Set invalidResources = new HashSet<>();
    -    List errors = new ArrayList<>();
    -    Set retainedRdsResources = new HashSet<>();
    -
    -    for (int i = 0; i < resources.size(); i++) {
    -      Any resource = resources.get(i);
    -
    -      // Unpack the Listener.
    -      boolean isResourceV3 = resource.getTypeUrl().equals(ResourceType.LDS.typeUrl());
    -      Listener listener;
    -      try {
    -        listener = unpackCompatibleType(resource, Listener.class, ResourceType.LDS.typeUrl(),
    -            ResourceType.LDS.typeUrlV2());
    -      } catch (InvalidProtocolBufferException e) {
    -        errors.add("LDS response Resource index " + i + " - can't decode Listener: " + e);
    -        continue;
    -      }
    -      if (!isResourceNameValid(listener.getName(), resource.getTypeUrl())) {
    -        errors.add(
    -            "Unsupported resource name: " + listener.getName() + " for type: " + ResourceType.LDS);
    -        continue;
    -      }
    -      String listenerName = canonifyResourceName(listener.getName());
    -      unpackedResources.add(listenerName);
    -
    -      // Process Listener into LdsUpdate.
    -      LdsUpdate ldsUpdate;
    -      try {
    -        if (listener.hasApiListener()) {
    -          ldsUpdate = processClientSideListener(
    -              listener, retainedRdsResources, enableFaultInjection && isResourceV3);
    -        } else {
    -          ldsUpdate = processServerSideListener(
    -              listener, retainedRdsResources, enableRbac && isResourceV3);
    -        }
    -      } catch (ResourceInvalidException e) {
    -        errors.add(
    -            "LDS response Listener '" + listenerName + "' validation error: " + e.getMessage());
    -        invalidResources.add(listenerName);
    -        continue;
    -      }
    -
    -      // LdsUpdate parsed successfully.
    -      parsedResources.put(listenerName, new ParsedResource(ldsUpdate, resource));
    -    }
    -    logger.log(XdsLogLevel.INFO,
    -        "Received LDS Response version {0} nonce {1}. Parsed resources: {2}",
    -        versionInfo, nonce, unpackedResources);
    -    handleResourceUpdate(
    -        serverInfo, ResourceType.LDS, parsedResources, invalidResources, retainedRdsResources,
    -        versionInfo, nonce, errors);
    -  }
    -
    -  private LdsUpdate processClientSideListener(
    -      Listener listener, Set rdsResources, boolean parseHttpFilter)
    -      throws ResourceInvalidException {
    -    // Unpack HttpConnectionManager from the Listener.
    -    HttpConnectionManager hcm;
    -    try {
    -      hcm = unpackCompatibleType(
    -          listener.getApiListener().getApiListener(), HttpConnectionManager.class,
    -          TYPE_URL_HTTP_CONNECTION_MANAGER, TYPE_URL_HTTP_CONNECTION_MANAGER_V2);
    -    } catch (InvalidProtocolBufferException e) {
    -      throw new ResourceInvalidException(
    -          "Could not parse HttpConnectionManager config from ApiListener", e);
    -    }
    -    return LdsUpdate.forApiListener(parseHttpConnectionManager(
    -        hcm, rdsResources, filterRegistry, parseHttpFilter, true /* isForClient */));
    -  }
    -
    -  private LdsUpdate processServerSideListener(
    -      Listener proto, Set rdsResources, boolean parseHttpFilter)
    -      throws ResourceInvalidException {
    -    Set certProviderInstances = null;
    -    if (getBootstrapInfo() != null && getBootstrapInfo().certProviders() != null) {
    -      certProviderInstances = getBootstrapInfo().certProviders().keySet();
    -    }
    -    return LdsUpdate.forTcpListener(parseServerSideListener(
    -        proto, rdsResources, tlsContextManager, filterRegistry, certProviderInstances,
    -        parseHttpFilter));
    -  }
    -
    -  @VisibleForTesting
    -  static EnvoyServerProtoData.Listener parseServerSideListener(
    -      Listener proto, Set rdsResources, TlsContextManager tlsContextManager,
    -      FilterRegistry filterRegistry, Set certProviderInstances, boolean parseHttpFilter)
    -      throws ResourceInvalidException {
    -    if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) {
    -      throw new ResourceInvalidException(
    -          "Listener " + proto.getName() + " with invalid traffic direction: "
    -              + proto.getTrafficDirection());
    -    }
    -    if (!proto.getListenerFiltersList().isEmpty()) {
    -      throw new ResourceInvalidException(
    -          "Listener " + proto.getName() + " cannot have listener_filters");
    -    }
    -    if (proto.hasUseOriginalDst()) {
    -      throw new ResourceInvalidException(
    -          "Listener " + proto.getName() + " cannot have use_original_dst set to true");
    -    }
    -
    -    String address = null;
    -    if (proto.getAddress().hasSocketAddress()) {
    -      SocketAddress socketAddress = proto.getAddress().getSocketAddress();
    -      address = socketAddress.getAddress();
    -      switch (socketAddress.getPortSpecifierCase()) {
    -        case NAMED_PORT:
    -          address = address + ":" + socketAddress.getNamedPort();
    -          break;
    -        case PORT_VALUE:
    -          address = address + ":" + socketAddress.getPortValue();
    -          break;
    -        default:
    -          // noop
    -      }
    -    }
    -
    -    ImmutableList.Builder filterChains = ImmutableList.builder();
    -    Set uniqueSet = new HashSet<>();
    -    for (io.envoyproxy.envoy.config.listener.v3.FilterChain fc : proto.getFilterChainsList()) {
    -      filterChains.add(
    -          parseFilterChain(fc, rdsResources, tlsContextManager, filterRegistry, uniqueSet,
    -              certProviderInstances, parseHttpFilter));
    -    }
    -    FilterChain defaultFilterChain = null;
    -    if (proto.hasDefaultFilterChain()) {
    -      defaultFilterChain = parseFilterChain(
    -          proto.getDefaultFilterChain(), rdsResources, tlsContextManager, filterRegistry,
    -          null, certProviderInstances, parseHttpFilter);
    -    }
    -
    -    return EnvoyServerProtoData.Listener.create(
    -        proto.getName(), address, filterChains.build(), defaultFilterChain);
    -  }
    -
    -  @VisibleForTesting
    -  static FilterChain parseFilterChain(
    -      io.envoyproxy.envoy.config.listener.v3.FilterChain proto, Set rdsResources,
    -      TlsContextManager tlsContextManager, FilterRegistry filterRegistry,
    -      Set uniqueSet, Set certProviderInstances, boolean parseHttpFilters)
    -      throws ResourceInvalidException {
    -    if (proto.getFiltersCount() != 1) {
    -      throw new ResourceInvalidException("FilterChain " + proto.getName()
    -              + " should contain exact one HttpConnectionManager filter");
    -    }
    -    io.envoyproxy.envoy.config.listener.v3.Filter filter = proto.getFiltersList().get(0);
    -    if (!filter.hasTypedConfig()) {
    -      throw new ResourceInvalidException(
    -          "FilterChain " + proto.getName() + " contains filter " + filter.getName()
    -              + " without typed_config");
    -    }
    -    Any any = filter.getTypedConfig();
    -    // HttpConnectionManager is the only supported network filter at the moment.
    -    if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) {
    -      throw new ResourceInvalidException(
    -          "FilterChain " + proto.getName() + " contains filter " + filter.getName()
    -              + " with unsupported typed_config type " + any.getTypeUrl());
    -    }
    -    HttpConnectionManager hcmProto;
    -    try {
    -      hcmProto = any.unpack(HttpConnectionManager.class);
    -    } catch (InvalidProtocolBufferException e) {
    -      throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter "
    -          + filter.getName() + " failed to unpack message", e);
    -    }
    -    io.grpc.xds.HttpConnectionManager httpConnectionManager = parseHttpConnectionManager(
    -            hcmProto, rdsResources, filterRegistry, parseHttpFilters, false /* isForClient */);
    -
    -    EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = null;
    -    if (proto.hasTransportSocket()) {
    -      if (!TRANSPORT_SOCKET_NAME_TLS.equals(proto.getTransportSocket().getName())) {
    -        throw new ResourceInvalidException("transport-socket with name "
    -            + proto.getTransportSocket().getName() + " not supported.");
    -      }
    -      DownstreamTlsContext downstreamTlsContextProto;
    -      try {
    -        downstreamTlsContextProto =
    -            proto.getTransportSocket().getTypedConfig().unpack(DownstreamTlsContext.class);
    -      } catch (InvalidProtocolBufferException e) {
    -        throw new ResourceInvalidException("FilterChain " + proto.getName()
    -            + " failed to unpack message", e);
    -      }
    -      downstreamTlsContext =
    -          EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(
    -              validateDownstreamTlsContext(downstreamTlsContextProto, certProviderInstances));
    -    }
    -
    -    FilterChainMatch filterChainMatch = parseFilterChainMatch(proto.getFilterChainMatch());
    -    checkForUniqueness(uniqueSet, filterChainMatch);
    -    return FilterChain.create(
    -        proto.getName(),
    -        filterChainMatch,
    -        httpConnectionManager,
    -        downstreamTlsContext,
    -        tlsContextManager
    -    );
    -  }
    -
    -  @VisibleForTesting
    -  static DownstreamTlsContext validateDownstreamTlsContext(
    -      DownstreamTlsContext downstreamTlsContext, Set certProviderInstances)
    -      throws ResourceInvalidException {
    -    if (downstreamTlsContext.hasCommonTlsContext()) {
    -      validateCommonTlsContext(downstreamTlsContext.getCommonTlsContext(), certProviderInstances,
    -          true);
    -    } else {
    -      throw new ResourceInvalidException(
    -          "common-tls-context is required in downstream-tls-context");
    -    }
    -    if (downstreamTlsContext.hasRequireSni()) {
    -      throw new ResourceInvalidException(
    -          "downstream-tls-context with require-sni is not supported");
    -    }
    -    DownstreamTlsContext.OcspStaplePolicy ocspStaplePolicy = downstreamTlsContext
    -        .getOcspStaplePolicy();
    -    if (ocspStaplePolicy != DownstreamTlsContext.OcspStaplePolicy.UNRECOGNIZED
    -        && ocspStaplePolicy != DownstreamTlsContext.OcspStaplePolicy.LENIENT_STAPLING) {
    -      throw new ResourceInvalidException(
    -          "downstream-tls-context with ocsp_staple_policy value " + ocspStaplePolicy.name()
    -              + " is not supported");
    -    }
    -    return downstreamTlsContext;
    -  }
    -
    -  @VisibleForTesting
    -  static io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
    -      validateUpstreamTlsContext(
    -      io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext,
    -      Set certProviderInstances)
    -      throws ResourceInvalidException {
    -    if (upstreamTlsContext.hasCommonTlsContext()) {
    -      validateCommonTlsContext(upstreamTlsContext.getCommonTlsContext(), certProviderInstances,
    -          false);
    -    } else {
    -      throw new ResourceInvalidException("common-tls-context is required in upstream-tls-context");
    -    }
    -    return upstreamTlsContext;
    -  }
    -
    -  @VisibleForTesting
    -  static void validateCommonTlsContext(
    -      CommonTlsContext commonTlsContext, Set certProviderInstances, boolean server)
    -      throws ResourceInvalidException {
    -    if (commonTlsContext.hasCustomHandshaker()) {
    -      throw new ResourceInvalidException(
    -          "common-tls-context with custom_handshaker is not supported");
    -    }
    -    if (commonTlsContext.hasTlsParams()) {
    -      throw new ResourceInvalidException("common-tls-context with tls_params is not supported");
    -    }
    -    if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
    -      throw new ResourceInvalidException(
    -          "common-tls-context with validation_context_sds_secret_config is not supported");
    -    }
    -    if (commonTlsContext.hasValidationContextCertificateProvider()) {
    -      throw new ResourceInvalidException(
    -          "common-tls-context with validation_context_certificate_provider is not supported");
    -    }
    -    if (commonTlsContext.hasValidationContextCertificateProviderInstance()) {
    -      throw new ResourceInvalidException(
    -          "common-tls-context with validation_context_certificate_provider_instance is not"
    -              + " supported");
    -    }
    -    String certInstanceName = getIdentityCertInstanceName(commonTlsContext);
    -    if (certInstanceName == null) {
    -      if (server) {
    -        throw new ResourceInvalidException(
    -            "tls_certificate_provider_instance is required in downstream-tls-context");
    -      }
    -      if (commonTlsContext.getTlsCertificatesCount() > 0) {
    -        throw new ResourceInvalidException(
    -            "tls_certificate_provider_instance is unset");
    -      }
    -      if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
    -        throw new ResourceInvalidException(
    -            "tls_certificate_provider_instance is unset");
    -      }
    -      if (commonTlsContext.hasTlsCertificateCertificateProvider()) {
    -        throw new ResourceInvalidException(
    -            "tls_certificate_provider_instance is unset");
    -      }
    -    } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) {
    -      throw new ResourceInvalidException(
    -          "CertificateProvider instance name '" + certInstanceName
    -              + "' not defined in the bootstrap file.");
    -    }
    -    String rootCaInstanceName = getRootCertInstanceName(commonTlsContext);
    -    if (rootCaInstanceName == null) {
    -      if (!server) {
    -        throw new ResourceInvalidException(
    -            "ca_certificate_provider_instance is required in upstream-tls-context");
    -      }
    -    } else {
    -      if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) {
    -        throw new ResourceInvalidException(
    -                "ca_certificate_provider_instance name '" + rootCaInstanceName
    -                        + "' not defined in the bootstrap file.");
    -      }
    -      CertificateValidationContext certificateValidationContext = null;
    -      if (commonTlsContext.hasValidationContext()) {
    -        certificateValidationContext = commonTlsContext.getValidationContext();
    -      } else if (commonTlsContext.hasCombinedValidationContext() && commonTlsContext
    -          .getCombinedValidationContext().hasDefaultValidationContext()) {
    -        certificateValidationContext = commonTlsContext.getCombinedValidationContext()
    -            .getDefaultValidationContext();
    -      }
    -      if (certificateValidationContext != null) {
    -        if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) {
    -          throw new ResourceInvalidException(
    -              "match_subject_alt_names only allowed in upstream_tls_context");
    -        }
    -        if (certificateValidationContext.getVerifyCertificateSpkiCount() > 0) {
    -          throw new ResourceInvalidException(
    -              "verify_certificate_spki in default_validation_context is not supported");
    -        }
    -        if (certificateValidationContext.getVerifyCertificateHashCount() > 0) {
    -          throw new ResourceInvalidException(
    -              "verify_certificate_hash in default_validation_context is not supported");
    -        }
    -        if (certificateValidationContext.hasRequireSignedCertificateTimestamp()) {
    -          throw new ResourceInvalidException(
    -              "require_signed_certificate_timestamp in default_validation_context is not "
    -                  + "supported");
    -        }
    -        if (certificateValidationContext.hasCrl()) {
    -          throw new ResourceInvalidException("crl in default_validation_context is not supported");
    -        }
    -        if (certificateValidationContext.hasCustomValidatorConfig()) {
    -          throw new ResourceInvalidException(
    -              "custom_validator_config in default_validation_context is not supported");
    -        }
    -      }
    -    }
    -  }
    -
    -  private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) {
    -    if (commonTlsContext.hasTlsCertificateProviderInstance()) {
    -      return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName();
    -    } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
    -      return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName();
    -    }
    -    return null;
    -  }
    -
    -  private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) {
    -    if (commonTlsContext.hasValidationContext()) {
    -      if (commonTlsContext.getValidationContext().hasCaCertificateProviderInstance()) {
    -        return commonTlsContext.getValidationContext().getCaCertificateProviderInstance()
    -            .getInstanceName();
    -      }
    -    } else if (commonTlsContext.hasCombinedValidationContext()) {
    -      CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext
    -          = commonTlsContext.getCombinedValidationContext();
    -      if (combinedCertificateValidationContext.hasDefaultValidationContext()
    -          && combinedCertificateValidationContext.getDefaultValidationContext()
    -          .hasCaCertificateProviderInstance()) {
    -        return combinedCertificateValidationContext.getDefaultValidationContext()
    -            .getCaCertificateProviderInstance().getInstanceName();
    -      } else if (combinedCertificateValidationContext
    -          .hasValidationContextCertificateProviderInstance()) {
    -        return combinedCertificateValidationContext
    -            .getValidationContextCertificateProviderInstance().getInstanceName();
    -      }
    -    }
    -    return null;
    -  }
    -
    -  private static void checkForUniqueness(Set uniqueSet,
    -      FilterChainMatch filterChainMatch) throws ResourceInvalidException {
    -    if (uniqueSet != null) {
    -      List crossProduct = getCrossProduct(filterChainMatch);
    -      for (FilterChainMatch cur : crossProduct) {
    -        if (!uniqueSet.add(cur)) {
    -          throw new ResourceInvalidException("FilterChainMatch must be unique. "
    -              + "Found duplicate: " + cur);
    -        }
    -      }
    -    }
    -  }
    -
    -  private static List getCrossProduct(FilterChainMatch filterChainMatch) {
    -    // repeating fields to process:
    -    // prefixRanges, applicationProtocols, sourcePrefixRanges, sourcePorts, serverNames
    -    List expandedList = expandOnPrefixRange(filterChainMatch);
    -    expandedList = expandOnApplicationProtocols(expandedList);
    -    expandedList = expandOnSourcePrefixRange(expandedList);
    -    expandedList = expandOnSourcePorts(expandedList);
    -    return expandOnServerNames(expandedList);
    -  }
    -
    -  private static List expandOnPrefixRange(FilterChainMatch filterChainMatch) {
    -    ArrayList expandedList = new ArrayList<>();
    -    if (filterChainMatch.prefixRanges().isEmpty()) {
    -      expandedList.add(filterChainMatch);
    -    } else {
    -      for (EnvoyServerProtoData.CidrRange cidrRange : filterChainMatch.prefixRanges()) {
    -        expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(),
    -            ImmutableList.of(cidrRange),
    -            filterChainMatch.applicationProtocols(),
    -            filterChainMatch.sourcePrefixRanges(),
    -            filterChainMatch.connectionSourceType(),
    -            filterChainMatch.sourcePorts(),
    -            filterChainMatch.serverNames(),
    -            filterChainMatch.transportProtocol()));
    -      }
    -    }
    -    return expandedList;
    -  }
    -
    -  private static List expandOnApplicationProtocols(
    -      Collection set) {
    -    ArrayList expandedList = new ArrayList<>();
    -    for (FilterChainMatch filterChainMatch : set) {
    -      if (filterChainMatch.applicationProtocols().isEmpty()) {
    -        expandedList.add(filterChainMatch);
    -      } else {
    -        for (String applicationProtocol : filterChainMatch.applicationProtocols()) {
    -          expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(),
    -              filterChainMatch.prefixRanges(),
    -              ImmutableList.of(applicationProtocol),
    -              filterChainMatch.sourcePrefixRanges(),
    -              filterChainMatch.connectionSourceType(),
    -              filterChainMatch.sourcePorts(),
    -              filterChainMatch.serverNames(),
    -              filterChainMatch.transportProtocol()));
    -        }
    -      }
    -    }
    -    return expandedList;
    -  }
    -
    -  private static List expandOnSourcePrefixRange(
    -      Collection set) {
    -    ArrayList expandedList = new ArrayList<>();
    -    for (FilterChainMatch filterChainMatch : set) {
    -      if (filterChainMatch.sourcePrefixRanges().isEmpty()) {
    -        expandedList.add(filterChainMatch);
    -      } else {
    -        for (EnvoyServerProtoData.CidrRange cidrRange : filterChainMatch.sourcePrefixRanges()) {
    -          expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(),
    -              filterChainMatch.prefixRanges(),
    -              filterChainMatch.applicationProtocols(),
    -              ImmutableList.of(cidrRange),
    -              filterChainMatch.connectionSourceType(),
    -              filterChainMatch.sourcePorts(),
    -              filterChainMatch.serverNames(),
    -              filterChainMatch.transportProtocol()));
    -        }
    -      }
    -    }
    -    return expandedList;
    -  }
    -
    -  private static List expandOnSourcePorts(Collection set) {
    -    ArrayList expandedList = new ArrayList<>();
    -    for (FilterChainMatch filterChainMatch : set) {
    -      if (filterChainMatch.sourcePorts().isEmpty()) {
    -        expandedList.add(filterChainMatch);
    -      } else {
    -        for (Integer sourcePort : filterChainMatch.sourcePorts()) {
    -          expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(),
    -              filterChainMatch.prefixRanges(),
    -              filterChainMatch.applicationProtocols(),
    -              filterChainMatch.sourcePrefixRanges(),
    -              filterChainMatch.connectionSourceType(),
    -              ImmutableList.of(sourcePort),
    -              filterChainMatch.serverNames(),
    -              filterChainMatch.transportProtocol()));
    -        }
    -      }
    -    }
    -    return expandedList;
    -  }
    -
    -  private static List expandOnServerNames(Collection set) {
    -    ArrayList expandedList = new ArrayList<>();
    -    for (FilterChainMatch filterChainMatch : set) {
    -      if (filterChainMatch.serverNames().isEmpty()) {
    -        expandedList.add(filterChainMatch);
    -      } else {
    -        for (String serverName : filterChainMatch.serverNames()) {
    -          expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(),
    -              filterChainMatch.prefixRanges(),
    -              filterChainMatch.applicationProtocols(),
    -              filterChainMatch.sourcePrefixRanges(),
    -              filterChainMatch.connectionSourceType(),
    -              filterChainMatch.sourcePorts(),
    -              ImmutableList.of(serverName),
    -              filterChainMatch.transportProtocol()));
    -        }
    -      }
    -    }
    -    return expandedList;
    -  }
    -
    -  private static FilterChainMatch parseFilterChainMatch(
    -      io.envoyproxy.envoy.config.listener.v3.FilterChainMatch proto)
    -      throws ResourceInvalidException {
    -    ImmutableList.Builder prefixRanges = ImmutableList.builder();
    -    ImmutableList.Builder sourcePrefixRanges = ImmutableList.builder();
    -    try {
    -      for (io.envoyproxy.envoy.config.core.v3.CidrRange range : proto.getPrefixRangesList()) {
    -        prefixRanges.add(
    -            CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue()));
    -      }
    -      for (io.envoyproxy.envoy.config.core.v3.CidrRange range
    -          : proto.getSourcePrefixRangesList()) {
    -        sourcePrefixRanges.add(
    -            CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue()));
    -      }
    -    } catch (UnknownHostException e) {
    -      throw new ResourceInvalidException("Failed to create CidrRange", e);
    -    }
    -    ConnectionSourceType sourceType;
    -    switch (proto.getSourceType()) {
    -      case ANY:
    -        sourceType = ConnectionSourceType.ANY;
    -        break;
    -      case EXTERNAL:
    -        sourceType = ConnectionSourceType.EXTERNAL;
    -        break;
    -      case SAME_IP_OR_LOOPBACK:
    -        sourceType = ConnectionSourceType.SAME_IP_OR_LOOPBACK;
    -        break;
    -      default:
    -        throw new ResourceInvalidException("Unknown source-type: " + proto.getSourceType());
    -    }
    -    return FilterChainMatch.create(
    -        proto.getDestinationPort().getValue(),
    -        prefixRanges.build(),
    -        ImmutableList.copyOf(proto.getApplicationProtocolsList()),
    -        sourcePrefixRanges.build(),
    -        sourceType,
    -        ImmutableList.copyOf(proto.getSourcePortsList()),
    -        ImmutableList.copyOf(proto.getServerNamesList()),
    -        proto.getTransportProtocol());
    -  }
    -
    -  @VisibleForTesting
    -  static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager(
    -      HttpConnectionManager proto, Set rdsResources, FilterRegistry filterRegistry,
    -      boolean parseHttpFilter, boolean isForClient) throws ResourceInvalidException {
    -    if (enableRbac && proto.getXffNumTrustedHops() != 0) {
    -      throw new ResourceInvalidException(
    -          "HttpConnectionManager with xff_num_trusted_hops unsupported");
    -    }
    -    if (enableRbac && !proto.getOriginalIpDetectionExtensionsList().isEmpty()) {
    -      throw new ResourceInvalidException("HttpConnectionManager with "
    -          + "original_ip_detection_extensions unsupported");
    -    }
    -    // Obtain max_stream_duration from Http Protocol Options.
    -    long maxStreamDuration = 0;
    -    if (proto.hasCommonHttpProtocolOptions()) {
    -      HttpProtocolOptions options = proto.getCommonHttpProtocolOptions();
    -      if (options.hasMaxStreamDuration()) {
    -        maxStreamDuration = Durations.toNanos(options.getMaxStreamDuration());
    -      }
    -    }
    -
    -    // Parse http filters.
    -    List filterConfigs = null;
    -    if (parseHttpFilter) {
    -      if (proto.getHttpFiltersList().isEmpty()) {
    -        throw new ResourceInvalidException("Missing HttpFilter in HttpConnectionManager.");
    -      }
    -      filterConfigs = new ArrayList<>();
    -      Set names = new HashSet<>();
    -      for (int i = 0; i < proto.getHttpFiltersCount(); i++) {
    -        io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter
    -                httpFilter = proto.getHttpFiltersList().get(i);
    -        String filterName = httpFilter.getName();
    -        if (!names.add(filterName)) {
    -          throw new ResourceInvalidException(
    -              "HttpConnectionManager contains duplicate HttpFilter: " + filterName);
    -        }
    -        StructOrError filterConfig =
    -            parseHttpFilter(httpFilter, filterRegistry, isForClient);
    -        if ((i == proto.getHttpFiltersCount() - 1)
    -                && (filterConfig == null || !isTerminalFilter(filterConfig.struct))) {
    -          throw new ResourceInvalidException("The last HttpFilter must be a terminal filter: "
    -                  + filterName);
    -        }
    -        if (filterConfig == null) {
    -          continue;
    -        }
    -        if (filterConfig.getErrorDetail() != null) {
    -          throw new ResourceInvalidException(
    -              "HttpConnectionManager contains invalid HttpFilter: "
    -                  + filterConfig.getErrorDetail());
    -        }
    -        if ((i < proto.getHttpFiltersCount() - 1) && isTerminalFilter(filterConfig.getStruct())) {
    -          throw new ResourceInvalidException("A terminal HttpFilter must be the last filter: "
    -                  + filterName);
    -        }
    -        filterConfigs.add(new NamedFilterConfig(filterName, filterConfig.struct));
    -      }
    -    }
    -
    -    // Parse inlined RouteConfiguration or RDS.
    -    if (proto.hasRouteConfig()) {
    -      List virtualHosts = extractVirtualHosts(
    -          proto.getRouteConfig(), filterRegistry, parseHttpFilter);
    -      return io.grpc.xds.HttpConnectionManager.forVirtualHosts(
    -          maxStreamDuration, virtualHosts, filterConfigs);
    -    }
    -    if (proto.hasRds()) {
    -      Rds rds = proto.getRds();
    -      if (!rds.hasConfigSource()) {
    -        throw new ResourceInvalidException(
    -            "HttpConnectionManager contains invalid RDS: missing config_source");
    -      }
    -      if (!rds.getConfigSource().hasAds() && !rds.getConfigSource().hasSelf()) {
    -        throw new ResourceInvalidException(
    -            "HttpConnectionManager contains invalid RDS: must specify ADS or self ConfigSource");
    -      }
    -      // Collect the RDS resource referenced by this HttpConnectionManager.
    -      rdsResources.add(rds.getRouteConfigName());
    -      return io.grpc.xds.HttpConnectionManager.forRdsName(
    -          maxStreamDuration, rds.getRouteConfigName(), filterConfigs);
    -    }
    -    throw new ResourceInvalidException(
    -        "HttpConnectionManager neither has inlined route_config nor RDS");
    -  }
    -
    -  // hard-coded: currently router config is the only terminal filter.
    -  private static boolean isTerminalFilter(FilterConfig filterConfig) {
    -    return RouterFilter.ROUTER_CONFIG.equals(filterConfig);
    -  }
    -
    -  @VisibleForTesting
    -  @Nullable // Returns null if the filter is optional but not supported.
    -  static StructOrError parseHttpFilter(
    -      io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter
    -          httpFilter, FilterRegistry filterRegistry, boolean isForClient) {
    -    String filterName = httpFilter.getName();
    -    boolean isOptional = httpFilter.getIsOptional();
    -    if (!httpFilter.hasTypedConfig()) {
    -      if (isOptional) {
    -        return null;
    -      } else {
    -        return StructOrError.fromError(
    -            "HttpFilter [" + filterName + "] is not optional and has no typed config");
    -      }
    -    }
    -    Message rawConfig = httpFilter.getTypedConfig();
    -    String typeUrl = httpFilter.getTypedConfig().getTypeUrl();
    -
    -    try {
    -      if (typeUrl.equals(TYPE_URL_TYPED_STRUCT_UDPA)) {
    -        TypedStruct typedStruct = httpFilter.getTypedConfig().unpack(TypedStruct.class);
    -        typeUrl = typedStruct.getTypeUrl();
    -        rawConfig = typedStruct.getValue();
    -      } else if (typeUrl.equals(TYPE_URL_TYPED_STRUCT)) {
    -        com.github.xds.type.v3.TypedStruct newTypedStruct =
    -            httpFilter.getTypedConfig().unpack(com.github.xds.type.v3.TypedStruct.class);
    -        typeUrl = newTypedStruct.getTypeUrl();
    -        rawConfig = newTypedStruct.getValue();
    -      }
    -    } catch (InvalidProtocolBufferException e) {
    -      return StructOrError.fromError(
    -          "HttpFilter [" + filterName + "] contains invalid proto: " + e);
    -    }
    -    Filter filter = filterRegistry.get(typeUrl);
    -    if ((isForClient && !(filter instanceof ClientInterceptorBuilder))
    -        || (!isForClient && !(filter instanceof ServerInterceptorBuilder))) {
    -      if (isOptional) {
    -        return null;
    -      } else {
    -        return StructOrError.fromError(
    -            "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for "
    -                + (isForClient ? "client" : "server"));
    -      }
    -    }
    -    ConfigOrError filterConfig = filter.parseFilterConfig(rawConfig);
    -    if (filterConfig.errorDetail != null) {
    -      return StructOrError.fromError(
    -          "Invalid filter config for HttpFilter [" + filterName + "]: " + filterConfig.errorDetail);
    -    }
    -    return StructOrError.fromStruct(filterConfig.config);
    -  }
    -
    -  private static StructOrError parseVirtualHost(
    -      io.envoyproxy.envoy.config.route.v3.VirtualHost proto, FilterRegistry filterRegistry,
    -      boolean parseHttpFilter, Map pluginConfigMap) {
    -    String name = proto.getName();
    -    List routes = new ArrayList<>(proto.getRoutesCount());
    -    for (io.envoyproxy.envoy.config.route.v3.Route routeProto : proto.getRoutesList()) {
    -      StructOrError route = parseRoute(
    -          routeProto, filterRegistry, parseHttpFilter, pluginConfigMap);
    -      if (route == null) {
    -        continue;
    -      }
    -      if (route.getErrorDetail() != null) {
    -        return StructOrError.fromError(
    -            "Virtual host [" + name + "] contains invalid route : " + route.getErrorDetail());
    -      }
    -      routes.add(route.getStruct());
    -    }
    -    if (!parseHttpFilter) {
    -      return StructOrError.fromStruct(VirtualHost.create(
    -          name, proto.getDomainsList(), routes, new HashMap()));
    -    }
    -    StructOrError> overrideConfigs =
    -        parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry);
    -    if (overrideConfigs.errorDetail != null) {
    -      return StructOrError.fromError(
    -          "VirtualHost [" + proto.getName() + "] contains invalid HttpFilter config: "
    -              + overrideConfigs.errorDetail);
    -    }
    -    return StructOrError.fromStruct(VirtualHost.create(
    -        name, proto.getDomainsList(), routes, overrideConfigs.struct));
    -  }
    -
    -  @VisibleForTesting
    -  static StructOrError> parseOverrideFilterConfigs(
    -      Map rawFilterConfigMap, FilterRegistry filterRegistry) {
    -    Map overrideConfigs = new HashMap<>();
    -    for (String name : rawFilterConfigMap.keySet()) {
    -      Any anyConfig = rawFilterConfigMap.get(name);
    -      String typeUrl = anyConfig.getTypeUrl();
    -      boolean isOptional = false;
    -      if (typeUrl.equals(TYPE_URL_FILTER_CONFIG)) {
    -        io.envoyproxy.envoy.config.route.v3.FilterConfig filterConfig;
    -        try {
    -          filterConfig =
    -              anyConfig.unpack(io.envoyproxy.envoy.config.route.v3.FilterConfig.class);
    -        } catch (InvalidProtocolBufferException e) {
    -          return StructOrError.fromError(
    -              "FilterConfig [" + name + "] contains invalid proto: " + e);
    -        }
    -        isOptional = filterConfig.getIsOptional();
    -        anyConfig = filterConfig.getConfig();
    -        typeUrl = anyConfig.getTypeUrl();
    -      }
    -      Message rawConfig = anyConfig;
    -      try {
    -        if (typeUrl.equals(TYPE_URL_TYPED_STRUCT_UDPA)) {
    -          TypedStruct typedStruct = anyConfig.unpack(TypedStruct.class);
    -          typeUrl = typedStruct.getTypeUrl();
    -          rawConfig = typedStruct.getValue();
    -        } else if (typeUrl.equals(TYPE_URL_TYPED_STRUCT)) {
    -          com.github.xds.type.v3.TypedStruct newTypedStruct =
    -              anyConfig.unpack(com.github.xds.type.v3.TypedStruct.class);
    -          typeUrl = newTypedStruct.getTypeUrl();
    -          rawConfig = newTypedStruct.getValue();
    -        }
    -      } catch (InvalidProtocolBufferException e) {
    -        return StructOrError.fromError(
    -            "FilterConfig [" + name + "] contains invalid proto: " + e);
    -      }
    -      Filter filter = filterRegistry.get(typeUrl);
    -      if (filter == null) {
    -        if (isOptional) {
    -          continue;
    -        }
    -        return StructOrError.fromError(
    -            "HttpFilter [" + name + "](" + typeUrl + ") is required but unsupported");
    -      }
    -      ConfigOrError filterConfig =
    -          filter.parseFilterConfigOverride(rawConfig);
    -      if (filterConfig.errorDetail != null) {
    -        return StructOrError.fromError(
    -            "Invalid filter config for HttpFilter [" + name + "]: " + filterConfig.errorDetail);
    -      }
    -      overrideConfigs.put(name, filterConfig.config);
    -    }
    -    return StructOrError.fromStruct(overrideConfigs);
    -  }
    -
    -  @VisibleForTesting
    -  @Nullable
    -  static StructOrError parseRoute(
    -      io.envoyproxy.envoy.config.route.v3.Route proto, FilterRegistry filterRegistry,
    -      boolean parseHttpFilter, Map pluginConfigMap) {
    -    StructOrError routeMatch = parseRouteMatch(proto.getMatch());
    -    if (routeMatch == null) {
    -      return null;
    -    }
    -    if (routeMatch.getErrorDetail() != null) {
    -      return StructOrError.fromError(
    -          "Route [" + proto.getName() + "] contains invalid RouteMatch: "
    -              + routeMatch.getErrorDetail());
    -    }
    -
    -    Map overrideConfigs = Collections.emptyMap();
    -    if (parseHttpFilter) {
    -      StructOrError> overrideConfigsOrError =
    -          parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry);
    -      if (overrideConfigsOrError.errorDetail != null) {
    -        return StructOrError.fromError(
    -            "Route [" + proto.getName() + "] contains invalid HttpFilter config: "
    -                + overrideConfigsOrError.errorDetail);
    -      }
    -      overrideConfigs = overrideConfigsOrError.struct;
    -    }
    -
    -    switch (proto.getActionCase()) {
    -      case ROUTE:
    -        StructOrError routeAction =
    -            parseRouteAction(proto.getRoute(), filterRegistry, parseHttpFilter, pluginConfigMap);
    -        if (routeAction == null) {
    -          return null;
    -        }
    -        if (routeAction.errorDetail != null) {
    -          return StructOrError.fromError(
    -              "Route [" + proto.getName() + "] contains invalid RouteAction: "
    -                  + routeAction.getErrorDetail());
    -        }
    -        return StructOrError.fromStruct(
    -            Route.forAction(routeMatch.struct, routeAction.struct, overrideConfigs));
    -      case NON_FORWARDING_ACTION:
    -        return StructOrError.fromStruct(
    -            Route.forNonForwardingAction(routeMatch.struct, overrideConfigs));
    -      case REDIRECT:
    -      case DIRECT_RESPONSE:
    -      case FILTER_ACTION:
    -      case ACTION_NOT_SET:
    -      default:
    -        return StructOrError.fromError(
    -            "Route [" + proto.getName() + "] with unknown action type: " + proto.getActionCase());
    -    }
    -  }
    -
    -  @VisibleForTesting
    -  @Nullable
    -  static StructOrError parseRouteMatch(
    -      io.envoyproxy.envoy.config.route.v3.RouteMatch proto) {
    -    if (proto.getQueryParametersCount() != 0) {
    -      return null;
    -    }
    -    StructOrError pathMatch = parsePathMatcher(proto);
    -    if (pathMatch.getErrorDetail() != null) {
    -      return StructOrError.fromError(pathMatch.getErrorDetail());
    -    }
    -
    -    FractionMatcher fractionMatch = null;
    -    if (proto.hasRuntimeFraction()) {
    -      StructOrError parsedFraction =
    -          parseFractionMatcher(proto.getRuntimeFraction().getDefaultValue());
    -      if (parsedFraction.getErrorDetail() != null) {
    -        return StructOrError.fromError(parsedFraction.getErrorDetail());
    -      }
    -      fractionMatch = parsedFraction.getStruct();
    -    }
    -
    -    List headerMatchers = new ArrayList<>();
    -    for (io.envoyproxy.envoy.config.route.v3.HeaderMatcher hmProto : proto.getHeadersList()) {
    -      StructOrError headerMatcher = parseHeaderMatcher(hmProto);
    -      if (headerMatcher.getErrorDetail() != null) {
    -        return StructOrError.fromError(headerMatcher.getErrorDetail());
    -      }
    -      headerMatchers.add(headerMatcher.getStruct());
    -    }
    -
    -    return StructOrError.fromStruct(RouteMatch.create(
    -        pathMatch.getStruct(), headerMatchers, fractionMatch));
    -  }
    -
    -  @VisibleForTesting
    -  static StructOrError parsePathMatcher(
    -      io.envoyproxy.envoy.config.route.v3.RouteMatch proto) {
    -    boolean caseSensitive = proto.getCaseSensitive().getValue();
    -    switch (proto.getPathSpecifierCase()) {
    -      case PREFIX:
    -        return StructOrError.fromStruct(
    -            PathMatcher.fromPrefix(proto.getPrefix(), caseSensitive));
    -      case PATH:
    -        return StructOrError.fromStruct(PathMatcher.fromPath(proto.getPath(), caseSensitive));
    -      case SAFE_REGEX:
    -        String rawPattern = proto.getSafeRegex().getRegex();
    -        Pattern safeRegEx;
    -        try {
    -          safeRegEx = Pattern.compile(rawPattern);
    -        } catch (PatternSyntaxException e) {
    -          return StructOrError.fromError("Malformed safe regex pattern: " + e.getMessage());
    -        }
    -        return StructOrError.fromStruct(PathMatcher.fromRegEx(safeRegEx));
    -      case PATHSPECIFIER_NOT_SET:
    -      default:
    -        return StructOrError.fromError("Unknown path match type");
    -    }
    -  }
    -
    -  private static StructOrError parseFractionMatcher(FractionalPercent proto) {
    -    int numerator = proto.getNumerator();
    -    int denominator = 0;
    -    switch (proto.getDenominator()) {
    -      case HUNDRED:
    -        denominator = 100;
    -        break;
    -      case TEN_THOUSAND:
    -        denominator = 10_000;
    -        break;
    -      case MILLION:
    -        denominator = 1_000_000;
    -        break;
    -      case UNRECOGNIZED:
    -      default:
    -        return StructOrError.fromError(
    -            "Unrecognized fractional percent denominator: " + proto.getDenominator());
    -    }
    -    return StructOrError.fromStruct(FractionMatcher.create(numerator, denominator));
    -  }
    -
    -  @VisibleForTesting
    -  static StructOrError parseHeaderMatcher(
    -      io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) {
    -    switch (proto.getHeaderMatchSpecifierCase()) {
    -      case EXACT_MATCH:
    -        return StructOrError.fromStruct(HeaderMatcher.forExactValue(
    -            proto.getName(), proto.getExactMatch(), proto.getInvertMatch()));
    -      case SAFE_REGEX_MATCH:
    -        String rawPattern = proto.getSafeRegexMatch().getRegex();
    -        Pattern safeRegExMatch;
    -        try {
    -          safeRegExMatch = Pattern.compile(rawPattern);
    -        } catch (PatternSyntaxException e) {
    -          return StructOrError.fromError(
    -              "HeaderMatcher [" + proto.getName() + "] contains malformed safe regex pattern: "
    -                  + e.getMessage());
    -        }
    -        return StructOrError.fromStruct(HeaderMatcher.forSafeRegEx(
    -            proto.getName(), safeRegExMatch, proto.getInvertMatch()));
    -      case RANGE_MATCH:
    -        HeaderMatcher.Range rangeMatch = HeaderMatcher.Range.create(
    -            proto.getRangeMatch().getStart(), proto.getRangeMatch().getEnd());
    -        return StructOrError.fromStruct(HeaderMatcher.forRange(
    -            proto.getName(), rangeMatch, proto.getInvertMatch()));
    -      case PRESENT_MATCH:
    -        return StructOrError.fromStruct(HeaderMatcher.forPresent(
    -            proto.getName(), proto.getPresentMatch(), proto.getInvertMatch()));
    -      case PREFIX_MATCH:
    -        return StructOrError.fromStruct(HeaderMatcher.forPrefix(
    -            proto.getName(), proto.getPrefixMatch(), proto.getInvertMatch()));
    -      case SUFFIX_MATCH:
    -        return StructOrError.fromStruct(HeaderMatcher.forSuffix(
    -            proto.getName(), proto.getSuffixMatch(), proto.getInvertMatch()));
    -      case HEADERMATCHSPECIFIER_NOT_SET:
    -      default:
    -        return StructOrError.fromError("Unknown header matcher type");
    -    }
    -  }
    -
    -  /**
    -   * Parses the RouteAction config. The returned result may contain a (parsed form)
    -   * {@link RouteAction} or an error message. Returns {@code null} if the RouteAction
    -   * should be ignored.
    -   */
    -  @VisibleForTesting
    -  @Nullable
    -  static StructOrError parseRouteAction(
    -      io.envoyproxy.envoy.config.route.v3.RouteAction proto, FilterRegistry filterRegistry,
    -      boolean parseHttpFilter, Map pluginConfigMap) {
    -    Long timeoutNano = null;
    -    if (proto.hasMaxStreamDuration()) {
    -      io.envoyproxy.envoy.config.route.v3.RouteAction.MaxStreamDuration maxStreamDuration
    -          = proto.getMaxStreamDuration();
    -      if (maxStreamDuration.hasGrpcTimeoutHeaderMax()) {
    -        timeoutNano = Durations.toNanos(maxStreamDuration.getGrpcTimeoutHeaderMax());
    -      } else if (maxStreamDuration.hasMaxStreamDuration()) {
    -        timeoutNano = Durations.toNanos(maxStreamDuration.getMaxStreamDuration());
    -      }
    -    }
    -    RetryPolicy retryPolicy = null;
    -    if (enableRetry && proto.hasRetryPolicy()) {
    -      StructOrError retryPolicyOrError = parseRetryPolicy(proto.getRetryPolicy());
    -      if (retryPolicyOrError != null) {
    -        if (retryPolicyOrError.errorDetail != null) {
    -          return StructOrError.fromError(retryPolicyOrError.errorDetail);
    -        }
    -        retryPolicy = retryPolicyOrError.struct;
    -      }
    -    }
    -    List hashPolicies = new ArrayList<>();
    -    for (io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy config
    -        : proto.getHashPolicyList()) {
    -      HashPolicy policy = null;
    -      boolean terminal = config.getTerminal();
    -      switch (config.getPolicySpecifierCase()) {
    -        case HEADER:
    -          io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.Header headerCfg =
    -              config.getHeader();
    -          Pattern regEx = null;
    -          String regExSubstitute = null;
    -          if (headerCfg.hasRegexRewrite() && headerCfg.getRegexRewrite().hasPattern()
    -              && headerCfg.getRegexRewrite().getPattern().hasGoogleRe2()) {
    -            regEx = Pattern.compile(headerCfg.getRegexRewrite().getPattern().getRegex());
    -            regExSubstitute = headerCfg.getRegexRewrite().getSubstitution();
    -          }
    -          policy = HashPolicy.forHeader(
    -              terminal, headerCfg.getHeaderName(), regEx, regExSubstitute);
    -          break;
    -        case FILTER_STATE:
    -          if (config.getFilterState().getKey().equals(HASH_POLICY_FILTER_STATE_KEY)) {
    -            policy = HashPolicy.forChannelId(terminal);
    -          }
    -          break;
    -        default:
    -          // Ignore
    -      }
    -      if (policy != null) {
    -        hashPolicies.add(policy);
    -      }
    -    }
    -
    -    switch (proto.getClusterSpecifierCase()) {
    -      case CLUSTER:
    -        return StructOrError.fromStruct(RouteAction.forCluster(
    -            proto.getCluster(), hashPolicies, timeoutNano, retryPolicy));
    -      case CLUSTER_HEADER:
    -        return null;
    -      case WEIGHTED_CLUSTERS:
    -        List clusterWeights
    -            = proto.getWeightedClusters().getClustersList();
    -        if (clusterWeights.isEmpty()) {
    -          return StructOrError.fromError("No cluster found in weighted cluster list");
    -        }
    -        List weightedClusters = new ArrayList<>();
    -        for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
    -            : clusterWeights) {
    -          StructOrError clusterWeightOrError =
    -              parseClusterWeight(clusterWeight, filterRegistry, parseHttpFilter);
    -          if (clusterWeightOrError.getErrorDetail() != null) {
    -            return StructOrError.fromError("RouteAction contains invalid ClusterWeight: "
    -                + clusterWeightOrError.getErrorDetail());
    -          }
    -          weightedClusters.add(clusterWeightOrError.getStruct());
    -        }
    -        // TODO(chengyuanzhang): validate if the sum of weights equals to total weight.
    -        return StructOrError.fromStruct(RouteAction.forWeightedClusters(
    -            weightedClusters, hashPolicies, timeoutNano, retryPolicy));
    -      case CLUSTER_SPECIFIER_PLUGIN:
    -        if (enableRouteLookup) {
    -          String pluginName = proto.getClusterSpecifierPlugin();
    -          PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
    -          if (pluginConfig == null) {
    -            return StructOrError.fromError(
    -                "ClusterSpecifierPlugin for [" + pluginName + "] not found");
    -          }
    -          NamedPluginConfig namedPluginConfig = NamedPluginConfig.create(pluginName, pluginConfig);
    -          return StructOrError.fromStruct(RouteAction.forClusterSpecifierPlugin(
    -              namedPluginConfig, hashPolicies, timeoutNano, retryPolicy));
    -        } else {
    -          return StructOrError.fromError("Support for ClusterSpecifierPlugin not enabled");
    -        }
    -      case CLUSTERSPECIFIER_NOT_SET:
    -      default:
    -        return StructOrError.fromError(
    -            "Unknown cluster specifier: " + proto.getClusterSpecifierCase());
    -    }
    -  }
    -
    -  @Nullable // Return null if we ignore the given policy.
    -  private static StructOrError parseRetryPolicy(
    -      io.envoyproxy.envoy.config.route.v3.RetryPolicy retryPolicyProto) {
    -    int maxAttempts = 2;
    -    if (retryPolicyProto.hasNumRetries()) {
    -      maxAttempts = retryPolicyProto.getNumRetries().getValue() + 1;
    -    }
    -    Duration initialBackoff = Durations.fromMillis(25);
    -    Duration maxBackoff = Durations.fromMillis(250);
    -    if (retryPolicyProto.hasRetryBackOff()) {
    -      RetryBackOff retryBackOff = retryPolicyProto.getRetryBackOff();
    -      if (!retryBackOff.hasBaseInterval()) {
    -        return StructOrError.fromError("No base_interval specified in retry_backoff");
    -      }
    -      Duration originalInitialBackoff = initialBackoff = retryBackOff.getBaseInterval();
    -      if (Durations.compare(initialBackoff, Durations.ZERO) <= 0) {
    -        return StructOrError.fromError("base_interval in retry_backoff must be positive");
    -      }
    -      if (Durations.compare(initialBackoff, Durations.fromMillis(1)) < 0) {
    -        initialBackoff = Durations.fromMillis(1);
    -      }
    -      if (retryBackOff.hasMaxInterval()) {
    -        maxBackoff = retryPolicyProto.getRetryBackOff().getMaxInterval();
    -        if (Durations.compare(maxBackoff, originalInitialBackoff) < 0) {
    -          return StructOrError.fromError(
    -              "max_interval in retry_backoff cannot be less than base_interval");
    -        }
    -        if (Durations.compare(maxBackoff, Durations.fromMillis(1)) < 0) {
    -          maxBackoff = Durations.fromMillis(1);
    -        }
    -      } else {
    -        maxBackoff = Durations.fromNanos(Durations.toNanos(initialBackoff) * 10);
    -      }
    -    }
    -    Iterable retryOns =
    -        Splitter.on(',').omitEmptyStrings().trimResults().split(retryPolicyProto.getRetryOn());
    -    ImmutableList.Builder retryableStatusCodesBuilder = ImmutableList.builder();
    -    for (String retryOn : retryOns) {
    -      Code code;
    -      try {
    -        code = Code.valueOf(retryOn.toUpperCase(Locale.US).replace('-', '_'));
    -      } catch (IllegalArgumentException e) {
    -        // unsupported value, such as "5xx"
    -        continue;
    -      }
    -      if (!SUPPORTED_RETRYABLE_CODES.contains(code)) {
    -        // unsupported value
    -        continue;
    -      }
    -      retryableStatusCodesBuilder.add(code);
    -    }
    -    List retryableStatusCodes = retryableStatusCodesBuilder.build();
    -    return StructOrError.fromStruct(
    -        RetryPolicy.create(
    -            maxAttempts, retryableStatusCodes, initialBackoff, maxBackoff,
    -            /* perAttemptRecvTimeout= */ null));
    -  }
    -
    -  @VisibleForTesting
    -  static StructOrError parseClusterWeight(
    -      io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight proto,
    -      FilterRegistry filterRegistry, boolean parseHttpFilter) {
    -    if (!parseHttpFilter) {
    -      return StructOrError.fromStruct(ClusterWeight.create(
    -          proto.getName(), proto.getWeight().getValue(), new HashMap()));
    -    }
    -    StructOrError> overrideConfigs =
    -        parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry);
    -    if (overrideConfigs.errorDetail != null) {
    -      return StructOrError.fromError(
    -          "ClusterWeight [" + proto.getName() + "] contains invalid HttpFilter config: "
    -              + overrideConfigs.errorDetail);
    -    }
    -    return StructOrError.fromStruct(ClusterWeight.create(
    -        proto.getName(), proto.getWeight().getValue(), overrideConfigs.struct));
    -  }
    -
    -  @Override
    -  public void handleRdsResponse(
    -      ServerInfo serverInfo, String versionInfo, List resources, String nonce) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    Map parsedResources = new HashMap<>(resources.size());
    -    Set unpackedResources = new HashSet<>(resources.size());
    -    Set invalidResources = new HashSet<>();
    -    List errors = new ArrayList<>();
    -
    -    for (int i = 0; i < resources.size(); i++) {
    -      Any resource = resources.get(i);
    -
    -      // Unpack the RouteConfiguration.
    -      RouteConfiguration routeConfig;
    -      try {
    -        routeConfig = unpackCompatibleType(resource, RouteConfiguration.class,
    -            ResourceType.RDS.typeUrl(), ResourceType.RDS.typeUrlV2());
    -      } catch (InvalidProtocolBufferException e) {
    -        errors.add("RDS response Resource index " + i + " - can't decode RouteConfiguration: " + e);
    -        continue;
    -      }
    -      if (!isResourceNameValid(routeConfig.getName(), resource.getTypeUrl())) {
    -        errors.add(
    -            "Unsupported resource name: " + routeConfig.getName() + " for type: "
    -                + ResourceType.RDS);
    -        continue;
    -      }
    -      String routeConfigName = canonifyResourceName(routeConfig.getName());
    -      unpackedResources.add(routeConfigName);
    -
    -      // Process RouteConfiguration into RdsUpdate.
    -      RdsUpdate rdsUpdate;
    -      boolean isResourceV3 = resource.getTypeUrl().equals(ResourceType.RDS.typeUrl());
    -      try {
    -        rdsUpdate = processRouteConfiguration(
    -            routeConfig, filterRegistry, enableFaultInjection && isResourceV3);
    -      } catch (ResourceInvalidException e) {
    -        errors.add(
    -            "RDS response RouteConfiguration '" + routeConfigName + "' validation error: " + e
    -                .getMessage());
    -        invalidResources.add(routeConfigName);
    -        continue;
    -      }
    -
    -      parsedResources.put(routeConfigName, new ParsedResource(rdsUpdate, resource));
    -    }
    -    logger.log(XdsLogLevel.INFO,
    -        "Received RDS Response version {0} nonce {1}. Parsed resources: {2}",
    -        versionInfo, nonce, unpackedResources);
    -    handleResourceUpdate(
    -        serverInfo, ResourceType.RDS, parsedResources, invalidResources,
    -        Collections.emptySet(), versionInfo, nonce, errors);
    -  }
    -
    -  private static RdsUpdate processRouteConfiguration(
    -      RouteConfiguration routeConfig, FilterRegistry filterRegistry, boolean parseHttpFilter)
    -      throws ResourceInvalidException {
    -    return new RdsUpdate(extractVirtualHosts(routeConfig, filterRegistry, parseHttpFilter));
    -  }
    -
    -  private static List extractVirtualHosts(
    -      RouteConfiguration routeConfig, FilterRegistry filterRegistry, boolean parseHttpFilter)
    -      throws ResourceInvalidException {
    -    Map pluginConfigMap = new HashMap<>();
    -    if (enableRouteLookup) {
    -      List plugins = routeConfig.getClusterSpecifierPluginsList();
    -      for (ClusterSpecifierPlugin plugin : plugins) {
    -        PluginConfig existing = pluginConfigMap.put(
    -            plugin.getExtension().getName(), parseClusterSpecifierPlugin(plugin));
    -        if (existing != null) {
    -          throw new ResourceInvalidException(
    -              "Multiple ClusterSpecifierPlugins with the same name: "
    -                  + plugin.getExtension().getName());
    -        }
    -      }
    -    }
    -    List virtualHosts = new ArrayList<>(routeConfig.getVirtualHostsCount());
    -    for (io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHostProto
    -        : routeConfig.getVirtualHostsList()) {
    -      StructOrError virtualHost =
    -          parseVirtualHost(virtualHostProto, filterRegistry, parseHttpFilter, pluginConfigMap);
    -      if (virtualHost.getErrorDetail() != null) {
    -        throw new ResourceInvalidException(
    -            "RouteConfiguration contains invalid virtual host: " + virtualHost.getErrorDetail());
    -      }
    -      virtualHosts.add(virtualHost.getStruct());
    -    }
    -    return virtualHosts;
    -  }
    -
    -  private static PluginConfig parseClusterSpecifierPlugin(ClusterSpecifierPlugin pluginProto)
    -      throws ResourceInvalidException {
    -    return parseClusterSpecifierPlugin(
    -        pluginProto, ClusterSpecifierPluginRegistry.getDefaultRegistry());
    -  }
    -
    -  @VisibleForTesting
    -  static PluginConfig parseClusterSpecifierPlugin(
    -      ClusterSpecifierPlugin pluginProto, ClusterSpecifierPluginRegistry registry)
    -      throws ResourceInvalidException {
    -    TypedExtensionConfig extension = pluginProto.getExtension();
    -    String pluginName = extension.getName();
    -    Any anyConfig = extension.getTypedConfig();
    -    String typeUrl = anyConfig.getTypeUrl();
    -    Message rawConfig = anyConfig;
    -    if (typeUrl.equals(TYPE_URL_TYPED_STRUCT_UDPA) || typeUrl.equals(TYPE_URL_TYPED_STRUCT)) {
    -      try {
    -        TypedStruct typedStruct = unpackCompatibleType(
    -            anyConfig, TypedStruct.class, TYPE_URL_TYPED_STRUCT_UDPA, TYPE_URL_TYPED_STRUCT);
    -        typeUrl = typedStruct.getTypeUrl();
    -        rawConfig = typedStruct.getValue();
    -      } catch (InvalidProtocolBufferException e) {
    -        throw new ResourceInvalidException(
    -            "ClusterSpecifierPlugin [" + pluginName + "] contains invalid proto", e);
    -      }
    -    }
    -    io.grpc.xds.ClusterSpecifierPlugin plugin = registry.get(typeUrl);
    -    if (plugin == null) {
    -      throw new ResourceInvalidException("Unsupported ClusterSpecifierPlugin type: " + typeUrl);
    -    }
    -    ConfigOrError pluginConfigOrError = plugin.parsePlugin(rawConfig);
    -    if (pluginConfigOrError.errorDetail != null) {
    -      throw new ResourceInvalidException(pluginConfigOrError.errorDetail);
    -    }
    -    return pluginConfigOrError.config;
    -  }
    -
    -  @Override
    -  public void handleCdsResponse(
    -      ServerInfo serverInfo, String versionInfo, List resources, String nonce) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    Map parsedResources = new HashMap<>(resources.size());
    -    Set unpackedResources = new HashSet<>(resources.size());
    -    Set invalidResources = new HashSet<>();
    -    List errors = new ArrayList<>();
    -    Set retainedEdsResources = new HashSet<>();
    -
    -    for (int i = 0; i < resources.size(); i++) {
    -      Any resource = resources.get(i);
    -
    -      // Unpack the Cluster.
    -      Cluster cluster;
    -      try {
    -        cluster = unpackCompatibleType(
    -            resource, Cluster.class, ResourceType.CDS.typeUrl(), ResourceType.CDS.typeUrlV2());
    -      } catch (InvalidProtocolBufferException e) {
    -        errors.add("CDS response Resource index " + i + " - can't decode Cluster: " + e);
    -        continue;
    -      }
    -      if (!isResourceNameValid(cluster.getName(), resource.getTypeUrl())) {
    -        errors.add(
    -            "Unsupported resource name: " + cluster.getName() + " for type: " + ResourceType.CDS);
    -        continue;
    -      }
    -      String clusterName = canonifyResourceName(cluster.getName());
    -
    -      // Management server is required to always send newly requested resources, even if they
    -      // may have been sent previously (proactively). Thus, client does not need to cache
    -      // unrequested resources.
    -      if (!cdsResourceSubscribers.containsKey(clusterName)) {
    -        continue;
    -      }
    -      unpackedResources.add(clusterName);
    -
    -      // Process Cluster into CdsUpdate.
    -      CdsUpdate cdsUpdate;
    -      try {
    -        Set certProviderInstances = null;
    -        if (getBootstrapInfo() != null && getBootstrapInfo().certProviders() != null) {
    -          certProviderInstances = getBootstrapInfo().certProviders().keySet();
    -        }
    -        cdsUpdate =
    -            processCluster(cluster, retainedEdsResources, certProviderInstances, serverInfo);
    -      } catch (ResourceInvalidException e) {
    -        errors.add(
    -            "CDS response Cluster '" + clusterName + "' validation error: " + e.getMessage());
    -        invalidResources.add(clusterName);
    -        continue;
    -      }
    -      parsedResources.put(clusterName, new ParsedResource(cdsUpdate, resource));
    -    }
    -    logger.log(XdsLogLevel.INFO,
    -        "Received CDS Response version {0} nonce {1}. Parsed resources: {2}",
    -        versionInfo, nonce, unpackedResources);
    -    handleResourceUpdate(
    -        serverInfo, ResourceType.CDS, parsedResources, invalidResources, retainedEdsResources,
    -        versionInfo, nonce, errors);
    -  }
    -
    -  @VisibleForTesting
    -  static CdsUpdate processCluster(Cluster cluster, Set retainedEdsResources,
    -      Set certProviderInstances, ServerInfo serverInfo)
    -      throws ResourceInvalidException {
    -    StructOrError structOrError;
    -    switch (cluster.getClusterDiscoveryTypeCase()) {
    -      case TYPE:
    -        structOrError = parseNonAggregateCluster(cluster, retainedEdsResources,
    -            certProviderInstances, serverInfo);
    -        break;
    -      case CLUSTER_TYPE:
    -        structOrError = parseAggregateCluster(cluster);
    -        break;
    -      case CLUSTERDISCOVERYTYPE_NOT_SET:
    -      default:
    -        throw new ResourceInvalidException(
    -            "Cluster " + cluster.getName() + ": unspecified cluster discovery type");
    -    }
    -    if (structOrError.getErrorDetail() != null) {
    -      throw new ResourceInvalidException(structOrError.getErrorDetail());
    -    }
    -    CdsUpdate.Builder updateBuilder = structOrError.getStruct();
    -
    -    if (cluster.getLbPolicy() == LbPolicy.RING_HASH) {
    -      RingHashLbConfig lbConfig = cluster.getRingHashLbConfig();
    -      long minRingSize =
    -          lbConfig.hasMinimumRingSize()
    -              ? lbConfig.getMinimumRingSize().getValue()
    -              : DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE;
    -      long maxRingSize =
    -          lbConfig.hasMaximumRingSize()
    -              ? lbConfig.getMaximumRingSize().getValue()
    -              : DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE;
    -      if (lbConfig.getHashFunction() != RingHashLbConfig.HashFunction.XX_HASH
    -          || minRingSize > maxRingSize
    -          || maxRingSize > MAX_RING_HASH_LB_POLICY_RING_SIZE) {
    -        throw new ResourceInvalidException(
    -            "Cluster " + cluster.getName() + ": invalid ring_hash_lb_config: " + lbConfig);
    -      }
    -      updateBuilder.ringHashLbPolicy(minRingSize, maxRingSize);
    -    } else if (cluster.getLbPolicy() == LbPolicy.ROUND_ROBIN) {
    -      updateBuilder.roundRobinLbPolicy();
    -    } else if (enableLeastRequest && cluster.getLbPolicy() == LbPolicy.LEAST_REQUEST) {
    -      LeastRequestLbConfig lbConfig =  cluster.getLeastRequestLbConfig();
    -      int choiceCount =
    -              lbConfig.hasChoiceCount()
    -                ? lbConfig.getChoiceCount().getValue()
    -                : DEFAULT_LEAST_REQUEST_CHOICE_COUNT;
    -      if (choiceCount < DEFAULT_LEAST_REQUEST_CHOICE_COUNT) {
    -        throw new ResourceInvalidException(
    -                "Cluster " + cluster.getName() + ": invalid least_request_lb_config: " + lbConfig);
    -      }
    -      updateBuilder.leastRequestLbPolicy(choiceCount);
    -    } else {
    -      throw new ResourceInvalidException(
    -          "Cluster " + cluster.getName() + ": unsupported lb policy: " + cluster.getLbPolicy());
    -    }
    -
    -    return updateBuilder.build();
    -  }
    -
    -  private static StructOrError parseAggregateCluster(Cluster cluster) {
    -    String clusterName = cluster.getName();
    -    CustomClusterType customType = cluster.getClusterType();
    -    String typeName = customType.getName();
    -    if (!typeName.equals(AGGREGATE_CLUSTER_TYPE_NAME)) {
    -      return StructOrError.fromError(
    -          "Cluster " + clusterName + ": unsupported custom cluster type: " + typeName);
    -    }
    -    io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig clusterConfig;
    -    try {
    -      clusterConfig = unpackCompatibleType(customType.getTypedConfig(),
    -          io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig.class,
    -          TYPE_URL_CLUSTER_CONFIG, TYPE_URL_CLUSTER_CONFIG_V2);
    -    } catch (InvalidProtocolBufferException e) {
    -      return StructOrError.fromError("Cluster " + clusterName + ": malformed ClusterConfig: " + e);
    -    }
    -    return StructOrError.fromStruct(CdsUpdate.forAggregate(
    -        clusterName, clusterConfig.getClustersList()));
    -  }
    -
    -  private static StructOrError parseNonAggregateCluster(
    -      Cluster cluster, Set edsResources, Set certProviderInstances,
    -      ServerInfo serverInfo) {
    -    String clusterName = cluster.getName();
    -    ServerInfo lrsServerInfo = null;
    -    Long maxConcurrentRequests = null;
    -    UpstreamTlsContext upstreamTlsContext = null;
    -    if (cluster.hasLrsServer()) {
    -      if (!cluster.getLrsServer().hasSelf()) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName + ": only support LRS for the same management server");
    -      }
    -      lrsServerInfo = serverInfo;
    -    }
    -    if (cluster.hasCircuitBreakers()) {
    -      List thresholds = cluster.getCircuitBreakers().getThresholdsList();
    -      for (Thresholds threshold : thresholds) {
    -        if (threshold.getPriority() != RoutingPriority.DEFAULT) {
    -          continue;
    -        }
    -        if (threshold.hasMaxRequests()) {
    -          maxConcurrentRequests = (long) threshold.getMaxRequests().getValue();
    -        }
    -      }
    -    }
    -    if (cluster.getTransportSocketMatchesCount() > 0) {
    -      return StructOrError.fromError("Cluster " + clusterName
    -          + ": transport-socket-matches not supported.");
    -    }
    -    if (cluster.hasTransportSocket()) {
    -      if (!TRANSPORT_SOCKET_NAME_TLS.equals(cluster.getTransportSocket().getName())) {
    -        return StructOrError.fromError("transport-socket with name "
    -            + cluster.getTransportSocket().getName() + " not supported.");
    -      }
    -      try {
    -        upstreamTlsContext = UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
    -                validateUpstreamTlsContext(
    -            unpackCompatibleType(cluster.getTransportSocket().getTypedConfig(),
    -                io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext.class,
    -                TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2),
    -                certProviderInstances));
    -      } catch (InvalidProtocolBufferException | ResourceInvalidException e) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName + ": malformed UpstreamTlsContext: " + e);
    -      }
    -    }
    -
    -    DiscoveryType type = cluster.getType();
    -    if (type == DiscoveryType.EDS) {
    -      String edsServiceName = null;
    -      io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig edsClusterConfig =
    -          cluster.getEdsClusterConfig();
    -      if (!edsClusterConfig.getEdsConfig().hasAds()
    -          && ! edsClusterConfig.getEdsConfig().hasSelf()) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName + ": field eds_cluster_config must be set to indicate to use"
    -                + " EDS over ADS or self ConfigSource");
    -      }
    -      // If the service_name field is set, that value will be used for the EDS request.
    -      if (!edsClusterConfig.getServiceName().isEmpty()) {
    -        edsServiceName = edsClusterConfig.getServiceName();
    -        edsResources.add(edsServiceName);
    -      } else {
    -        edsResources.add(clusterName);
    -      }
    -      return StructOrError.fromStruct(CdsUpdate.forEds(
    -          clusterName, edsServiceName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext));
    -    } else if (type.equals(DiscoveryType.LOGICAL_DNS)) {
    -      if (!cluster.hasLoadAssignment()) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName + ": LOGICAL_DNS clusters must have a single host");
    -      }
    -      ClusterLoadAssignment assignment = cluster.getLoadAssignment();
    -      if (assignment.getEndpointsCount() != 1
    -          || assignment.getEndpoints(0).getLbEndpointsCount() != 1) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName + ": LOGICAL_DNS clusters must have a single "
    -                + "locality_lb_endpoint and a single lb_endpoint");
    -      }
    -      io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint lbEndpoint =
    -          assignment.getEndpoints(0).getLbEndpoints(0);
    -      if (!lbEndpoint.hasEndpoint() || !lbEndpoint.getEndpoint().hasAddress()
    -          || !lbEndpoint.getEndpoint().getAddress().hasSocketAddress()) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName
    -                + ": LOGICAL_DNS clusters must have an endpoint with address and socket_address");
    -      }
    -      SocketAddress socketAddress = lbEndpoint.getEndpoint().getAddress().getSocketAddress();
    -      if (!socketAddress.getResolverName().isEmpty()) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName
    -                + ": LOGICAL DNS clusters must NOT have a custom resolver name set");
    -      }
    -      if (socketAddress.getPortSpecifierCase() != PortSpecifierCase.PORT_VALUE) {
    -        return StructOrError.fromError(
    -            "Cluster " + clusterName
    -                + ": LOGICAL DNS clusters socket_address must have port_value");
    -      }
    -      String dnsHostName =
    -          String.format("%s:%d", socketAddress.getAddress(), socketAddress.getPortValue());
    -      return StructOrError.fromStruct(CdsUpdate.forLogicalDns(
    -          clusterName, dnsHostName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext));
    -    }
    -    return StructOrError.fromError(
    -        "Cluster " + clusterName + ": unsupported built-in discovery type: " + type);
    -  }
    -
    -  @Override
    -  public void handleEdsResponse(
    -      ServerInfo serverInfo, String versionInfo, List resources, String nonce) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    Map parsedResources = new HashMap<>(resources.size());
    -    Set unpackedResources = new HashSet<>(resources.size());
    -    Set invalidResources = new HashSet<>();
    -    List errors = new ArrayList<>();
    -
    -    for (int i = 0; i < resources.size(); i++) {
    -      Any resource = resources.get(i);
    -
    -      // Unpack the ClusterLoadAssignment.
    -      ClusterLoadAssignment assignment;
    -      try {
    -        assignment =
    -            unpackCompatibleType(resource, ClusterLoadAssignment.class, ResourceType.EDS.typeUrl(),
    -                ResourceType.EDS.typeUrlV2());
    -      } catch (InvalidProtocolBufferException e) {
    -        errors.add(
    -            "EDS response Resource index " + i + " - can't decode ClusterLoadAssignment: " + e);
    -        continue;
    -      }
    -      if (!isResourceNameValid(assignment.getClusterName(), resource.getTypeUrl())) {
    -        errors.add("Unsupported resource name: " + assignment.getClusterName() + " for type: "
    -            + ResourceType.EDS);
    -        continue;
    -      }
    -      String clusterName = canonifyResourceName(assignment.getClusterName());
    -
    -      // Skip information for clusters not requested.
    -      // Management server is required to always send newly requested resources, even if they
    -      // may have been sent previously (proactively). Thus, client does not need to cache
    -      // unrequested resources.
    -      if (!edsResourceSubscribers.containsKey(clusterName)) {
    -        continue;
    -      }
    -      unpackedResources.add(clusterName);
    -
    -      // Process ClusterLoadAssignment into EdsUpdate.
    -      EdsUpdate edsUpdate;
    -      try {
    -        edsUpdate = processClusterLoadAssignment(assignment);
    -      } catch (ResourceInvalidException e) {
    -        errors.add("EDS response ClusterLoadAssignment '" + clusterName
    -            + "' validation error: " + e.getMessage());
    -        invalidResources.add(clusterName);
    -        continue;
    -      }
    -      parsedResources.put(clusterName, new ParsedResource(edsUpdate, resource));
    -    }
    -    logger.log(
    -        XdsLogLevel.INFO, "Received EDS Response version {0} nonce {1}. Parsed resources: {2}",
    -        versionInfo, nonce, unpackedResources);
    -    handleResourceUpdate(
    -        serverInfo, ResourceType.EDS, parsedResources, invalidResources,
    -        Collections.emptySet(), versionInfo, nonce, errors);
    -  }
    -
    -  private static EdsUpdate processClusterLoadAssignment(ClusterLoadAssignment assignment)
    -      throws ResourceInvalidException {
    -    Set priorities = new HashSet<>();
    -    Map localityLbEndpointsMap = new LinkedHashMap<>();
    -    List dropOverloads = new ArrayList<>();
    -    int maxPriority = -1;
    -    for (io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints localityLbEndpointsProto
    -        : assignment.getEndpointsList()) {
    -      StructOrError structOrError =
    -          parseLocalityLbEndpoints(localityLbEndpointsProto);
    -      if (structOrError == null) {
    -        continue;
    -      }
    -      if (structOrError.getErrorDetail() != null) {
    -        throw new ResourceInvalidException(structOrError.getErrorDetail());
    -      }
    -
    -      LocalityLbEndpoints localityLbEndpoints = structOrError.getStruct();
    -      maxPriority = Math.max(maxPriority, localityLbEndpoints.priority());
    -      priorities.add(localityLbEndpoints.priority());
    -      // Note endpoints with health status other than HEALTHY and UNKNOWN are still
    -      // handed over to watching parties. It is watching parties' responsibility to
    -      // filter out unhealthy endpoints. See EnvoyProtoData.LbEndpoint#isHealthy().
    -      localityLbEndpointsMap.put(
    -          parseLocality(localityLbEndpointsProto.getLocality()),
    -          localityLbEndpoints);
    -    }
    -    if (priorities.size() != maxPriority + 1) {
    -      throw new ResourceInvalidException("ClusterLoadAssignment has sparse priorities");
    -    }
    -
    -    for (ClusterLoadAssignment.Policy.DropOverload dropOverloadProto
    -        : assignment.getPolicy().getDropOverloadsList()) {
    -      dropOverloads.add(parseDropOverload(dropOverloadProto));
    -    }
    -    return new EdsUpdate(assignment.getClusterName(), localityLbEndpointsMap, dropOverloads);
    -  }
    -
    -  private static Locality parseLocality(io.envoyproxy.envoy.config.core.v3.Locality proto) {
    -    return Locality.create(proto.getRegion(), proto.getZone(), proto.getSubZone());
    -  }
    -
    -  private static DropOverload parseDropOverload(
    -      io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment.Policy.DropOverload proto) {
    -    return DropOverload.create(proto.getCategory(), getRatePerMillion(proto.getDropPercentage()));
    -  }
    -
    -  @VisibleForTesting
    -  @Nullable
    -  static StructOrError parseLocalityLbEndpoints(
    -      io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto) {
    -    // Filter out localities without or with 0 weight.
    -    if (!proto.hasLoadBalancingWeight() || proto.getLoadBalancingWeight().getValue() < 1) {
    -      return null;
    -    }
    -    if (proto.getPriority() < 0) {
    -      return StructOrError.fromError("negative priority");
    -    }
    -    List endpoints = new ArrayList<>(proto.getLbEndpointsCount());
    -    for (io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint endpoint : proto.getLbEndpointsList()) {
    -      // The endpoint field of each lb_endpoints must be set.
    -      // Inside of it: the address field must be set.
    -      if (!endpoint.hasEndpoint() || !endpoint.getEndpoint().hasAddress()) {
    -        return StructOrError.fromError("LbEndpoint with no endpoint/address");
    -      }
    -      io.envoyproxy.envoy.config.core.v3.SocketAddress socketAddress =
    -          endpoint.getEndpoint().getAddress().getSocketAddress();
    -      InetSocketAddress addr =
    -          new InetSocketAddress(socketAddress.getAddress(), socketAddress.getPortValue());
    -      boolean isHealthy =
    -          endpoint.getHealthStatus() == io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY
    -              || endpoint.getHealthStatus()
    -              == io.envoyproxy.envoy.config.core.v3.HealthStatus.UNKNOWN;
    -      endpoints.add(LbEndpoint.create(
    -          new EquivalentAddressGroup(ImmutableList.of(addr)),
    -          endpoint.getLoadBalancingWeight().getValue(), isHealthy));
    -    }
    -    return StructOrError.fromStruct(LocalityLbEndpoints.create(
    -        endpoints, proto.getLoadBalancingWeight().getValue(), proto.getPriority()));
    -  }
    -
    -  /**
    -   * Helper method to unpack serialized {@link com.google.protobuf.Any} message, while replacing
    -   * Type URL {@code compatibleTypeUrl} with {@code typeUrl}.
    -   *
    -   * @param  The type of unpacked message
    -   * @param any serialized message to unpack
    -   * @param clazz the class to unpack the message to
    -   * @param typeUrl type URL to replace message Type URL, when it's compatible
    -   * @param compatibleTypeUrl compatible Type URL to be replaced with {@code typeUrl}
    -   * @return Unpacked message
    -   * @throws InvalidProtocolBufferException if the message couldn't be unpacked
    -   */
    -  private static  T unpackCompatibleType(
    -      Any any, Class clazz, String typeUrl, String compatibleTypeUrl)
    -      throws InvalidProtocolBufferException {
    -    if (any.getTypeUrl().equals(compatibleTypeUrl)) {
    -      any = any.toBuilder().setTypeUrl(typeUrl).build();
    -    }
    -    return any.unpack(clazz);
    -  }
    -
    -  private static int getRatePerMillion(FractionalPercent percent) {
    -    int numerator = percent.getNumerator();
    -    DenominatorType type = percent.getDenominator();
    -    switch (type) {
    -      case TEN_THOUSAND:
    -        numerator *= 100;
    -        break;
    -      case HUNDRED:
    -        numerator *= 10_000;
    -        break;
    -      case MILLION:
    -        break;
    -      case UNRECOGNIZED:
    -      default:
    -        throw new IllegalArgumentException("Unknown denominator type of " + percent);
    -    }
    -
    -    if (numerator > 1_000_000 || numerator < 0) {
    -      numerator = 1_000_000;
    -    }
    -    return numerator;
    -  }
    -
    -  @Override
    -  public void handleStreamClosed(Status error) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    cleanUpResourceTimers();
    -    for (ResourceSubscriber subscriber : ldsResourceSubscribers.values()) {
    -      subscriber.onError(error);
    -    }
    -    for (ResourceSubscriber subscriber : rdsResourceSubscribers.values()) {
    -      subscriber.onError(error);
    -    }
    -    for (ResourceSubscriber subscriber : cdsResourceSubscribers.values()) {
    -      subscriber.onError(error);
    -    }
    -    for (ResourceSubscriber subscriber : edsResourceSubscribers.values()) {
    -      subscriber.onError(error);
    -    }
    -  }
    -
    -  @Override
    -  public void handleStreamRestarted(ServerInfo serverInfo) {
    -    syncContext.throwIfNotInThisSynchronizationContext();
    -    for (ResourceSubscriber subscriber : ldsResourceSubscribers.values()) {
    -      if (subscriber.serverInfo.equals(serverInfo)) {
    -        subscriber.restartTimer();
    -      }
    -    }
    -    for (ResourceSubscriber subscriber : rdsResourceSubscribers.values()) {
    -      if (subscriber.serverInfo.equals(serverInfo)) {
    -        subscriber.restartTimer();
    -      }
    -    }
    -    for (ResourceSubscriber subscriber : cdsResourceSubscribers.values()) {
    -      if (subscriber.serverInfo.equals(serverInfo)) {
    -        subscriber.restartTimer();
    -      }
    -    }
    -    for (ResourceSubscriber subscriber : edsResourceSubscribers.values()) {
    -      if (subscriber.serverInfo.equals(serverInfo)) {
    -        subscriber.restartTimer();
    -      }
    -    }
    -  }
    -
    -  @Override
    -  void shutdown() {
    -    syncContext.execute(
    -        new Runnable() {
    -          @Override
    -          public void run() {
    -            if (isShutdown) {
    -              return;
    -            }
    -            isShutdown = true;
    -            for (AbstractXdsClient xdsChannel : serverChannelMap.values()) {
    -              xdsChannel.shutdown();
    -            }
    -            if (reportingLoad) {
    -              for (final LoadReportClient lrsClient : serverLrsClientMap.values()) {
    -                lrsClient.stopLoadReporting();
    -              }
    -            }
    -            cleanUpResourceTimers();
    -          }
    -        });
    -  }
    -
    -  @Override
    -  boolean isShutDown() {
    -    return isShutdown;
    -  }
    -
    -  private Map getSubscribedResourcesMap(ResourceType type) {
    -    switch (type) {
    -      case LDS:
    -        return ldsResourceSubscribers;
    -      case RDS:
    -        return rdsResourceSubscribers;
    -      case CDS:
    -        return cdsResourceSubscribers;
    -      case EDS:
    -        return edsResourceSubscribers;
    -      case UNKNOWN:
    -      default:
    -        throw new AssertionError("Unknown resource type");
    -    }
    -  }
    -
    -  @Nullable
    -  @Override
    -  public Collection getSubscribedResources(ServerInfo serverInfo, ResourceType type) {
    -    Map resources = getSubscribedResourcesMap(type);
    -    ImmutableSet.Builder builder = ImmutableSet.builder();
    -    for (String key : resources.keySet()) {
    -      if (resources.get(key).serverInfo.equals(serverInfo)) {
    -        builder.add(key);
    -      }
    -    }
    -    Collection retVal = builder.build();
    -    return retVal.isEmpty() ? null : retVal;
    -  }
    -
    -  @Override
    -  ListenableFuture>>
    -      getSubscribedResourcesMetadataSnapshot() {
    -    final SettableFuture>> future =
    -        SettableFuture.create();
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        // A map from a "resource type" to a map ("resource name": "resource metadata")
    -        ImmutableMap.Builder> metadataSnapshot =
    -            ImmutableMap.builder();
    -        for (ResourceType type : ResourceType.values()) {
    -          if (type == ResourceType.UNKNOWN) {
    -            continue;
    -          }
    -          ImmutableMap.Builder metadataMap = ImmutableMap.builder();
    -          for (Map.Entry resourceEntry
    -              : getSubscribedResourcesMap(type).entrySet()) {
    -            metadataMap.put(resourceEntry.getKey(), resourceEntry.getValue().metadata);
    -          }
    -          metadataSnapshot.put(type, metadataMap.build());
    -        }
    -        future.set(metadataSnapshot.build());
    -      }
    -    });
    -    return future;
    -  }
    -
    -  @Override
    -  TlsContextManager getTlsContextManager() {
    -    return tlsContextManager;
    -  }
    -
    -  @Override
    -  void watchLdsResource(final String resourceName, final LdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = ldsResourceSubscribers.get(resourceName);
    -        if (subscriber == null) {
    -          logger.log(XdsLogLevel.INFO, "Subscribe LDS resource {0}", resourceName);
    -          subscriber = new ResourceSubscriber(ResourceType.LDS, resourceName);
    -          ldsResourceSubscribers.put(resourceName, subscriber);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.LDS);
    -        }
    -        subscriber.addWatcher(watcher);
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void cancelLdsResourceWatch(final String resourceName, final LdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = ldsResourceSubscribers.get(resourceName);
    -        subscriber.removeWatcher(watcher);
    -        if (!subscriber.isWatched()) {
    -          subscriber.stopTimer();
    -          logger.log(XdsLogLevel.INFO, "Unsubscribe LDS resource {0}", resourceName);
    -          ldsResourceSubscribers.remove(resourceName);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.LDS);
    -        }
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void watchRdsResource(final String resourceName, final RdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = rdsResourceSubscribers.get(resourceName);
    -        if (subscriber == null) {
    -          logger.log(XdsLogLevel.INFO, "Subscribe RDS resource {0}", resourceName);
    -          subscriber = new ResourceSubscriber(ResourceType.RDS, resourceName);
    -          rdsResourceSubscribers.put(resourceName, subscriber);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.RDS);
    -        }
    -        subscriber.addWatcher(watcher);
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void cancelRdsResourceWatch(final String resourceName, final RdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = rdsResourceSubscribers.get(resourceName);
    -        subscriber.removeWatcher(watcher);
    -        if (!subscriber.isWatched()) {
    -          subscriber.stopTimer();
    -          logger.log(XdsLogLevel.INFO, "Unsubscribe RDS resource {0}", resourceName);
    -          rdsResourceSubscribers.remove(resourceName);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.RDS);
    -        }
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void watchCdsResource(final String resourceName, final CdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = cdsResourceSubscribers.get(resourceName);
    -        if (subscriber == null) {
    -          logger.log(XdsLogLevel.INFO, "Subscribe CDS resource {0}", resourceName);
    -          subscriber = new ResourceSubscriber(ResourceType.CDS, resourceName);
    -          cdsResourceSubscribers.put(resourceName, subscriber);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.CDS);
    -        }
    -        subscriber.addWatcher(watcher);
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void cancelCdsResourceWatch(final String resourceName, final CdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = cdsResourceSubscribers.get(resourceName);
    -        subscriber.removeWatcher(watcher);
    -        if (!subscriber.isWatched()) {
    -          subscriber.stopTimer();
    -          logger.log(XdsLogLevel.INFO, "Unsubscribe CDS resource {0}", resourceName);
    -          cdsResourceSubscribers.remove(resourceName);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.CDS);
    -        }
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void watchEdsResource(final String resourceName, final EdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = edsResourceSubscribers.get(resourceName);
    -        if (subscriber == null) {
    -          logger.log(XdsLogLevel.INFO, "Subscribe EDS resource {0}", resourceName);
    -          subscriber = new ResourceSubscriber(ResourceType.EDS, resourceName);
    -          edsResourceSubscribers.put(resourceName, subscriber);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.EDS);
    -        }
    -        subscriber.addWatcher(watcher);
    -      }
    -    });
    -  }
    -
    -  @Override
    -  void cancelEdsResourceWatch(final String resourceName, final EdsResourceWatcher watcher) {
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        ResourceSubscriber subscriber = edsResourceSubscribers.get(resourceName);
    -        subscriber.removeWatcher(watcher);
    -        if (!subscriber.isWatched()) {
    -          subscriber.stopTimer();
    -          logger.log(XdsLogLevel.INFO, "Unsubscribe EDS resource {0}", resourceName);
    -          edsResourceSubscribers.remove(resourceName);
    -          subscriber.xdsChannel.adjustResourceSubscription(ResourceType.EDS);
    -        }
    -      }
    -    });
    -  }
    -
    -  @Override
    -  ClusterDropStats addClusterDropStats(
    -      final ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName) {
    -    ClusterDropStats dropCounter =
    -        loadStatsManager.getClusterDropStats(clusterName, edsServiceName);
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        if (!reportingLoad) {
    -          serverLrsClientMap.get(serverInfo).startLoadReporting();
    -          reportingLoad = true;
    -        }
    -      }
    -    });
    -    return dropCounter;
    -  }
    -
    -  @Override
    -  ClusterLocalityStats addClusterLocalityStats(
    -      final ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName,
    -      Locality locality) {
    -    ClusterLocalityStats loadCounter =
    -        loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality);
    -    syncContext.execute(new Runnable() {
    -      @Override
    -      public void run() {
    -        if (!reportingLoad) {
    -          serverLrsClientMap.get(serverInfo).startLoadReporting();
    -          reportingLoad = true;
    -        }
    -      }
    -    });
    -    return loadCounter;
    -  }
    -
    -  @Override
    -  Bootstrapper.BootstrapInfo getBootstrapInfo() {
    -    return bootstrapInfo;
    -  }
    -  
    -  @Override
    -  public String toString() {
    -    return logId.toString();
    -  }
    -
    -  private void cleanUpResourceTimers() {
    -    for (ResourceSubscriber subscriber : ldsResourceSubscribers.values()) {
    -      subscriber.stopTimer();
    -    }
    -    for (ResourceSubscriber subscriber : rdsResourceSubscribers.values()) {
    -      subscriber.stopTimer();
    -    }
    -    for (ResourceSubscriber subscriber : cdsResourceSubscribers.values()) {
    -      subscriber.stopTimer();
    -    }
    -    for (ResourceSubscriber subscriber : edsResourceSubscribers.values()) {
    -      subscriber.stopTimer();
    -    }
    -  }
    -
    -  private void handleResourceUpdate(
    -      ServerInfo serverInfo, ResourceType type, Map parsedResources,
    -      Set invalidResources, Set retainedResources, String version, String nonce,
    -      List errors) {
    -    String errorDetail = null;
    -    if (errors.isEmpty()) {
    -      checkArgument(invalidResources.isEmpty(), "found invalid resources but missing errors");
    -      serverChannelMap.get(serverInfo).ackResponse(type, version, nonce);
    -    } else {
    -      errorDetail = Joiner.on('\n').join(errors);
    -      logger.log(XdsLogLevel.WARNING,
    -          "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}",
    -          type, version, nonce, errorDetail);
    -      serverChannelMap.get(serverInfo).nackResponse(type, nonce, errorDetail);
    -    }
    -    long updateTime = timeProvider.currentTimeNanos();
    -    for (Map.Entry entry : getSubscribedResourcesMap(type).entrySet()) {
    -      String resourceName = entry.getKey();
    -      ResourceSubscriber subscriber = entry.getValue();
    -      // Attach error details to the subscribed resources that included in the ADS update.
    -      if (invalidResources.contains(resourceName)) {
    -        subscriber.onRejected(version, updateTime, errorDetail);
    -      }
    -      // Notify the watchers.
    -      if (parsedResources.containsKey(resourceName)) {
    -        subscriber.onData(parsedResources.get(resourceName), version, updateTime);
    -      } else if (type == ResourceType.LDS || type == ResourceType.CDS) {
    -        if (subscriber.data != null && invalidResources.contains(resourceName)) {
    -          // Update is rejected but keep using the cached data.
    -          if (type == ResourceType.LDS) {
    -            LdsUpdate ldsUpdate = (LdsUpdate) subscriber.data;
    -            io.grpc.xds.HttpConnectionManager hcm = ldsUpdate.httpConnectionManager();
    -            if (hcm != null) {
    -              String rdsName = hcm.rdsName();
    -              if (rdsName != null) {
    -                retainedResources.add(rdsName);
    -              }
    -            }
    -          } else {
    -            CdsUpdate cdsUpdate = (CdsUpdate) subscriber.data;
    -            String edsName = cdsUpdate.edsServiceName();
    -            if (edsName == null) {
    -              edsName = cdsUpdate.clusterName();
    -            }
    -            retainedResources.add(edsName);
    -          }
    -        } else if (invalidResources.contains(resourceName)) {
    -          subscriber.onError(Status.UNAVAILABLE.withDescription(errorDetail));
    -        } else {
    -          // For State of the World services, notify watchers when their watched resource is missing
    -          // from the ADS update.
    -          subscriber.onAbsent();
    -        }
    -      }
    -    }
    -    // LDS/CDS responses represents the state of the world, RDS/EDS resources not referenced in
    -    // LDS/CDS resources should be deleted.
    -    if (type == ResourceType.LDS || type == ResourceType.CDS) {
    -      Map dependentSubscribers =
    -          type == ResourceType.LDS ? rdsResourceSubscribers : edsResourceSubscribers;
    -      for (String resource : dependentSubscribers.keySet()) {
    -        if (!retainedResources.contains(resource)) {
    -          dependentSubscribers.get(resource).onAbsent();
    -        }
    -      }
    -    }
    -  }
    -
    -  private static final class ParsedResource {
    -    private final ResourceUpdate resourceUpdate;
    -    private final Any rawResource;
    -
    -    private ParsedResource(ResourceUpdate resourceUpdate, Any rawResource) {
    -      this.resourceUpdate = checkNotNull(resourceUpdate, "resourceUpdate");
    -      this.rawResource = checkNotNull(rawResource, "rawResource");
    -    }
    -
    -    private ResourceUpdate getResourceUpdate() {
    -      return resourceUpdate;
    -    }
    -
    -    private Any getRawResource() {
    -      return rawResource;
    -    }
    -  }
    -
    -  /**
    -   * Tracks a single subscribed resource.
    -   */
    -  private final class ResourceSubscriber {
    -    private final ServerInfo serverInfo;
    -    private final AbstractXdsClient xdsChannel;
    -    private final ResourceType type;
    -    private final String resource;
    -    private final Set watchers = new HashSet<>();
    -    private ResourceUpdate data;
    -    private boolean absent;
    -    private ScheduledHandle respTimer;
    -    private ResourceMetadata metadata;
    -
    -    ResourceSubscriber(ResourceType type, String resource) {
    -      syncContext.throwIfNotInThisSynchronizationContext();
    -      this.type = type;
    -      this.resource = resource;
    -      this.serverInfo = getServerInfo(resource);
    -      // Initialize metadata in UNKNOWN state to cover the case when resource subscriber,
    -      // is created but not yet requested because the client is in backoff.
    -      this.metadata = ResourceMetadata.newResourceMetadataUnknown();
    -      maybeCreateXdsChannelWithLrs(serverInfo);
    -      this.xdsChannel = serverChannelMap.get(serverInfo);
    -      if (xdsChannel.isInBackoff()) {
    -        return;
    -      }
    -      restartTimer();
    -    }
    -
    -    private ServerInfo getServerInfo(String resource) {
    -      if (resource.startsWith(XDSTP_SCHEME)) {
    -        URI uri = URI.create(resource);
    -        String authority = uri.getAuthority();
    -        if (authority == null) {
    -          authority = "";
    -        }
    -        AuthorityInfo authorityInfo = bootstrapInfo.authorities().get(authority);
    -        return authorityInfo.xdsServers().get(0);
    -      }
    -      return bootstrapInfo.servers().get(0); // use first server
    -    }
    -
    -    void addWatcher(ResourceWatcher watcher) {
    -      checkArgument(!watchers.contains(watcher), "watcher %s already registered", watcher);
    -      watchers.add(watcher);
    -      if (data != null) {
    -        notifyWatcher(watcher, data);
    -      } else if (absent) {
    -        watcher.onResourceDoesNotExist(resource);
    -      }
    -    }
    -
    -    void removeWatcher(ResourceWatcher watcher) {
    -      checkArgument(watchers.contains(watcher), "watcher %s not registered", watcher);
    -      watchers.remove(watcher);
    -    }
    -
    -    void restartTimer() {
    -      if (data != null || absent) {  // resource already resolved
    -        return;
    -      }
    -      class ResourceNotFound implements Runnable {
    -        @Override
    -        public void run() {
    -          logger.log(XdsLogLevel.INFO, "{0} resource {1} initial fetch timeout",
    -              type, resource);
    -          respTimer = null;
    -          onAbsent();
    -        }
    -
    -        @Override
    -        public String toString() {
    -          return type + this.getClass().getSimpleName();
    -        }
    -      }
    -
    -      // Initial fetch scheduled or rescheduled, transition metadata state to REQUESTED.
    -      metadata = ResourceMetadata.newResourceMetadataRequested();
    -      respTimer = syncContext.schedule(
    -          new ResourceNotFound(), INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS,
    -          timeService);
    -    }
    -
    -    void stopTimer() {
    -      if (respTimer != null && respTimer.isPending()) {
    -        respTimer.cancel();
    -        respTimer = null;
    -      }
    -    }
    -
    -    boolean isWatched() {
    -      return !watchers.isEmpty();
    -    }
    -
    -    void onData(ParsedResource parsedResource, String version, long updateTime) {
    -      if (respTimer != null && respTimer.isPending()) {
    -        respTimer.cancel();
    -        respTimer = null;
    -      }
    -      this.metadata = ResourceMetadata
    -          .newResourceMetadataAcked(parsedResource.getRawResource(), version, updateTime);
    -      ResourceUpdate oldData = this.data;
    -      this.data = parsedResource.getResourceUpdate();
    -      absent = false;
    -      if (!Objects.equals(oldData, data)) {
    -        for (ResourceWatcher watcher : watchers) {
    -          notifyWatcher(watcher, data);
    -        }
    -      }
    -    }
    -
    -    void onAbsent() {
    -      if (respTimer != null && respTimer.isPending()) {  // too early to conclude absence
    -        return;
    -      }
    -      logger.log(XdsLogLevel.INFO, "Conclude {0} resource {1} not exist", type, resource);
    -      if (!absent) {
    -        data = null;
    -        absent = true;
    -        metadata = ResourceMetadata.newResourceMetadataDoesNotExist();
    -        for (ResourceWatcher watcher : watchers) {
    -          watcher.onResourceDoesNotExist(resource);
    -        }
    -      }
    -    }
    -
    -    void onError(Status error) {
    -      if (respTimer != null && respTimer.isPending()) {
    -        respTimer.cancel();
    -        respTimer = null;
    -      }
    -      for (ResourceWatcher watcher : watchers) {
    -        watcher.onError(error);
    -      }
    -    }
    -
    -    void onRejected(String rejectedVersion, long rejectedTime, String rejectedDetails) {
    -      metadata = ResourceMetadata
    -          .newResourceMetadataNacked(metadata, rejectedVersion, rejectedTime, rejectedDetails);
    -    }
    -
    -    private void notifyWatcher(ResourceWatcher watcher, ResourceUpdate update) {
    -      switch (type) {
    -        case LDS:
    -          ((LdsResourceWatcher) watcher).onChanged((LdsUpdate) update);
    -          break;
    -        case RDS:
    -          ((RdsResourceWatcher) watcher).onChanged((RdsUpdate) update);
    -          break;
    -        case CDS:
    -          ((CdsResourceWatcher) watcher).onChanged((CdsUpdate) update);
    -          break;
    -        case EDS:
    -          ((EdsResourceWatcher) watcher).onChanged((EdsUpdate) update);
    -          break;
    -        case UNKNOWN:
    -        default:
    -          throw new AssertionError("should never be here");
    -      }
    -    }
    -  }
    -
    -  @VisibleForTesting
    -  static final class ResourceInvalidException extends Exception {
    -    private static final long serialVersionUID = 0L;
    -
    -    private ResourceInvalidException(String message) {
    -      super(message, null, false, false);
    -    }
    -
    -    private ResourceInvalidException(String message, Throwable cause) {
    -      super(cause != null ? message + ": " + cause.getMessage() : message, cause, false, false);
    -    }
    -  }
    -
    -  @VisibleForTesting
    -  static final class StructOrError {
    -
    -    /**
    -     * Returns a {@link StructOrError} for the successfully converted data object.
    -     */
    -    private static  StructOrError fromStruct(T struct) {
    -      return new StructOrError<>(struct);
    -    }
    -
    -    /**
    -     * Returns a {@link StructOrError} for the failure to convert the data object.
    -     */
    -    private static  StructOrError fromError(String errorDetail) {
    -      return new StructOrError<>(errorDetail);
    -    }
    -
    -    private final String errorDetail;
    -    private final T struct;
    -
    -    private StructOrError(T struct) {
    -      this.struct = checkNotNull(struct, "struct");
    -      this.errorDetail = null;
    -    }
    -
    -    private StructOrError(String errorDetail) {
    -      this.struct = null;
    -      this.errorDetail = checkNotNull(errorDetail, "errorDetail");
    -    }
    -
    -    /**
    -     * Returns struct if exists, otherwise null.
    -     */
    -    @VisibleForTesting
    -    @Nullable
    -    T getStruct() {
    -      return struct;
    -    }
    -
    -    /**
    -     * Returns error detail if exists, otherwise null.
    -     */
    -    @VisibleForTesting
    -    @Nullable
    -    String getErrorDetail() {
    -      return errorDetail;
    -    }
    -  }
    -
    -  abstract static class XdsChannelFactory {
    -    static final XdsChannelFactory DEFAULT_XDS_CHANNEL_FACTORY = new XdsChannelFactory() {
    -      @Override
    -      ManagedChannel create(ServerInfo serverInfo) {
    -        String target = serverInfo.target();
    -        ChannelCredentials channelCredentials = serverInfo.channelCredentials();
    -        return Grpc.newChannelBuilder(target, channelCredentials)
    -            .keepAliveTime(5, TimeUnit.MINUTES)
    -            .build();
    -      }
    -    };
    -
    -    abstract ManagedChannel create(ServerInfo serverInfo);
    -  }
    -}
    diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java
    index e9796267a85..b225b01af7a 100644
    --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java
    +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java
    @@ -45,7 +45,7 @@
     import io.grpc.xds.XdsLogger.XdsLogLevel;
     import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider;
     import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
    -import io.grpc.xds.internal.sds.SslContextProviderSupplier;
    +import io.grpc.xds.internal.security.SslContextProviderSupplier;
     import java.util.ArrayList;
     import java.util.Collections;
     import java.util.List;
    @@ -102,7 +102,7 @@ final class ClusterImplLoadBalancer extends LoadBalancer {
       }
     
       @Override
    -  public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
    +  public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
         logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
         Attributes attributes = resolvedAddresses.getAttributes();
         if (xdsClientPool == null) {
    @@ -134,6 +134,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
                 .setAttributes(attributes)
                 .setLoadBalancingPolicyConfig(config.childPolicy.getConfig())
                 .build());
    +    return true;
       }
     
       @Override
    @@ -162,11 +163,6 @@ public void shutdown() {
         }
       }
     
    -  @Override
    -  public boolean canHandleEmptyAddressListFromNameResolution() {
    -    return true;
    -  }
    -
       /**
        * A decorated {@link LoadBalancer.Helper} that applies configurations for connections
        * or requests to endpoints in the cluster.
    @@ -198,17 +194,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne
     
         @Override
         public Subchannel createSubchannel(CreateSubchannelArgs args) {
    -      List addresses = new ArrayList<>();
    -      for (EquivalentAddressGroup eag : args.getAddresses()) {
    -        Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set(
    -            InternalXdsAttributes.ATTR_CLUSTER_NAME, cluster);
    -        if (enableSecurity && sslContextProviderSupplier != null) {
    -          attrBuilder.set(
    -              InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
    -              sslContextProviderSupplier);
    -        }
    -        addresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build()));
    -      }
    +      List addresses = withAdditionalAttributes(args.getAddresses());
           Locality locality = args.getAddresses().get(0).getAttributes().get(
               InternalXdsAttributes.ATTR_LOCALITY);  // all addresses should be in the same locality
           // Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain
    @@ -233,6 +219,11 @@ public void shutdown() {
               delegate().shutdown();
             }
     
    +        @Override
    +        public void updateAddresses(List addresses) {
    +          delegate().updateAddresses(withAdditionalAttributes(addresses));
    +        }
    +
             @Override
             protected Subchannel delegate() {
               return subchannel;
    @@ -240,6 +231,22 @@ protected Subchannel delegate() {
           };
         }
     
    +    private List withAdditionalAttributes(
    +        List addresses) {
    +      List newAddresses = new ArrayList<>();
    +      for (EquivalentAddressGroup eag : addresses) {
    +        Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set(
    +            InternalXdsAttributes.ATTR_CLUSTER_NAME, cluster);
    +        if (enableSecurity && sslContextProviderSupplier != null) {
    +          attrBuilder.set(
    +              InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
    +              sslContextProviderSupplier);
    +        }
    +        newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build()));
    +      }
    +      return newAddresses;
    +    }
    +
         @Override
         protected Helper delegate()  {
           return helper;
    diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
    index 0557f3a6a8c..cce32c68246 100644
    --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
    +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
    @@ -57,6 +57,8 @@ class ClusterManagerLoadBalancer extends LoadBalancer {
       private final SynchronizationContext syncContext;
       private final ScheduledExecutorService timeService;
       private final XdsLogger logger;
    +  // Set to true if currently in the process of handling resolved addresses.
    +  private boolean resolvingAddresses;
     
       ClusterManagerLoadBalancer(Helper helper) {
         this.helper = checkNotNull(helper, "helper");
    @@ -68,7 +70,16 @@ class ClusterManagerLoadBalancer extends LoadBalancer {
       }
     
       @Override
    -  public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
    +  public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
    +    try {
    +      resolvingAddresses = true;
    +      return acceptResolvedAddressesInternal(resolvedAddresses);
    +    } finally {
    +      resolvingAddresses = false;
    +    }
    +  }
    +
    +  public boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) {
         logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
         ClusterManagerConfig config = (ClusterManagerConfig)
             resolvedAddresses.getLoadBalancingPolicyConfig();
    @@ -98,6 +109,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
         // Must update channel picker before return so that new RPCs will not be routed to deleted
         // clusters and resolver can remove them in service config.
         updateOverallBalancingState();
    +    return true;
       }
     
       @Override
    @@ -115,11 +127,6 @@ public void handleNameResolutionError(Status error) {
         }
       }
     
    -  @Override
    -  public boolean canHandleEmptyAddressListFromNameResolution() {
    -    return true;
    -  }
    -
       @Override
       public void shutdown() {
         logger.log(XdsLogLevel.INFO, "Shutdown");
    @@ -251,21 +258,18 @@ private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper {
           @Override
           public void updateBalancingState(final ConnectivityState newState,
               final SubchannelPicker newPicker) {
    -        syncContext.execute(new Runnable() {
    -          @Override
    -          public void run() {
    -            if (!childLbStates.containsKey(name)) {
    -              return;
    -            }
    -            // Subchannel picker and state are saved, but will only be propagated to the channel
    -            // when the child instance exits deactivated state.
    -            currentState = newState;
    -            currentPicker = newPicker;
    -            if (!deactivated) {
    -              updateOverallBalancingState();
    -            }
    -          }
    -        });
    +        // If we are already in the process of resolving addresses, the overall balancing state
    +        // will be updated at the end of it, and we don't need to trigger that update here.
    +        if (!childLbStates.containsKey(name)) {
    +          return;
    +        }
    +        // Subchannel picker and state are saved, but will only be propagated to the channel
    +        // when the child instance exits deactivated state.
    +        currentState = newState;
    +        currentPicker = newPicker;
    +        if (!deactivated && !resolvingAddresses) {
    +          updateOverallBalancingState();
    +        }
           }
     
           @Override
    diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java
    index e219c0467ac..9c97d3fe966 100644
    --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java
    +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java
    @@ -110,7 +110,7 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) {
           }
         } catch (RuntimeException e) {
           return ConfigOrError.fromError(
    -          Status.fromThrowable(e).withDescription(
    +          Status.INTERNAL.withCause(e).withDescription(
                   "Failed to parse cluster_manager LB config: " + rawConfig));
         }
         return ConfigOrError.fromConfig(new ClusterManagerConfig(parsedChildPolicies));
    diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java
    index 309daf55a18..3af58ef93cb 100644
    --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java
    +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java
    @@ -19,7 +19,6 @@
     import static com.google.common.base.Preconditions.checkNotNull;
     import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
     import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME;
    -import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME;
     
     import com.google.common.annotations.VisibleForTesting;
     import io.grpc.Attributes;
    @@ -39,6 +38,7 @@
     import io.grpc.internal.ServiceConfigUtil.PolicySelection;
     import io.grpc.util.ForwardingLoadBalancerHelper;
     import io.grpc.util.GracefulSwitchLoadBalancer;
    +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig;
     import io.grpc.xds.Bootstrapper.ServerInfo;
     import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig;
     import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig;
    @@ -46,13 +46,14 @@
     import io.grpc.xds.Endpoints.DropOverload;
     import io.grpc.xds.Endpoints.LbEndpoint;
     import io.grpc.xds.Endpoints.LocalityLbEndpoints;
    +import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection;
    +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection;
    +import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection;
     import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
     import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig;
     import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig;
    -import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection;
    -import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig;
    -import io.grpc.xds.XdsClient.EdsResourceWatcher;
    -import io.grpc.xds.XdsClient.EdsUpdate;
    +import io.grpc.xds.XdsClient.ResourceWatcher;
    +import io.grpc.xds.XdsEndpointResource.EdsUpdate;
     import io.grpc.xds.XdsLogger.XdsLogLevel;
     import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
     import java.net.URI;
    @@ -61,9 +62,13 @@
     import java.util.Arrays;
     import java.util.Collections;
     import java.util.HashMap;
    +import java.util.HashSet;
     import java.util.List;
    +import java.util.Locale;
     import java.util.Map;
     import java.util.Objects;
    +import java.util.Set;
    +import java.util.TreeMap;
     import java.util.concurrent.ScheduledExecutorService;
     import java.util.concurrent.TimeUnit;
     import javax.annotation.Nullable;
    @@ -108,7 +113,7 @@ final class ClusterResolverLoadBalancer extends LoadBalancer {
       }
     
       @Override
    -  public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
    +  public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
         logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
         if (xdsClientPool == null) {
           xdsClientPool = resolvedAddresses.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL);
    @@ -122,6 +127,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
           this.config = config;
           delegate.handleResolvedAddresses(resolvedAddresses);
         }
    +    return true;
       }
     
       @Override
    @@ -165,7 +171,7 @@ private final class ClusterResolverLbState extends LoadBalancer {
         }
     
         @Override
    -    public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
    +    public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
           this.resolvedAddresses = resolvedAddresses;
           ClusterResolverConfig config =
               (ClusterResolverConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
    @@ -175,7 +181,8 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
             ClusterState state;
             if (instance.type == DiscoveryMechanism.Type.EDS) {
               state = new EdsClusterState(instance.cluster, instance.edsServiceName,
    -              instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext);
    +              instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext,
    +              instance.outlierDetection);
             } else {  // logical DNS
               state = new LogicalDnsClusterState(instance.cluster, instance.dnsHostName,
                   instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext);
    @@ -183,6 +190,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
             clusterStates.put(instance.cluster, state);
             state.start();
           }
    +      return true;
         }
     
         @Override
    @@ -208,6 +216,7 @@ private void handleEndpointResourceUpdate() {
           List addresses = new ArrayList<>();
           Map priorityChildConfigs = new HashMap<>();
           List priorities = new ArrayList<>();  // totally ordered priority list
    +
           Status endpointNotFound = Status.OK;
           for (String cluster : clusters) {
             ClusterState state = clusterStates.get(cluster);
    @@ -278,10 +287,8 @@ private void handleEndpointResolutionError() {
         private final class RefreshableHelper extends ForwardingLoadBalancerHelper {
           private final Helper delegate;
     
    -      @SuppressWarnings("deprecation")
           private RefreshableHelper(Helper delegate) {
             this.delegate = checkNotNull(delegate, "delegate");
    -        delegate.ignoreRefreshNameResolutionCheck();
           }
     
           @Override
    @@ -311,6 +318,8 @@ private abstract class ClusterState {
           protected final Long maxConcurrentRequests;
           @Nullable
           protected final UpstreamTlsContext tlsContext;
    +      @Nullable
    +      protected final OutlierDetection outlierDetection;
           // Resolution status, may contain most recent error encountered.
           protected Status status = Status.OK;
           // True if has received resolution result.
    @@ -318,14 +327,17 @@ private abstract class ClusterState {
           // Most recently resolved addresses and config, or null if resource not exists.
           @Nullable
           protected ClusterResolutionResult result;
    +
           protected boolean shutdown;
     
           private ClusterState(String name, @Nullable ServerInfo lrsServerInfo,
    -          @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) {
    +          @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext,
    +          @Nullable OutlierDetection outlierDetection) {
             this.name = name;
             this.lrsServerInfo = lrsServerInfo;
             this.maxConcurrentRequests = maxConcurrentRequests;
             this.tlsContext = tlsContext;
    +        this.outlierDetection = outlierDetection;
           }
     
           abstract void start();
    @@ -335,14 +347,16 @@ void shutdown() {
           }
         }
     
    -    private final class EdsClusterState extends ClusterState implements EdsResourceWatcher {
    +    private final class EdsClusterState extends ClusterState implements ResourceWatcher {
           @Nullable
           private final String edsServiceName;
    +      private Map localityPriorityNames = Collections.emptyMap();
    +      int priorityNameGenId = 1;
     
           private EdsClusterState(String name, @Nullable String edsServiceName,
               @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests,
    -          @Nullable UpstreamTlsContext tlsContext) {
    -        super(name, lrsServerInfo, maxConcurrentRequests, tlsContext);
    +          @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) {
    +        super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, outlierDetection);
             this.edsServiceName = edsServiceName;
           }
     
    @@ -350,7 +364,7 @@ private EdsClusterState(String name, @Nullable String edsServiceName,
           void start() {
             String resourceName = edsServiceName != null ? edsServiceName : name;
             logger.log(XdsLogLevel.INFO, "Start watching EDS resource {0}", resourceName);
    -        xdsClient.watchEdsResource(resourceName, this);
    +        xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), resourceName, this);
           }
     
           @Override
    @@ -358,7 +372,7 @@ protected void shutdown() {
             super.shutdown();
             String resourceName = edsServiceName != null ? edsServiceName : name;
             logger.log(XdsLogLevel.INFO, "Stop watching EDS resource {0}", resourceName);
    -        xdsClient.cancelEdsResourceWatch(resourceName, this);
    +        xdsClient.cancelXdsResourceWatch(XdsEndpointResource.getInstance(), resourceName, this);
           }
     
           @Override
    @@ -380,10 +394,10 @@ public void run() {
                 List dropOverloads = update.dropPolicies;
                 List addresses = new ArrayList<>();
                 Map> prioritizedLocalityWeights = new HashMap<>();
    +            List sortedPriorityNames = generatePriorityNames(name, localityLbEndpoints);
                 for (Locality locality : localityLbEndpoints.keySet()) {
                   LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality);
    -              int priority = localityLbInfo.priority();
    -              String priorityName = priorityName(name, priority);
    +              String priorityName = localityPriorityNames.get(locality);
                   boolean discard = true;
                   for (LbEndpoint endpoint : localityLbInfo.endpoints()) {
                     if (endpoint.isHealthy()) {
    @@ -395,6 +409,8 @@ public void run() {
                       Attributes attr =
                           endpoint.eag().getAttributes().toBuilder()
                               .set(InternalXdsAttributes.ATTR_LOCALITY, locality)
    +                          .set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT,
    +                              localityLbInfo.localityWeight())
                               .set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight)
                               .build();
                       EquivalentAddressGroup eag = new EquivalentAddressGroup(
    @@ -420,15 +436,16 @@ public void run() {
                   logger.log(XdsLogLevel.INFO,
                       "Cluster {0} has no usable priority/locality/endpoint", update.clusterName);
                 }
    -            List priorities = new ArrayList<>(prioritizedLocalityWeights.keySet());
    -            Collections.sort(priorities);
    +            sortedPriorityNames.retainAll(prioritizedLocalityWeights.keySet());
                 Map priorityChildConfigs =
                     generateEdsBasedPriorityChildConfigs(
                         name, edsServiceName, lrsServerInfo, maxConcurrentRequests, tlsContext,
    -                    endpointLbPolicy, lbRegistry, prioritizedLocalityWeights, dropOverloads);
    +                    outlierDetection, endpointLbPolicy, lbRegistry, prioritizedLocalityWeights,
    +                    dropOverloads);
                 status = Status.OK;
                 resolved = true;
    -            result = new ClusterResolutionResult(addresses, priorityChildConfigs, priorities);
    +            result = new ClusterResolutionResult(addresses, priorityChildConfigs,
    +                sortedPriorityNames);
                 handleEndpointResourceUpdate();
               }
             }
    @@ -436,6 +453,40 @@ public void run() {
             syncContext.execute(new EndpointsUpdated());
           }
     
    +      private List generatePriorityNames(String name,
    +          Map localityLbEndpoints) {
    +        TreeMap> todo = new TreeMap<>();
    +        for (Locality locality : localityLbEndpoints.keySet()) {
    +          int priority = localityLbEndpoints.get(locality).priority();
    +          if (!todo.containsKey(priority)) {
    +            todo.put(priority, new ArrayList<>());
    +          }
    +          todo.get(priority).add(locality);
    +        }
    +        Map newNames = new HashMap<>();
    +        Set usedNames = new HashSet<>();
    +        List ret = new ArrayList<>();
    +        for (Integer priority: todo.keySet()) {
    +          String foundName = "";
    +          for (Locality locality : todo.get(priority)) {
    +            if (localityPriorityNames.containsKey(locality)
    +                && usedNames.add(localityPriorityNames.get(locality))) {
    +              foundName = localityPriorityNames.get(locality);
    +              break;
    +            }
    +          }
    +          if ("".equals(foundName)) {
    +            foundName = String.format(Locale.US, "%s[child%d]", name, priorityNameGenId++);
    +          }
    +          for (Locality locality : todo.get(priority)) {
    +            newNames.put(locality, foundName);
    +          }
    +          ret.add(foundName);
    +        }
    +        localityPriorityNames = newNames;
    +        return ret;
    +      }
    +
           @Override
           public void onResourceDoesNotExist(final String resourceName) {
             syncContext.execute(new Runnable() {
    @@ -461,7 +512,11 @@ public void run() {
                 if (shutdown) {
                   return;
                 }
    -            status = error;
    +            String resourceName = edsServiceName != null ? edsServiceName : name;
    +            status = Status.UNAVAILABLE
    +                .withDescription(String.format("Unable to load EDS %s. xDS server returned: %s: %s",
    +                      resourceName, error.getCode(), error.getDescription()))
    +                .withCause(error.getCause());
                 logger.log(XdsLogLevel.WARNING, "Received EDS error: {0}", error);
                 handleEndpointResolutionError();
               }
    @@ -482,7 +537,7 @@ private final class LogicalDnsClusterState extends ClusterState {
           private LogicalDnsClusterState(String name, String dnsHostName,
               @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests,
               @Nullable UpstreamTlsContext tlsContext) {
    -        super(name, lrsServerInfo, maxConcurrentRequests, tlsContext);
    +        super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, null);
             this.dnsHostName = checkNotNull(dnsHostName, "dnsHostName");
             nameResolverFactory =
                 checkNotNull(helper.getNameResolverRegistry().asFactory(), "nameResolverFactory");
    @@ -677,55 +732,115 @@ private static PriorityChildConfig generateDnsBasedPriorityChildConfig(
       private static Map generateEdsBasedPriorityChildConfigs(
           String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo,
           @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext,
    -      PolicySelection endpointLbPolicy, LoadBalancerRegistry lbRegistry,
    -      Map> prioritizedLocalityWeights,
    -      List dropOverloads) {
    +      @Nullable OutlierDetection outlierDetection, PolicySelection endpointLbPolicy,
    +      LoadBalancerRegistry lbRegistry, Map> prioritizedLocalityWeights, List dropOverloads) {
         Map configs = new HashMap<>();
         for (String priority : prioritizedLocalityWeights.keySet()) {
    -      PolicySelection leafPolicy =  endpointLbPolicy;
    -      // Depending on the endpoint-level load balancing policy, different LB hierarchy may be
    -      // created. If the endpoint-level LB policy is round_robin or least_request_experimental,
    -      // it creates a two-level LB hierarchy: a locality-level LB policy that balances load
    -      // according to locality weights followed by an endpoint-level LB policy that balances load
    -      // between endpoints within the locality. If the endpoint-level LB policy is
    -      // ring_hash_experimental, it creates a unified LB policy that balances load by weighing the
    -      // product of each endpoint's weight and the weight of the locality it belongs to.
    -      if (endpointLbPolicy.getProvider().getPolicyName().equals("round_robin")
    -          || endpointLbPolicy.getProvider().getPolicyName().equals("least_request_experimental")) {
    -        Map localityWeights = prioritizedLocalityWeights.get(priority);
    -        Map targets = new HashMap<>();
    -        for (Locality locality : localityWeights.keySet()) {
    -          int weight = localityWeights.get(locality);
    -          WeightedPolicySelection target = new WeightedPolicySelection(weight, endpointLbPolicy);
    -          targets.put(localityName(locality), target);
    -        }
    -        LoadBalancerProvider weightedTargetLbProvider =
    -            lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME);
    -        WeightedTargetConfig weightedTargetConfig =
    -            new WeightedTargetConfig(Collections.unmodifiableMap(targets));
    -        leafPolicy = new PolicySelection(weightedTargetLbProvider, weightedTargetConfig);
    -      }
           ClusterImplConfig clusterImplConfig =
               new ClusterImplConfig(cluster, edsServiceName, lrsServerInfo, maxConcurrentRequests,
    -              dropOverloads, leafPolicy, tlsContext);
    +              dropOverloads, endpointLbPolicy, tlsContext);
           LoadBalancerProvider clusterImplLbProvider =
               lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME);
    -      PolicySelection clusterImplPolicy =
    +      PolicySelection priorityChildPolicy =
               new PolicySelection(clusterImplLbProvider, clusterImplConfig);
    +
    +      // If outlier detection has been configured we wrap the child policy in the outlier detection
    +      // load balancer.
    +      if (outlierDetection != null) {
    +        LoadBalancerProvider outlierDetectionProvider = lbRegistry.getProvider(
    +            "outlier_detection_experimental");
    +        priorityChildPolicy = new PolicySelection(outlierDetectionProvider,
    +            buildOutlierDetectionLbConfig(outlierDetection, priorityChildPolicy));
    +      }
    +
           PriorityChildConfig priorityChildConfig =
    -          new PriorityChildConfig(clusterImplPolicy, true /* ignoreReresolution */);
    +          new PriorityChildConfig(priorityChildPolicy, true /* ignoreReresolution */);
           configs.put(priority, priorityChildConfig);
         }
         return configs;
       }
     
    +  /**
    +   * Converts {@link OutlierDetection} that represents the xDS configuration to {@link
    +   * OutlierDetectionLoadBalancerConfig} that the {@link io.grpc.util.OutlierDetectionLoadBalancer}
    +   * understands.
    +   */
    +  private static OutlierDetectionLoadBalancerConfig buildOutlierDetectionLbConfig(
    +      OutlierDetection outlierDetection, PolicySelection childPolicy) {
    +    OutlierDetectionLoadBalancerConfig.Builder configBuilder
    +        = new OutlierDetectionLoadBalancerConfig.Builder();
    +
    +    configBuilder.setChildPolicy(childPolicy);
    +
    +    if (outlierDetection.intervalNanos() != null) {
    +      configBuilder.setIntervalNanos(outlierDetection.intervalNanos());
    +    }
    +    if (outlierDetection.baseEjectionTimeNanos() != null) {
    +      configBuilder.setBaseEjectionTimeNanos(outlierDetection.baseEjectionTimeNanos());
    +    }
    +    if (outlierDetection.maxEjectionTimeNanos() != null) {
    +      configBuilder.setMaxEjectionTimeNanos(outlierDetection.maxEjectionTimeNanos());
    +    }
    +    if (outlierDetection.maxEjectionPercent() != null) {
    +      configBuilder.setMaxEjectionPercent(outlierDetection.maxEjectionPercent());
    +    }
    +
    +    SuccessRateEjection successRate = outlierDetection.successRateEjection();
    +    if (successRate != null) {
    +      OutlierDetectionLoadBalancerConfig.SuccessRateEjection.Builder
    +          successRateConfigBuilder = new OutlierDetectionLoadBalancerConfig
    +          .SuccessRateEjection.Builder();
    +
    +      if (successRate.stdevFactor() != null) {
    +        successRateConfigBuilder.setStdevFactor(successRate.stdevFactor());
    +      }
    +      if (successRate.enforcementPercentage() != null) {
    +        successRateConfigBuilder.setEnforcementPercentage(successRate.enforcementPercentage());
    +      }
    +      if (successRate.minimumHosts() != null) {
    +        successRateConfigBuilder.setMinimumHosts(successRate.minimumHosts());
    +      }
    +      if (successRate.requestVolume() != null) {
    +        successRateConfigBuilder.setRequestVolume(successRate.requestVolume());
    +      }
    +
    +      configBuilder.setSuccessRateEjection(successRateConfigBuilder.build());
    +    }
    +
    +    FailurePercentageEjection failurePercentage = outlierDetection.failurePercentageEjection();
    +    if (failurePercentage != null) {
    +      OutlierDetectionLoadBalancerConfig.FailurePercentageEjection.Builder
    +          failurePercentageConfigBuilder = new OutlierDetectionLoadBalancerConfig
    +          .FailurePercentageEjection.Builder();
    +
    +      if (failurePercentage.threshold() != null) {
    +        failurePercentageConfigBuilder.setThreshold(failurePercentage.threshold());
    +      }
    +      if (failurePercentage.enforcementPercentage() != null) {
    +        failurePercentageConfigBuilder.setEnforcementPercentage(
    +            failurePercentage.enforcementPercentage());
    +      }
    +      if (failurePercentage.minimumHosts() != null) {
    +        failurePercentageConfigBuilder.setMinimumHosts(failurePercentage.minimumHosts());
    +      }
    +      if (failurePercentage.requestVolume() != null) {
    +        failurePercentageConfigBuilder.setRequestVolume(failurePercentage.requestVolume());
    +      }
    +
    +      configBuilder.setFailurePercentageEjection(failurePercentageConfigBuilder.build());
    +    }
    +
    +    return configBuilder.build();
    +  }
    +
       /**
        * Generates a string that represents the priority in the LB policy config. The string is unique
        * across priorities in all clusters and priorityName(c, p1) < priorityName(c, p2) iff p1 < p2.
        * The ordering is undefined for priorities in different clusters.
        */
       private static String priorityName(String cluster, int priority) {
    -    return cluster + "[priority" + priority + "]";
    +    return cluster + "[child" + priority + "]";
       }
     
       /**
    diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java
    index 6f6f887e925..38da1f465c1 100644
    --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java
    +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java
    @@ -26,6 +26,7 @@
     import io.grpc.NameResolver.ConfigOrError;
     import io.grpc.internal.ServiceConfigUtil.PolicySelection;
     import io.grpc.xds.Bootstrapper.ServerInfo;
    +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection;
     import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
     import java.util.List;
     import java.util.Map;
    @@ -124,6 +125,8 @@ static final class DiscoveryMechanism {
           // Hostname for resolving endpoints via DNS. Only valid for LOGICAL_DNS clusters.
           @Nullable
           final String dnsHostName;
    +      @Nullable
    +      final OutlierDetection outlierDetection;
     
           enum Type {
             EDS,
    @@ -132,7 +135,8 @@ enum Type {
     
           private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServiceName,
               @Nullable String dnsHostName, @Nullable ServerInfo lrsServerInfo,
    -          @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) {
    +          @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext,
    +          @Nullable OutlierDetection outlierDetection) {
             this.cluster = checkNotNull(cluster, "cluster");
             this.type = checkNotNull(type, "type");
             this.edsServiceName = edsServiceName;
    @@ -140,20 +144,22 @@ private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServic
             this.lrsServerInfo = lrsServerInfo;
             this.maxConcurrentRequests = maxConcurrentRequests;
             this.tlsContext = tlsContext;
    +        this.outlierDetection = outlierDetection;
           }
     
           static DiscoveryMechanism forEds(String cluster, @Nullable String edsServiceName,
               @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests,
    -          @Nullable UpstreamTlsContext tlsContext) {
    +          @Nullable UpstreamTlsContext tlsContext,
    +          OutlierDetection outlierDetection) {
             return new DiscoveryMechanism(cluster, Type.EDS, edsServiceName, null, lrsServerInfo,
    -            maxConcurrentRequests, tlsContext);
    +            maxConcurrentRequests, tlsContext, outlierDetection);
           }
     
           static DiscoveryMechanism forLogicalDns(String cluster, String dnsHostName,
               @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests,
               @Nullable UpstreamTlsContext tlsContext) {
             return new DiscoveryMechanism(cluster, Type.LOGICAL_DNS, null, dnsHostName,
    -            lrsServerInfo, maxConcurrentRequests, tlsContext);
    +            lrsServerInfo, maxConcurrentRequests, tlsContext, null);
           }
     
           @Override
    diff --git a/xds/src/main/java/io/grpc/xds/CsdsService.java b/xds/src/main/java/io/grpc/xds/CsdsService.java
    index edee01f95f1..3aab66d94c9 100644
    --- a/xds/src/main/java/io/grpc/xds/CsdsService.java
    +++ b/xds/src/main/java/io/grpc/xds/CsdsService.java
    @@ -33,7 +33,6 @@
     import io.grpc.StatusException;
     import io.grpc.internal.ObjectPool;
     import io.grpc.stub.StreamObserver;
    -import io.grpc.xds.AbstractXdsClient.ResourceType;
     import io.grpc.xds.XdsClient.ResourceMetadata;
     import io.grpc.xds.XdsClient.ResourceMetadata.ResourceMetadataStatus;
     import io.grpc.xds.XdsClient.ResourceMetadata.UpdateFailureState;
    @@ -156,12 +155,12 @@ static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient) throws Inte
         ClientConfig.Builder builder = ClientConfig.newBuilder()
             .setNode(xdsClient.getBootstrapInfo().node().toEnvoyProtoNode());
     
    -    Map> metadataByType =
    +    Map, Map> metadataByType =
             awaitSubscribedResourcesMetadata(xdsClient.getSubscribedResourcesMetadataSnapshot());
     
    -    for (Map.Entry> metadataByTypeEntry
    +    for (Map.Entry, Map> metadataByTypeEntry
             : metadataByType.entrySet()) {
    -      ResourceType type = metadataByTypeEntry.getKey();
    +      XdsResourceType type = metadataByTypeEntry.getKey();
           Map metadataByResourceName = metadataByTypeEntry.getValue();
           for (Map.Entry metadataEntry : metadataByResourceName.entrySet()) {
             String resourceName = metadataEntry.getKey();
    @@ -187,8 +186,9 @@ static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient) throws Inte
         return builder.build();
       }
     
    -  private static Map> awaitSubscribedResourcesMetadata(
    -      ListenableFuture>> future)
    +  private static Map, Map>
    +      awaitSubscribedResourcesMetadata(
    +      ListenableFuture, Map>> future)
           throws InterruptedException {
         try {
           // Normally this shouldn't take long, but add some slack for cases like a cold JVM.
    diff --git a/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java
    index 8274b23a5d2..cb4fc4ee30d 100644
    --- a/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java
    +++ b/xds/src/main/java/io/grpc/xds/EnvoyProtoData.java
    @@ -270,38 +270,6 @@ public io.envoyproxy.envoy.config.core.v3.Node toEnvoyProtoNode() {
           builder.addAllClientFeatures(clientFeatures);
           return builder.build();
         }
    -
    -    @SuppressWarnings("deprecation") // Deprecated v2 API setBuildVersion().
    -    public io.envoyproxy.envoy.api.v2.core.Node toEnvoyProtoNodeV2() {
    -      io.envoyproxy.envoy.api.v2.core.Node.Builder builder =
    -          io.envoyproxy.envoy.api.v2.core.Node.newBuilder();
    -      builder.setId(id);
    -      builder.setCluster(cluster);
    -      if (metadata != null) {
    -        Struct.Builder structBuilder = Struct.newBuilder();
    -        for (Map.Entry entry : metadata.entrySet()) {
    -          structBuilder.putFields(entry.getKey(), convertToValue(entry.getValue()));
    -        }
    -        builder.setMetadata(structBuilder);
    -      }
    -      if (locality != null) {
    -        builder.setLocality(
    -            io.envoyproxy.envoy.api.v2.core.Locality.newBuilder()
    -                .setRegion(locality.region())
    -                .setZone(locality.zone())
    -                .setSubZone(locality.subZone()));
    -      }
    -      for (Address address : listeningAddresses) {
    -        builder.addListeningAddresses(address.toEnvoyProtoAddressV2());
    -      }
    -      builder.setBuildVersion(buildVersion);
    -      builder.setUserAgentName(userAgentName);
    -      if (userAgentVersion != null) {
    -        builder.setUserAgentVersion(userAgentVersion);
    -      }
    -      builder.addAllClientFeatures(clientFeatures);
    -      return builder.build();
    -    }
       }
     
       /**
    diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java
    index e53439755be..5015c56ba92 100644
    --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java
    +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java
    @@ -19,9 +19,10 @@
     import com.google.auto.value.AutoValue;
     import com.google.common.annotations.VisibleForTesting;
     import com.google.common.collect.ImmutableList;
    +import com.google.protobuf.util.Durations;
     import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
     import io.grpc.Internal;
    -import io.grpc.xds.internal.sds.SslContextProviderSupplier;
    +import io.grpc.xds.internal.security.SslContextProviderSupplier;
     import java.net.InetAddress;
     import java.net.UnknownHostException;
     import java.util.Objects;
    @@ -254,4 +255,148 @@ static Listener create(
               defaultFilterChain);
         }
       }
    +
    +  /**
    +   * Corresponds to Envoy proto message {@link
    +   * io.envoyproxy.envoy.config.cluster.v3.OutlierDetection}. Only the fields supported by gRPC are
    +   * included.
    +   *
    +   * 

    Protobuf Duration fields are represented in their string format (e.g. "10s"). + */ + @AutoValue + abstract static class OutlierDetection { + + @Nullable + abstract Long intervalNanos(); + + @Nullable + abstract Long baseEjectionTimeNanos(); + + @Nullable + abstract Long maxEjectionTimeNanos(); + + @Nullable + abstract Integer maxEjectionPercent(); + + @Nullable + abstract SuccessRateEjection successRateEjection(); + + @Nullable + abstract FailurePercentageEjection failurePercentageEjection(); + + static OutlierDetection create( + @Nullable Long intervalNanos, + @Nullable Long baseEjectionTimeNanos, + @Nullable Long maxEjectionTimeNanos, + @Nullable Integer maxEjectionPercentage, + @Nullable SuccessRateEjection successRateEjection, + @Nullable FailurePercentageEjection failurePercentageEjection) { + return new AutoValue_EnvoyServerProtoData_OutlierDetection(intervalNanos, + baseEjectionTimeNanos, maxEjectionTimeNanos, maxEjectionPercentage, successRateEjection, + failurePercentageEjection); + } + + static OutlierDetection fromEnvoyOutlierDetection( + io.envoyproxy.envoy.config.cluster.v3.OutlierDetection envoyOutlierDetection) { + + Long intervalNanos = envoyOutlierDetection.hasInterval() + ? Durations.toNanos(envoyOutlierDetection.getInterval()) : null; + Long baseEjectionTimeNanos = envoyOutlierDetection.hasBaseEjectionTime() + ? Durations.toNanos(envoyOutlierDetection.getBaseEjectionTime()) : null; + Long maxEjectionTimeNanos = envoyOutlierDetection.hasMaxEjectionTime() + ? Durations.toNanos(envoyOutlierDetection.getMaxEjectionTime()) : null; + Integer maxEjectionPercentage = envoyOutlierDetection.hasMaxEjectionPercent() + ? envoyOutlierDetection.getMaxEjectionPercent().getValue() : null; + + SuccessRateEjection successRateEjection; + // If success rate enforcement has been turned completely off, don't configure this ejection. + if (envoyOutlierDetection.hasEnforcingSuccessRate() + && envoyOutlierDetection.getEnforcingSuccessRate().getValue() == 0) { + successRateEjection = null; + } else { + Integer stdevFactor = envoyOutlierDetection.hasSuccessRateStdevFactor() + ? envoyOutlierDetection.getSuccessRateStdevFactor().getValue() : null; + Integer enforcementPercentage = envoyOutlierDetection.hasEnforcingSuccessRate() + ? envoyOutlierDetection.getEnforcingSuccessRate().getValue() : null; + Integer minimumHosts = envoyOutlierDetection.hasSuccessRateMinimumHosts() + ? envoyOutlierDetection.getSuccessRateMinimumHosts().getValue() : null; + Integer requestVolume = envoyOutlierDetection.hasSuccessRateRequestVolume() + ? envoyOutlierDetection.getSuccessRateMinimumHosts().getValue() : null; + + successRateEjection = SuccessRateEjection.create(stdevFactor, enforcementPercentage, + minimumHosts, requestVolume); + } + + FailurePercentageEjection failurePercentageEjection; + if (envoyOutlierDetection.hasEnforcingFailurePercentage() + && envoyOutlierDetection.getEnforcingFailurePercentage().getValue() == 0) { + failurePercentageEjection = null; + } else { + Integer threshold = envoyOutlierDetection.hasFailurePercentageThreshold() + ? envoyOutlierDetection.getFailurePercentageThreshold().getValue() : null; + Integer enforcementPercentage = envoyOutlierDetection.hasEnforcingFailurePercentage() + ? envoyOutlierDetection.getEnforcingFailurePercentage().getValue() : null; + Integer minimumHosts = envoyOutlierDetection.hasFailurePercentageMinimumHosts() + ? envoyOutlierDetection.getFailurePercentageMinimumHosts().getValue() : null; + Integer requestVolume = envoyOutlierDetection.hasFailurePercentageRequestVolume() + ? envoyOutlierDetection.getFailurePercentageRequestVolume().getValue() : null; + + failurePercentageEjection = FailurePercentageEjection.create(threshold, + enforcementPercentage, minimumHosts, requestVolume); + } + + return create(intervalNanos, baseEjectionTimeNanos, maxEjectionTimeNanos, + maxEjectionPercentage, successRateEjection, failurePercentageEjection); + } + } + + @AutoValue + abstract static class SuccessRateEjection { + + @Nullable + abstract Integer stdevFactor(); + + @Nullable + abstract Integer enforcementPercentage(); + + @Nullable + abstract Integer minimumHosts(); + + @Nullable + abstract Integer requestVolume(); + + static SuccessRateEjection create( + @Nullable Integer stdevFactor, + @Nullable Integer enforcementPercentage, + @Nullable Integer minimumHosts, + @Nullable Integer requestVolume) { + return new AutoValue_EnvoyServerProtoData_SuccessRateEjection(stdevFactor, + enforcementPercentage, minimumHosts, requestVolume); + } + } + + @AutoValue + abstract static class FailurePercentageEjection { + + @Nullable + abstract Integer threshold(); + + @Nullable + abstract Integer enforcementPercentage(); + + @Nullable + abstract Integer minimumHosts(); + + @Nullable + abstract Integer requestVolume(); + + static FailurePercentageEjection create( + @Nullable Integer threshold, + @Nullable Integer enforcementPercentage, + @Nullable Integer minimumHosts, + @Nullable Integer requestVolume) { + return new AutoValue_EnvoyServerProtoData_FailurePercentageEjection(threshold, + enforcementPercentage, minimumHosts, requestVolume); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index f0c37b20df4..d46b3d30f5a 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -48,6 +48,7 @@ import io.grpc.xds.FaultConfig.FaultDelay; import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; +import java.util.Locale; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -255,6 +256,7 @@ public void onClose(Status status, Metadata trailers) { // only sent out after the delay. There could be a race between local and // remote, but it is rather rare.) String description = String.format( + Locale.US, "Deadline exceeded after up to %d ns of fault-injected delay", finalDelayNanos); if (status.getDescription() != null) { @@ -400,7 +402,10 @@ public void run() { activeFaultCounter.decrementAndGet(); } } - setCall(callSupplier.get()); + Runnable toRun = setCall(callSupplier.get()); + if (toRun != null) { + toRun.run(); + } } }, delayNanos, diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index e75440225dc..fa03b2add4d 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -20,7 +20,7 @@ import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; -import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; +import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -41,7 +41,7 @@ import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; import io.grpc.xds.internal.Matchers.CidrMatcher; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; diff --git a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java new file mode 100644 index 00000000000..54e6c748cd5 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java @@ -0,0 +1,40 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; +import io.grpc.Internal; +import io.grpc.ServerInterceptor; +import io.grpc.xds.RbacConfig; +import io.grpc.xds.RbacFilter; + +/** This class exposes some functionality in RbacFilter to other packages. */ +@Internal +public final class InternalRbacFilter { + + private InternalRbacFilter() {} + + /** Parses RBAC filter config and creates AuthorizationServerInterceptor. */ + public static ServerInterceptor createInterceptor(RBAC rbac) { + ConfigOrError filterConfig = RbacFilter.parseRbacConfig(rbac); + if (filterConfig.errorDetail != null) { + throw new IllegalArgumentException( + String.format("Failed to parse Rbac policy: %s", filterConfig.errorDetail)); + } + return new RbacFilter().buildServerInterceptor(filterConfig.config, null); + } +} diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java new file mode 100644 index 00000000000..114300c9281 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -0,0 +1,33 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.Internal; +import java.util.Map; + +/** + * Accessor for global factory for managing XdsClient instance. + */ +@Internal +public final class InternalSharedXdsClientPoolProvider { + // Prevent instantiation + private InternalSharedXdsClientPoolProvider() {} + + public static void setDefaultProviderBootstrapOverride(Map bootstrap) { + SharedXdsClientPoolProvider.getDefaultProvider().setBootstrapOverride(bootstrap); + } +} diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index 410a64df9ca..bd21a8ac13e 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -23,7 +23,7 @@ import io.grpc.NameResolver; import io.grpc.internal.ObjectPool; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.SslContextProviderSupplier; /** * Internal attributes used for xDS implementation. Do not use. @@ -36,7 +36,7 @@ public final class InternalXdsAttributes { @Grpc.TransportAttr public static final Attributes.Key ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.sds.SslContextProviderSupplier"); + Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); /** * Attribute key for passing around the XdsClient object pool across NameResolver/LoadBalancers. @@ -53,6 +53,13 @@ public final class InternalXdsAttributes { static final Attributes.Key CALL_COUNTER_PROVIDER = Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.callCounterProvider"); + /** + * Map from localities to their weights. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key ATTR_LOCALITY_WEIGHT = + Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.localityWeight"); + /** * Name of the cluster that provides this EquivalentAddressGroup. */ diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index 584ac2dd16f..b4aa39821d2 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -88,7 +88,13 @@ final class LeastRequestLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getAddresses().isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } LeastRequestConfig config = (LeastRequestConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); // Config may be null if least_request is used outside xDS @@ -146,6 +152,8 @@ public void onSubchannelState(ConnectivityStateInfo state) { for (Subchannel removedSubchannel : removedSubchannels) { shutdownSubchannel(removedSubchannel); } + + return true; } @Override diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java index 3abac1d2f0d..f9281f695ca 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java @@ -67,13 +67,13 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { choiceCount = DEFAULT_CHOICE_COUNT; } if (choiceCount < MIN_CHOICE_COUNT) { - return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription( - "Invalid 'choiceCount'")); + return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( + "Invalid 'choiceCount' in least_request_experimental config")); } return ConfigOrError.fromConfig(new LeastRequestConfig(choiceCount)); } catch (RuntimeException e) { return ConfigOrError.fromError( - Status.fromThrowable(e).withDescription( + Status.UNAVAILABLE.withCause(e).withDescription( "Failed to parse least_request_experimental LB config: " + rawConfig)); } } diff --git a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java new file mode 100644 index 00000000000..ce3e95f03d1 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java @@ -0,0 +1,362 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Struct; +import com.google.protobuf.util.JsonFormat; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest; +import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash; +import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.grpc.InternalLogId; +import io.grpc.LoadBalancerRegistry; +import io.grpc.internal.JsonParser; +import io.grpc.xds.LoadBalancerConfigFactory.LoadBalancingPolicyConverter.MaxRecursionReachedException; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsLogger.XdsLogLevel; +import java.io.IOException; +import java.util.Map; + +/** + * Creates service config JSON load balancer config objects for a given xDS Cluster message. + * Supports both the "legacy" configuration style and the new, more advanced one that utilizes the + * xDS "typed extension" mechanism. + * + *

    Legacy configuration is done by setting the lb_policy enum field and any supporting + * configuration fields needed by the particular policy. + * + *

    The new approach is to set the load_balancing_policy field that contains both the policy + * selection as well as any supporting configuration data. Providing a list of acceptable policies + * is also supported. Note that if this field is used, it will override any configuration set using + * the legacy approach. The new configuration approach is explained in detail in the Custom LB Policies + * gRFC + */ +class LoadBalancerConfigFactory { + + private static final XdsLogger logger = XdsLogger.withLogId( + InternalLogId.allocate("xds-client-lbconfig-factory", null)); + + static final String ROUND_ROBIN_FIELD_NAME = "round_robin"; + + static final String RING_HASH_FIELD_NAME = "ring_hash_experimental"; + static final String MIN_RING_SIZE_FIELD_NAME = "minRingSize"; + static final String MAX_RING_SIZE_FIELD_NAME = "maxRingSize"; + + static final String LEAST_REQUEST_FIELD_NAME = "least_request_experimental"; + static final String CHOICE_COUNT_FIELD_NAME = "choiceCount"; + + static final String WRR_LOCALITY_FIELD_NAME = "wrr_locality_experimental"; + static final String CHILD_POLICY_FIELD = "childPolicy"; + + /** + * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link + * Cluster}. + * + * @throws ResourceInvalidException If the {@link Cluster} has an invalid LB configuration. + */ + static ImmutableMap newConfig(Cluster cluster, boolean enableLeastRequest, + boolean enableCustomLbConfig) + throws ResourceInvalidException { + // The new load_balancing_policy will always be used if it is set, but for backward + // compatibility we will fall back to using the old lb_policy field if the new field is not set. + if (cluster.hasLoadBalancingPolicy() && enableCustomLbConfig) { + try { + return LoadBalancingPolicyConverter.convertToServiceConfig(cluster.getLoadBalancingPolicy(), + 0); + } catch (MaxRecursionReachedException e) { + throw new ResourceInvalidException("Maximum LB config recursion depth reached", e); + } + } else { + return LegacyLoadBalancingPolicyConverter.convertToServiceConfig(cluster, enableLeastRequest); + } + } + + /** + * Builds a service config JSON object for the ring_hash load balancer config based on the given + * config values. + */ + private static ImmutableMap buildRingHashConfig(Long minRingSize, Long maxRingSize) { + ImmutableMap.Builder configBuilder = ImmutableMap.builder(); + if (minRingSize != null) { + configBuilder.put(MIN_RING_SIZE_FIELD_NAME, minRingSize.doubleValue()); + } + if (maxRingSize != null) { + configBuilder.put(MAX_RING_SIZE_FIELD_NAME, maxRingSize.doubleValue()); + } + return ImmutableMap.of(RING_HASH_FIELD_NAME, configBuilder.buildOrThrow()); + } + + /** + * Builds a service config JSON object for the least_request load balancer config based on the + * given config values.. + */ + private static ImmutableMap buildLeastRequestConfig(Integer choiceCount) { + ImmutableMap.Builder configBuilder = ImmutableMap.builder(); + if (choiceCount != null) { + configBuilder.put(CHOICE_COUNT_FIELD_NAME, choiceCount.doubleValue()); + } + return ImmutableMap.of(LEAST_REQUEST_FIELD_NAME, configBuilder.buildOrThrow()); + } + + /** + * Builds a service config JSON wrr_locality by wrapping another policy config. + */ + private static ImmutableMap buildWrrLocalityConfig( + ImmutableMap childConfig) { + return ImmutableMap.builder().put(WRR_LOCALITY_FIELD_NAME, + ImmutableMap.of(CHILD_POLICY_FIELD, ImmutableList.of(childConfig))).buildOrThrow(); + } + + /** + * Builds an empty service config JSON config object for round robin (it is not configurable). + */ + private static ImmutableMap buildRoundRobinConfig() { + return ImmutableMap.of(ROUND_ROBIN_FIELD_NAME, ImmutableMap.of()); + } + + /** + * Responsible for converting from a {@code envoy.config.cluster.v3.LoadBalancingPolicy} proto + * message to a gRPC service config format. + */ + static class LoadBalancingPolicyConverter { + + private static final int MAX_RECURSION = 16; + + /** + * Converts a {@link LoadBalancingPolicy} object to a service config JSON object. + */ + private static ImmutableMap convertToServiceConfig( + LoadBalancingPolicy loadBalancingPolicy, int recursionDepth) + throws ResourceInvalidException, MaxRecursionReachedException { + if (recursionDepth > MAX_RECURSION) { + throw new MaxRecursionReachedException(); + } + ImmutableMap serviceConfig = null; + + for (Policy policy : loadBalancingPolicy.getPoliciesList()) { + Any typedConfig = policy.getTypedExtensionConfig().getTypedConfig(); + try { + if (typedConfig.is(RingHash.class)) { + serviceConfig = convertRingHashConfig(typedConfig.unpack(RingHash.class)); + } else if (typedConfig.is(WrrLocality.class)) { + serviceConfig = convertWrrLocalityConfig(typedConfig.unpack(WrrLocality.class), + recursionDepth); + } else if (typedConfig.is(RoundRobin.class)) { + serviceConfig = convertRoundRobinConfig(); + } else if (typedConfig.is(LeastRequest.class)) { + serviceConfig = convertLeastRequestConfig(typedConfig.unpack(LeastRequest.class)); + } else if (typedConfig.is(com.github.xds.type.v3.TypedStruct.class)) { + serviceConfig = convertCustomConfig( + typedConfig.unpack(com.github.xds.type.v3.TypedStruct.class)); + } else if (typedConfig.is(com.github.udpa.udpa.type.v1.TypedStruct.class)) { + serviceConfig = convertCustomConfig( + typedConfig.unpack(com.github.udpa.udpa.type.v1.TypedStruct.class)); + } + + // TODO: support least_request once it is added to the envoy protos. + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException( + "Unable to unpack typedConfig for: " + typedConfig.getTypeUrl(), e); + } + // The service config is expected to have a single root entry, where the name of that entry + // is the name of the policy. A Load balancer with this name must exist in the registry. + if (serviceConfig == null || LoadBalancerRegistry.getDefaultRegistry() + .getProvider(Iterables.getOnlyElement(serviceConfig.keySet())) == null) { + logger.log(XdsLogLevel.WARNING, "Policy {0} not found in the LB registry, skipping", + typedConfig.getTypeUrl()); + continue; + } else { + return serviceConfig; + } + } + + // If we could not find a Policy that we could both convert as well as find a provider for + // then we have an invalid LB policy configuration. + throw new ResourceInvalidException("Invalid LoadBalancingPolicy: " + loadBalancingPolicy); + } + + /** + * Converts a ring_hash {@link Any} configuration to service config format. + */ + private static ImmutableMap convertRingHashConfig(RingHash ringHash) + throws ResourceInvalidException { + // The hash function needs to be validated here as it is not exposed in the returned + // configuration for later validation. + if (RingHash.HashFunction.XX_HASH != ringHash.getHashFunction()) { + throw new ResourceInvalidException( + "Invalid ring hash function: " + ringHash.getHashFunction()); + } + + return buildRingHashConfig( + ringHash.hasMinimumRingSize() ? ringHash.getMinimumRingSize().getValue() : null, + ringHash.hasMaximumRingSize() ? ringHash.getMaximumRingSize().getValue() : null); + } + + /** + * Converts a wrr_locality {@link Any} configuration to service config format. + */ + private static ImmutableMap convertWrrLocalityConfig(WrrLocality wrrLocality, + int recursionDepth) throws ResourceInvalidException, + MaxRecursionReachedException { + return buildWrrLocalityConfig( + convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(), recursionDepth + 1)); + } + + /** + * "Converts" a round_robin configuration to service config format. + */ + private static ImmutableMap convertRoundRobinConfig() { + return buildRoundRobinConfig(); + } + + /** + * Converts a least_request {@link Any} configuration to service config format. + */ + private static ImmutableMap convertLeastRequestConfig(LeastRequest leastRequest) + throws ResourceInvalidException { + return buildLeastRequestConfig( + leastRequest.hasChoiceCount() ? leastRequest.getChoiceCount().getValue() : null); + } + + /** + * Converts a custom TypedStruct LB config to service config format. + */ + @SuppressWarnings("unchecked") + private static ImmutableMap convertCustomConfig( + com.github.xds.type.v3.TypedStruct configTypedStruct) + throws ResourceInvalidException { + return ImmutableMap.of(parseCustomConfigTypeName(configTypedStruct.getTypeUrl()), + (Map) parseCustomConfigJson(configTypedStruct.getValue())); + } + + /** + * Converts a custom UDPA (legacy) TypedStruct LB config to service config format. + */ + @SuppressWarnings("unchecked") + private static ImmutableMap convertCustomConfig( + com.github.udpa.udpa.type.v1.TypedStruct configTypedStruct) + throws ResourceInvalidException { + return ImmutableMap.of(parseCustomConfigTypeName(configTypedStruct.getTypeUrl()), + (Map) parseCustomConfigJson(configTypedStruct.getValue())); + } + + /** + * Print the config Struct into JSON and then parse that into our internal representation. + */ + private static Object parseCustomConfigJson(Struct configStruct) + throws ResourceInvalidException { + Object rawJsonConfig = null; + try { + rawJsonConfig = JsonParser.parse(JsonFormat.printer().print(configStruct)); + } catch (IOException e) { + throw new ResourceInvalidException("Unable to parse custom LB config JSON", e); + } + + if (!(rawJsonConfig instanceof Map)) { + throw new ResourceInvalidException("Custom LB config does not contain a JSON object"); + } + return rawJsonConfig; + } + + + private static String parseCustomConfigTypeName(String customConfigTypeName) { + if (customConfigTypeName.contains("/")) { + customConfigTypeName = customConfigTypeName.substring( + customConfigTypeName.lastIndexOf("/") + 1); + } + return customConfigTypeName; + } + + // Used to signal that the LB config goes too deep. + static class MaxRecursionReachedException extends Exception { + static final long serialVersionUID = 1L; + } + } + + /** + * Builds a JSON LB configuration based on the old style of using the xDS Cluster proto message. + * The lb_policy field is used to select the policy and configuration is extracted from various + * policy specific fields in Cluster. + */ + static class LegacyLoadBalancingPolicyConverter { + + /** + * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link + * Cluster}. + * + * @throws ResourceInvalidException If the {@link Cluster} has an invalid LB configuration. + */ + static ImmutableMap convertToServiceConfig(Cluster cluster, + boolean enableLeastRequest) throws ResourceInvalidException { + switch (cluster.getLbPolicy()) { + case RING_HASH: + return convertRingHashConfig(cluster); + case ROUND_ROBIN: + return buildWrrLocalityConfig(buildRoundRobinConfig()); + case LEAST_REQUEST: + if (enableLeastRequest) { + return buildWrrLocalityConfig(convertLeastRequestConfig(cluster)); + } + break; + default: + } + throw new ResourceInvalidException( + "Cluster " + cluster.getName() + ": unsupported lb policy: " + cluster.getLbPolicy()); + } + + /** + * Creates a new ring_hash service config JSON object based on the old {@link RingHashLbConfig} + * config message. + */ + private static ImmutableMap convertRingHashConfig(Cluster cluster) + throws ResourceInvalidException { + RingHashLbConfig lbConfig = cluster.getRingHashLbConfig(); + + // The hash function needs to be validated here as it is not exposed in the returned + // configuration for later validation. + if (lbConfig.getHashFunction() != RingHashLbConfig.HashFunction.XX_HASH) { + throw new ResourceInvalidException( + "Cluster " + cluster.getName() + ": invalid ring hash function: " + lbConfig); + } + + return buildRingHashConfig( + lbConfig.hasMinimumRingSize() ? (Long) lbConfig.getMinimumRingSize().getValue() : null, + lbConfig.hasMaximumRingSize() ? (Long) lbConfig.getMaximumRingSize().getValue() : null); + } + + /** + * Creates a new least_request service config JSON object based on the old {@link + * LeastRequestLbConfig} config message. + */ + private static ImmutableMap convertLeastRequestConfig(Cluster cluster) { + LeastRequestLbConfig lbConfig = cluster.getLeastRequestLbConfig(); + return buildLeastRequestConfig( + lbConfig.hasChoiceCount() ? (Integer) lbConfig.getChoiceCount().getValue() : null); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/LoadReportClient.java b/xds/src/main/java/io/grpc/xds/LoadReportClient.java index af2a673e9f7..d6a3679d4cd 100644 --- a/xds/src/main/java/io/grpc/xds/LoadReportClient.java +++ b/xds/src/main/java/io/grpc/xds/LoadReportClient.java @@ -25,7 +25,6 @@ import com.google.common.base.Supplier; import com.google.protobuf.util.Durations; import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc; -import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc.LoadReportingServiceStub; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; import io.grpc.Channel; @@ -57,7 +56,6 @@ final class LoadReportClient { private final XdsLogger logger; private final Channel channel; private final Context context; - private final boolean useProtocolV3; private final Node node; private final SynchronizationContext syncContext; private final ScheduledExecutorService timerService; @@ -77,7 +75,6 @@ final class LoadReportClient { LoadStatsManager2 loadStatsManager, Channel channel, Context context, - boolean useProtocolV3, Node node, SynchronizationContext syncContext, ScheduledExecutorService scheduledExecutorService, @@ -86,7 +83,6 @@ final class LoadReportClient { this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager"); this.channel = checkNotNull(channel, "xdsChannel"); this.context = checkNotNull(context, "context"); - this.useProtocolV3 = useProtocolV3; this.syncContext = checkNotNull(syncContext, "syncContext"); this.timerService = checkNotNull(scheduledExecutorService, "timeService"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); @@ -161,11 +157,7 @@ private void startLrsRpc() { return; } checkState(lrsStream == null, "previous lbStream has not been cleared yet"); - if (useProtocolV3) { - lrsStream = new LrsStreamV3(); - } else { - lrsStream = new LrsStreamV2(); - } + lrsStream = new LrsStream(); retryStopwatch.reset().start(); Context prevContext = context.attach(); try { @@ -175,22 +167,73 @@ private void startLrsRpc() { } } - private abstract class LrsStream { + private final class LrsStream { boolean initialResponseReceived; boolean closed; long intervalNano = -1; boolean reportAllClusters; List clusterNames; // clusters to report loads for, if not report all. ScheduledHandle loadReportTimer; + StreamObserver lrsRequestWriterV3; + + void start() { + StreamObserver lrsResponseReaderV3 = + new StreamObserver() { + @Override + public void onNext(final LoadStatsResponse response) { + syncContext.execute(new Runnable() { + @Override + public void run() { + logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); + handleRpcResponse(response.getClustersList(), response.getSendAllClusters(), + Durations.toNanos(response.getLoadReportingInterval())); + } + }); + } + + @Override + public void onError(final Throwable t) { + syncContext.execute(new Runnable() { + @Override + public void run() { + handleRpcError(t); + } + }); + } - abstract void start(); + @Override + public void onCompleted() { + syncContext.execute(new Runnable() { + @Override + public void run() { + handleRpcCompleted(); + } + }); + } + }; + lrsRequestWriterV3 = LoadReportingServiceGrpc.newStub(channel).withWaitForReady() + .streamLoadStats(lrsResponseReaderV3); + logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); + sendLoadStatsRequest(Collections.emptyList()); + } - abstract void sendLoadStatsRequest(List clusterStatsList); + void sendLoadStatsRequest(List clusterStatsList) { + LoadStatsRequest.Builder requestBuilder = + LoadStatsRequest.newBuilder().setNode(node.toEnvoyProtoNode()); + for (ClusterStats stats : clusterStatsList) { + requestBuilder.addClusterStats(buildClusterStats(stats)); + } + LoadStatsRequest request = requestBuilder.build(); + lrsRequestWriterV3.onNext(request); + logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request); + } - abstract void sendError(Exception error); + void sendError(Exception error) { + lrsRequestWriterV3.onError(error); + } - final void handleRpcResponse(List clusters, boolean sendAllClusters, - long loadReportIntervalNano) { + void handleRpcResponse(List clusters, boolean sendAllClusters, + long loadReportIntervalNano) { if (closed) { return; } @@ -210,11 +253,11 @@ final void handleRpcResponse(List clusters, boolean sendAllClusters, scheduleNextLoadReport(); } - final void handleRpcError(Throwable t) { + void handleRpcError(Throwable t) { handleStreamClosed(Status.fromThrowable(t)); } - final void handleRpcCompleted() { + void handleRpcCompleted() { handleStreamClosed(Status.UNAVAILABLE.withDescription("Closed by server")); } @@ -259,20 +302,16 @@ private void handleStreamClosed(Status status) { closed = true; cleanUp(); - long delayNanos = 0; if (initialResponseReceived || lrsRpcRetryPolicy == null) { // Reset the backoff sequence if balancer has sent the initial response, or backoff sequence // has never been initialized. lrsRpcRetryPolicy = backoffPolicyProvider.get(); } - // Backoff only when balancer wasn't working previously. - if (!initialResponseReceived) { - // The back-off policy determines the interval between consecutive RPC upstarts, thus the - // actual delay may be smaller than the value from the back-off policy, or even negative, - // depending how much time was spent in the previous RPC. - delayNanos = - lrsRpcRetryPolicy.nextBackoffNanos() - retryStopwatch.elapsed(TimeUnit.NANOSECONDS); - } + // The back-off policy determines the interval between consecutive RPC upstarts, thus the + // actual delay may be smaller than the value from the back-off policy, or even negative, + // depending how much time was spent in the previous RPC. + long delayNanos = + lrsRpcRetryPolicy.nextBackoffNanos() - retryStopwatch.elapsed(TimeUnit.NANOSECONDS); logger.log(XdsLogLevel.INFO, "Retry LRS stream in {0} ns", delayNanos); if (delayNanos <= 0) { startLrsRpc(); @@ -300,169 +339,6 @@ private void cleanUp() { lrsStream = null; } } - } - - private final class LrsStreamV2 extends LrsStream { - StreamObserver lrsRequestWriterV2; - - @Override - void start() { - StreamObserver - lrsResponseReaderV2 = - new StreamObserver() { - @Override - public void onNext( - final io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) { - syncContext.execute(new Runnable() { - @Override - public void run() { - logger.log(XdsLogLevel.DEBUG, "Received LoadStatsResponse:\n{0}", response); - handleRpcResponse(response.getClustersList(), response.getSendAllClusters(), - Durations.toNanos(response.getLoadReportingInterval())); - } - }); - } - - @Override - public void onError(final Throwable t) { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcError(t); - } - }); - } - - @Override - public void onCompleted() { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcCompleted(); - } - }); - } - }; - io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceStub - stubV2 = io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.newStub( - channel); - lrsRequestWriterV2 = stubV2.withWaitForReady().streamLoadStats(lrsResponseReaderV2); - logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); - sendLoadStatsRequest(Collections.emptyList()); - } - - @Override - void sendLoadStatsRequest(List clusterStatsList) { - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder requestBuilder = - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder() - .setNode(node.toEnvoyProtoNodeV2()); - for (ClusterStats stats : clusterStatsList) { - requestBuilder.addClusterStats(buildClusterStats(stats)); - } - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest request = requestBuilder.build(); - lrsRequestWriterV2.onNext(requestBuilder.build()); - logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request); - } - - @Override - void sendError(Exception error) { - lrsRequestWriterV2.onError(error); - } - - private io.envoyproxy.envoy.api.v2.endpoint.ClusterStats buildClusterStats( - ClusterStats stats) { - io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.Builder builder = - io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.newBuilder() - .setClusterName(stats.clusterName()); - if (stats.clusterServiceName() != null) { - builder.setClusterServiceName(stats.clusterServiceName()); - } - for (UpstreamLocalityStats upstreamLocalityStats : stats.upstreamLocalityStatsList()) { - builder.addUpstreamLocalityStats( - io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats.newBuilder() - .setLocality( - io.envoyproxy.envoy.api.v2.core.Locality.newBuilder() - .setRegion(upstreamLocalityStats.locality().region()) - .setZone(upstreamLocalityStats.locality().zone()) - .setSubZone(upstreamLocalityStats.locality().subZone())) - .setTotalSuccessfulRequests(upstreamLocalityStats.totalSuccessfulRequests()) - .setTotalErrorRequests(upstreamLocalityStats.totalErrorRequests()) - .setTotalRequestsInProgress(upstreamLocalityStats.totalRequestsInProgress()) - .setTotalIssuedRequests(upstreamLocalityStats.totalIssuedRequests())); - } - for (DroppedRequests droppedRequests : stats.droppedRequestsList()) { - builder.addDroppedRequests( - io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests.newBuilder() - .setCategory(droppedRequests.category()) - .setDroppedCount(droppedRequests.droppedCount())); - } - return builder.setTotalDroppedRequests(stats.totalDroppedRequests()) - .setLoadReportInterval(Durations.fromNanos(stats.loadReportIntervalNano())).build(); - } - } - - private final class LrsStreamV3 extends LrsStream { - StreamObserver lrsRequestWriterV3; - - @Override - void start() { - StreamObserver lrsResponseReaderV3 = - new StreamObserver() { - @Override - public void onNext(final LoadStatsResponse response) { - syncContext.execute(new Runnable() { - @Override - public void run() { - logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); - handleRpcResponse(response.getClustersList(), response.getSendAllClusters(), - Durations.toNanos(response.getLoadReportingInterval())); - } - }); - } - - @Override - public void onError(final Throwable t) { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcError(t); - } - }); - } - - @Override - public void onCompleted() { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcCompleted(); - } - }); - } - }; - LoadReportingServiceStub stubV3 = - LoadReportingServiceGrpc.newStub(channel); - lrsRequestWriterV3 = stubV3.withWaitForReady().streamLoadStats(lrsResponseReaderV3); - logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); - sendLoadStatsRequest(Collections.emptyList()); - } - - @Override - void sendLoadStatsRequest(List clusterStatsList) { - LoadStatsRequest.Builder requestBuilder = - LoadStatsRequest.newBuilder().setNode(node.toEnvoyProtoNode()); - for (ClusterStats stats : clusterStatsList) { - requestBuilder.addClusterStats(buildClusterStats(stats)); - } - LoadStatsRequest request = requestBuilder.build(); - lrsRequestWriterV3.onNext(request); - logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request); - } - - @Override - void sendError(Exception error) { - lrsRequestWriterV3.onError(error); - } private io.envoyproxy.envoy.config.endpoint.v3.ClusterStats buildClusterStats( ClusterStats stats) { diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java index f29239331b2..e833b3777b8 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java @@ -37,10 +37,13 @@ import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -57,7 +60,9 @@ final class PriorityLoadBalancer extends LoadBalancer { private final XdsLogger logger; // Includes all active and deactivated children. Mutable. New entries are only added from priority - // 0 up to the selected priority. An entry is only deleted 15 minutes after the its deactivation. + // 0 up to the selected priority. An entry is only deleted 15 minutes after its deactivation. + // Note that calling into a child can cause the child to call back into the LB policy and modify + // the map. Therefore copy values before looping over them. private final Map children = new HashMap<>(); // Following fields are only null initially. @@ -66,8 +71,11 @@ final class PriorityLoadBalancer extends LoadBalancer { private List priorityNames; // Config for each priority. private Map priorityConfigs; + @Nullable private String currentPriority; private ConnectivityState currentConnectivityState; private SubchannelPicker currentPicker; + // Set to true if currently in the process of handling resolved addresses. + private boolean handlingResolvedAddresses; PriorityLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); @@ -79,7 +87,7 @@ final class PriorityLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); this.resolvedAddresses = resolvedAddresses; PriorityLbConfig config = (PriorityLbConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -87,61 +95,72 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { priorityNames = config.priorities; priorityConfigs = config.childConfigs; Set prioritySet = new HashSet<>(config.priorities); - for (String priority : children.keySet()) { + ArrayList childKeys = new ArrayList<>(children.keySet()); + for (String priority : childKeys) { if (!prioritySet.contains(priority)) { - children.get(priority).deactivate(); + ChildLbState childLbState = children.get(priority); + if (childLbState != null) { + childLbState.deactivate(); + } } } + handlingResolvedAddresses = true; for (String priority : priorityNames) { - if (children.containsKey(priority)) { - children.get(priority).updateResolvedAddresses(); + ChildLbState childLbState = children.get(priority); + if (childLbState != null) { + childLbState.updateResolvedAddresses(); } } - // Not to report connecting in case a pending priority bumps up on top of the current READY - // priority. - tryNextPriority(false); + handlingResolvedAddresses = false; + tryNextPriority(); + return true; } @Override public void handleNameResolutionError(Status error) { logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); boolean gotoTransientFailure = true; - for (ChildLbState child : children.values()) { + Collection childValues = new ArrayList<>(children.values()); + for (ChildLbState child : childValues) { if (priorityNames.contains(child.priority)) { child.lb.handleNameResolutionError(error); gotoTransientFailure = false; } } if (gotoTransientFailure) { - updateOverallState(TRANSIENT_FAILURE, new ErrorPicker(error)); + updateOverallState(null, TRANSIENT_FAILURE, new ErrorPicker(error)); } } @Override public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); - for (ChildLbState child : children.values()) { + Collection childValues = new ArrayList<>(children.values()); + for (ChildLbState child : childValues) { child.tearDown(); } children.clear(); } - private void tryNextPriority(boolean reportConnecting) { + private void tryNextPriority() { for (int i = 0; i < priorityNames.size(); i++) { String priority = priorityNames.get(i); if (!children.containsKey(priority)) { ChildLbState child = new ChildLbState(priority, priorityConfigs.get(priority).ignoreReresolution); children.put(priority, child); + updateOverallState(priority, CONNECTING, BUFFER_PICKER); + // Calling the child's updateResolvedAddresses() can result in tryNextPriority() being + // called recursively. We need to be sure to be done with processing here before it is + // called. child.updateResolvedAddresses(); - updateOverallState(CONNECTING, BUFFER_PICKER); return; // Give priority i time to connect. } ChildLbState child = children.get(priority); child.reactivate(); if (child.connectivityState.equals(READY) || child.connectivityState.equals(IDLE)) { logger.log(XdsLogLevel.DEBUG, "Shifted to priority {0}", priority); - updateOverallState(child.connectivityState, child.picker); + updateOverallState(priority, child.connectivityState, child.picker); for (int j = i + 1; j < priorityNames.size(); j++) { String p = priorityNames.get(j); if (children.containsKey(p)) { @@ -151,21 +170,27 @@ private void tryNextPriority(boolean reportConnecting) { return; } if (child.failOverTimer != null && child.failOverTimer.isPending()) { - if (reportConnecting) { - updateOverallState(CONNECTING, BUFFER_PICKER); - } + updateOverallState(priority, child.connectivityState, child.picker); return; // Give priority i time to connect. } + if (priority.equals(currentPriority) && child.connectivityState != TRANSIENT_FAILURE) { + // If the current priority is not changed into TRANSIENT_FAILURE, keep using it. + updateOverallState(priority, child.connectivityState, child.picker); + return; + } } // TODO(zdapeng): Include error details of each priority. logger.log(XdsLogLevel.DEBUG, "All priority failed"); String lastPriority = priorityNames.get(priorityNames.size() - 1); SubchannelPicker errorPicker = children.get(lastPriority).picker; - updateOverallState(TRANSIENT_FAILURE, errorPicker); + updateOverallState(lastPriority, TRANSIENT_FAILURE, errorPicker); } - private void updateOverallState(ConnectivityState state, SubchannelPicker picker) { - if (!state.equals(currentConnectivityState) || !picker.equals(currentPicker)) { + private void updateOverallState( + @Nullable String priority, ConnectivityState state, SubchannelPicker picker) { + if (!Objects.equals(priority, currentPriority) || !state.equals(currentConnectivityState) + || !picker.equals(currentPicker)) { + currentPriority = priority; currentConnectivityState = state; currentPicker = picker; helper.updateBalancingState(state, picker); @@ -178,7 +203,8 @@ private final class ChildLbState { final GracefulSwitchLoadBalancer lb; // Timer to fail over to the next priority if not connected in 10 sec. Scheduled only once at // child initialization. - final ScheduledHandle failOverTimer; + ScheduledHandle failOverTimer; + boolean seenReadyOrIdleSinceTransientFailure = false; // Timer to delay shutdown and deletion of the priority. Scheduled whenever the child is // deactivated. @Nullable ScheduledHandle deletionTimer; @@ -190,23 +216,23 @@ private final class ChildLbState { this.priority = priority; childHelper = new ChildHelper(ignoreReresolution); lb = new GracefulSwitchLoadBalancer(childHelper); + failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, executor); + logger.log(XdsLogLevel.DEBUG, "Priority created: {0}", priority); + } - class FailOverTask implements Runnable { - @Override - public void run() { - if (deletionTimer != null && deletionTimer.isPending()) { - // The child is deactivated. - return; - } - picker = new ErrorPicker( - Status.UNAVAILABLE.withDescription("Connection timeout for priority " + priority)); - logger.log(XdsLogLevel.DEBUG, "Priority {0} failed over to next", priority); - tryNextPriority(true); + final class FailOverTask implements Runnable { + @Override + public void run() { + if (deletionTimer != null && deletionTimer.isPending()) { + // The child is deactivated. + return; } + picker = new ErrorPicker( + Status.UNAVAILABLE.withDescription("Connection timeout for priority " + priority)); + logger.log(XdsLogLevel.DEBUG, "Priority {0} failed over to next", priority); + currentPriority = null; // reset currentPriority to guarantee failover happen + tryNextPriority(); } - - failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, executor); - logger.log(XdsLogLevel.DEBUG, "Priority created: {0}", priority); } /** @@ -290,26 +316,33 @@ public void refreshNameResolution() { @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!children.containsKey(priority)) { - return; - } - connectivityState = newState; - picker = newPicker; - if (deletionTimer != null && deletionTimer.isPending()) { - return; - } - if (failOverTimer.isPending()) { - if (newState.equals(READY) || newState.equals(IDLE) - || newState.equals(TRANSIENT_FAILURE)) { - failOverTimer.cancel(); - } - } - tryNextPriority(true); + if (!children.containsKey(priority)) { + return; + } + connectivityState = newState; + picker = newPicker; + + if (deletionTimer != null && deletionTimer.isPending()) { + return; + } + if (newState.equals(CONNECTING)) { + if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) { + failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, + executor); } - }); + } else if (newState.equals(READY) || newState.equals(IDLE)) { + seenReadyOrIdleSinceTransientFailure = true; + failOverTimer.cancel(); + } else if (newState.equals(TRANSIENT_FAILURE)) { + seenReadyOrIdleSinceTransientFailure = false; + failOverTimer.cancel(); + } + + // If we are currently handling newly resolved addresses, let's not try to reconfigure as + // the address handling process will take care of that to provide an atomic config update. + if (!handlingResolvedAddresses) { + tryNextPriority(); + } } @Override diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java index 6e178c62c1b..df35361bed0 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java @@ -47,7 +47,7 @@ public int getPriority() { @Override public String getPolicyName() { - return "priority_experimental"; + return XdsLbPolicies.PRIORITY_POLICY_NAME; } @Override diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index a8f517a8967..436eca8ec5d 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -26,7 +26,10 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.base.MoreObjects; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; import com.google.common.collect.Sets; +import com.google.common.primitives.UnsignedInteger; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -37,12 +40,18 @@ import io.grpc.SynchronizationContext; import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; /** * A {@link LoadBalancer} that provides consistent hashing based load balancing to upstream hosts. @@ -67,6 +76,8 @@ final class RingHashLoadBalancer extends LoadBalancer { private List ring; private ConnectivityState currentState; + private Iterator connectionAttemptIterator = subchannels.values().iterator(); + private final Random random = new Random(); RingHashLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); @@ -76,14 +87,13 @@ final class RingHashLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); List addrList = resolvedAddresses.getAddresses(); - if (addrList.isEmpty()) { - handleNameResolutionError(Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS " - + "resolution was successful, but returned server addresses are empty.")); - return; + if (!validateAddrList(addrList)) { + return false; } + Map latestAddrs = stripAttrs(addrList); Set removedAddrs = Sets.newHashSet(Sets.difference(subchannels.keySet(), latestAddrs.keySet())); @@ -142,6 +152,14 @@ public void onSubchannelState(ConnectivityStateInfo newState) { for (EquivalentAddressGroup addr : removedAddrs) { removedSubchannels.add(subchannels.remove(addr)); } + // If we need to proactively start connecting, iterate through all the subchannels, starting + // at a random position. + // Alternatively, we should better start at the same position. + connectionAttemptIterator = subchannels.values().iterator(); + int randomAdvance = random.nextInt(subchannels.size()); + while (randomAdvance-- > 0) { + connectionAttemptIterator.next(); + } // Update the picker before shutting down the subchannels, to reduce the chance of race // between picking a subchannel and shutting it down. @@ -149,6 +167,78 @@ public void onSubchannelState(ConnectivityStateInfo newState) { for (Subchannel subchann : removedSubchannels) { shutdownSubchannel(subchann); } + + return true; + } + + private boolean validateAddrList(List addrList) { + if (addrList.isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS " + + "resolution was successful, but returned server addresses are empty.")); + return false; + } + + String dupAddrString = validateNoDuplicateAddresses(addrList); + if (dupAddrString != null) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS " + + "resolution was successful, but there were duplicate addresses: " + dupAddrString)); + return false; + } + + long totalWeight = 0; + for (EquivalentAddressGroup eag : addrList) { + Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); + + if (weight == null) { + weight = 1L; + } + + if (weight < 0) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + String.format("Ring hash lb error: EDS resolution was successful, but returned a " + + "negative weight for %s.", stripAttrs(eag)))); + return false; + } + if (weight > UnsignedInteger.MAX_VALUE.longValue()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + String.format("Ring hash lb error: EDS resolution was successful, but returned a weight" + + " too large to fit in an unsigned int for %s.", stripAttrs(eag)))); + return false; + } + totalWeight += weight; + } + + if (totalWeight > UnsignedInteger.MAX_VALUE.longValue()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + String.format( + "Ring hash lb error: EDS resolution was successful, but returned a sum of weights too" + + " large to fit in an unsigned int (%d).", totalWeight))); + return false; + } + + return true; + } + + @Nullable + private String validateNoDuplicateAddresses(List addrList) { + Set addresses = new HashSet<>(); + Multiset dups = HashMultiset.create(); + for (EquivalentAddressGroup eag : addrList) { + for (SocketAddress address : eag.getAddresses()) { + if (!addresses.add(address)) { + dups.add(address.toString()); + } + } + } + + if (!dups.isEmpty()) { + return dups.entrySet().stream() + .map((dup) -> + String.format("Address: %s, count: %d", dup.getElement(), dup.getCount() + 1)) + .collect(Collectors.joining("; ")); + } + + return null; } private static List buildRing( @@ -162,6 +252,7 @@ private static List buildRing( // TODO(chengyuanzhang): is using the list of socket address correct? StringBuilder sb = new StringBuilder(addrKey.getAddresses().toString()); sb.append('_'); + int lengthWithoutCounter = sb.length(); targetHashes += scale * normalizedWeight; long i = 0L; while (currentHashes < targetHashes) { @@ -170,7 +261,7 @@ private static List buildRing( ring.add(new RingEntry(hash, addrKey)); i++; currentHashes++; - sb.setLength(sb.length() - 1); + sb.setLength(lengthWithoutCounter); } } Collections.sort(ring); @@ -203,53 +294,77 @@ public void shutdown() { * TRANSIENT_FAILURE *

  • If there is at least one subchannel in CONNECTING state, overall state is * CONNECTING
  • + *
  • If there is one subchannel in TRANSIENT_FAILURE state and there is + * more than one subchannel, report CONNECTING
  • *
  • If there is at least one subchannel in IDLE state, overall state is IDLE
  • *
  • Otherwise, overall state is TRANSIENT_FAILURE
  • * */ private void updateBalancingState() { checkState(!subchannels.isEmpty(), "no subchannel has been created"); - int failureCount = 0; - boolean hasConnecting = false; - Subchannel idleSubchannel = null; - ConnectivityState overallState = null; + boolean startConnectionAttempt = false; + int numIdle = 0; + int numReady = 0; + int numConnecting = 0; + int numTransientFailure = 0; for (Subchannel subchannel : subchannels.values()) { ConnectivityState state = getSubchannelStateInfoRef(subchannel).value.getState(); if (state == READY) { - overallState = READY; + numReady++; break; - } - if (state == TRANSIENT_FAILURE) { - failureCount++; - } else if (state == CONNECTING) { - hasConnecting = true; + } else if (state == TRANSIENT_FAILURE) { + numTransientFailure++; + } else if (state == CONNECTING ) { + numConnecting++; } else if (state == IDLE) { - if (idleSubchannel == null) { - idleSubchannel = subchannel; - } + numIdle++; } } - if (overallState == null) { - if (failureCount >= 2) { - // This load balancer may not get any pick requests from the upstream if it's reporting - // TRANSIENT_FAILURE. It needs to recover by itself by attempting to connect to at least - // one subchannel that has not failed at any given time. - if (!hasConnecting && idleSubchannel != null) { - idleSubchannel.requestConnection(); - } - overallState = TRANSIENT_FAILURE; - } else if (hasConnecting) { - overallState = CONNECTING; - } else if (idleSubchannel != null) { - overallState = IDLE; - } else { - overallState = TRANSIENT_FAILURE; - } + ConnectivityState overallState; + if (numReady > 0) { + overallState = READY; + } else if (numTransientFailure >= 2) { + overallState = TRANSIENT_FAILURE; + startConnectionAttempt = (numConnecting == 0); + } else if (numConnecting > 0) { + overallState = CONNECTING; + } else if (numTransientFailure == 1 && subchannels.size() > 1) { + overallState = CONNECTING; + startConnectionAttempt = true; + } else if (numIdle > 0) { + overallState = IDLE; + } else { + overallState = TRANSIENT_FAILURE; + startConnectionAttempt = true; } RingHashPicker picker = new RingHashPicker(syncContext, ring, subchannels); // TODO(chengyuanzhang): avoid unnecessary reprocess caused by duplicated server addr updates helper.updateBalancingState(overallState, picker); currentState = overallState; + // While the ring_hash policy is reporting TRANSIENT_FAILURE, it will + // not be getting any pick requests from the priority policy. + // However, because the ring_hash policy does not attempt to + // reconnect to subchannels unless it is getting pick requests, + // it will need special handling to ensure that it will eventually + // recover from TRANSIENT_FAILURE state once the problem is resolved. + // Specifically, it will make sure that it is attempting to connect to + // at least one subchannel at any given time. After a given subchannel + // fails a connection attempt, it will move on to the next subchannel + // in the ring. It will keep doing this until one of the subchannels + // successfully connects, at which point it will report READY and stop + // proactively trying to connect. The policy will remain in + // TRANSIENT_FAILURE until at least one subchannel becomes connected, + // even if subchannels are in state CONNECTING during that time. + // + // Note that we do the same thing when the policy is in state + // CONNECTING, just to ensure that we don't remain in CONNECTING state + // indefinitely if there are no new picks coming in. + if (startConnectionAttempt) { + if (!connectionAttemptIterator.hasNext()) { + connectionAttemptIterator = subchannels.values().iterator(); + } + connectionAttemptIterator.next().requestConnection(); + } } private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { @@ -259,18 +374,22 @@ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { helper.refreshNameResolution(); } - Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); + updateConnectivityState(subchannel, stateInfo); + updateBalancingState(); + } + private void updateConnectivityState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); + ConnectivityState previousConnectivityState = subchannelStateRef.value.getState(); // Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected. // If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in // TRANSIENT_FAILURE until it becomes READY. - if (subchannelStateRef.value.getState() == TRANSIENT_FAILURE) { + if (previousConnectivityState == TRANSIENT_FAILURE) { if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { return; } } subchannelStateRef.value = stateInfo; - updateBalancingState(); } private static void shutdownSubchannel(Subchannel subchannel) { @@ -359,10 +478,12 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { // Try finding a READY subchannel. Starting from the ring entry next to the RPC's hash. // If the one of the first two subchannels is not in TRANSIENT_FAILURE, return result // based on that subchannel. Otherwise, fail the pick unless a READY subchannel is found. - // Meanwhile, trigger connection for the first subchannel that is in IDLE if no subchannel - // before it is in CONNECTING or READY. - boolean hasPending = false; // true if having subchannel(s) in CONNECTING or IDLE - boolean canBuffer = true; // true if RPCs can be buffered with a pending subchannel + // Meanwhile, trigger connection for the channel and status: + // For the first subchannel that is in IDLE or TRANSIENT_FAILURE; + // And for the second subchannel that is in IDLE or TRANSIENT_FAILURE; + // And for each of the following subchannels that is in TRANSIENT_FAILURE or IDLE, + // stop until we find the first subchannel that is in CONNECTING or IDLE status. + boolean foundFirstNonFailed = false; // true if having subchannel(s) in CONNECTING or IDLE Subchannel firstSubchannel = null; Subchannel secondSubchannel = null; for (int i = 0; i < ring.size(); i++) { @@ -377,36 +498,50 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { // are failed unless there is a READY connection. if (firstSubchannel == null) { firstSubchannel = subchannel.subchannel; - } else if (subchannel.subchannel != firstSubchannel) { - if (secondSubchannel == null) { - secondSubchannel = subchannel.subchannel; - } else if (subchannel.subchannel != secondSubchannel) { - canBuffer = false; + PickResult maybeBuffer = pickSubchannelsNonReady(subchannel); + if (maybeBuffer != null) { + return maybeBuffer; } - } - if (subchannel.stateInfo.getState() == TRANSIENT_FAILURE) { - continue; - } - if (!hasPending) { // first non-failing subchannel - if (subchannel.stateInfo.getState() == IDLE) { - final Subchannel finalSubchannel = subchannel.subchannel; - syncContext.execute(new Runnable() { - @Override - public void run() { - finalSubchannel.requestConnection(); - } - }); + } else if (subchannel.subchannel != firstSubchannel && secondSubchannel == null) { + secondSubchannel = subchannel.subchannel; + PickResult maybeBuffer = pickSubchannelsNonReady(subchannel); + if (maybeBuffer != null) { + return maybeBuffer; } - if (canBuffer) { // done if this is the first or second two subchannel - return PickResult.withNoResult(); // queue the pick and re-process later + } else if (subchannel.subchannel != firstSubchannel + && subchannel.subchannel != secondSubchannel) { + if (!foundFirstNonFailed) { + pickSubchannelsNonReady(subchannel); + if (subchannel.stateInfo.getState() != TRANSIENT_FAILURE) { + foundFirstNonFailed = true; + } } - hasPending = true; } } // Fail the pick with error status of the original subchannel hit by hash. SubchannelView originalSubchannel = pickableSubchannels.get(ring.get(mid).addrKey); return PickResult.withError(originalSubchannel.stateInfo.getStatus()); } + + @Nullable + private PickResult pickSubchannelsNonReady(SubchannelView subchannel) { + if (subchannel.stateInfo.getState() == TRANSIENT_FAILURE + || subchannel.stateInfo.getState() == IDLE ) { + final Subchannel finalSubchannel = subchannel.subchannel; + syncContext.execute(new Runnable() { + @Override + public void run() { + finalSubchannel.requestConnection(); + } + }); + } + if (subchannel.stateInfo.getState() == CONNECTING + || subchannel.stateInfo.getState() == IDLE) { + return PickResult.withNoResult(); + } else { + return null; + } + } } /** diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index 0df4c8c2984..5bba0dc9b0a 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -26,6 +26,7 @@ import io.grpc.Status; import io.grpc.internal.JsonUtil; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; +import io.grpc.xds.RingHashOptions; import java.util.Map; /** @@ -39,11 +40,7 @@ public final class RingHashLoadBalancerProvider extends LoadBalancerProvider { static final long DEFAULT_MIN_RING_SIZE = 1024L; // Same as ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE @VisibleForTesting - static final long DEFAULT_MAX_RING_SIZE = 8 * 1024 * 1024L; - // Maximum number of ring entries allowed. Setting this too large can result in slow - // ring construction and OOM error. - // Same as ClientXdsClient.MAX_RING_HASH_LB_POLICY_RING_SIZE - static final long MAX_RING_SIZE = 8 * 1024 * 1024L; + static final long DEFAULT_MAX_RING_SIZE = 4 * 1024L; private static final boolean enableRingHash = Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")) @@ -73,15 +70,21 @@ public String getPolicyName() { public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); + long maxRingSizeCap = RingHashOptions.getRingSizeCap(); if (minRingSize == null) { minRingSize = DEFAULT_MIN_RING_SIZE; } if (maxRingSize == null) { maxRingSize = DEFAULT_MAX_RING_SIZE; } - if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize - || maxRingSize > MAX_RING_SIZE) { - return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription( + if (minRingSize > maxRingSizeCap) { + minRingSize = maxRingSizeCap; + } + if (maxRingSize > maxRingSizeCap) { + maxRingSize = maxRingSizeCap; + } + if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize) { + return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( "Invalid 'mingRingSize'/'maxRingSize'")); } return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize)); diff --git a/xds/src/main/java/io/grpc/xds/RingHashOptions.java b/xds/src/main/java/io/grpc/xds/RingHashOptions.java new file mode 100644 index 00000000000..6bb3fc7887e --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/RingHashOptions.java @@ -0,0 +1,62 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.ExperimentalApi; + +/** + * Utility class that provides a way to configure ring hash size limits. This is applicable + * for clients that use the ring hash load balancing policy. Note that size limits involve + * a tradeoff between client memory consumption and accuracy of load balancing weight + * representations. Also see https://github.com/grpc/proposal/pull/338. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9718") +public final class RingHashOptions { + // Same as ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE + @VisibleForTesting + static final long MAX_RING_SIZE_CAP = 8 * 1024 * 1024L; + @VisibleForTesting + // Same as RingHashLoadBalancerProvider.DEFAULT_MAX_RING_SIZE + static final long DEFAULT_RING_SIZE_CAP = 4 * 1024L; + + // Limits ring hash sizes to restrict client memory usage. + private static volatile long ringSizeCap = DEFAULT_RING_SIZE_CAP; + + private RingHashOptions() {} // Prevent instantiation + + /** + * Set the global limit for the min and max number of ring hash entries per ring. + * Note that this limit is clamped between 1 entry and 8,388,608 entries, and new + * limits lying outside that range will be silently moved to the nearest number within + * that range. Defaults initially to 4096 entries. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9718") + public static void setRingSizeCap(long ringSizeCap) { + ringSizeCap = Math.max(1, ringSizeCap); + ringSizeCap = Math.min(MAX_RING_SIZE_CAP, ringSizeCap); + RingHashOptions.ringSizeCap = ringSizeCap; + } + + /** + * Get the global limit for min and max ring hash sizes. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9718") + public static long getRingSizeCap() { + return RingHashOptions.ringSizeCap; + } +} diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index 1c8fe0bad6d..5aabd976085 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -26,9 +26,9 @@ import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.TimeProvider; import io.grpc.xds.Bootstrapper.BootstrapInfo; -import io.grpc.xds.ClientXdsClient.XdsChannelFactory; +import io.grpc.xds.XdsClientImpl.XdsChannelFactory; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; -import io.grpc.xds.internal.sds.TlsContextManagerImpl; +import io.grpc.xds.internal.security.TlsContextManagerImpl; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicReference; @@ -123,7 +123,7 @@ public XdsClient getObject() { synchronized (lock) { if (refCount == 0) { scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - xdsClient = new ClientXdsClient( + xdsClient = new XdsClientImpl( XdsChannelFactory.DEFAULT_XDS_CHANNEL_FACTORY, bootstrapInfo, context, diff --git a/xds/src/main/java/io/grpc/xds/TlsContextManager.java b/xds/src/main/java/io/grpc/xds/TlsContextManager.java index e35eb68f219..772a6cff102 100644 --- a/xds/src/main/java/io/grpc/xds/TlsContextManager.java +++ b/xds/src/main/java/io/grpc/xds/TlsContextManager.java @@ -19,7 +19,7 @@ import io.grpc.Internal; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.SslContextProvider; +import io.grpc.xds.internal.security.SslContextProvider; @Internal public interface TlsContextManager { diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java index ee8c0308fce..825e4a8eca0 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java @@ -28,7 +28,6 @@ import io.grpc.InternalLogId; import io.grpc.LoadBalancer; import io.grpc.Status; -import io.grpc.SynchronizationContext; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; @@ -49,20 +48,29 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { private final Map childBalancers = new HashMap<>(); private final Map childHelpers = new HashMap<>(); private final Helper helper; - private final SynchronizationContext syncContext; private Map targets = ImmutableMap.of(); + // Set to true if currently in the process of handling resolved addresses. + private boolean resolvingAddresses; WeightedTargetLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); logger = XdsLogger.withLogId( InternalLogId.allocate("weighted-target-lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + try { + resolvingAddresses = true; + return acceptResolvedAddressesInternal(resolvedAddresses); + } finally { + resolvingAddresses = false; + } + } + + public boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); Object lbConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); checkNotNull(lbConfig, "missing weighted_target lb config"); @@ -101,6 +109,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { childBalancers.keySet().retainAll(targets.keySet()); childHelpers.keySet().retainAll(targets.keySet()); updateOverallBalancingState(); + return true; } @Override @@ -191,17 +200,14 @@ private ChildHelper(String name) { @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!childBalancers.containsKey(name)) { - return; - } - currentState = newState; - currentPicker = newPicker; - updateOverallBalancingState(); - } - }); + currentState = newState; + currentPicker = newPicker; + + // If we are already in the process of resolving addresses, the overall balancing state + // will be updated at the end of it, and we don't need to trigger that update here. + if (!resolvingAddresses && childBalancers.containsKey(name)) { + updateOverallBalancingState(); + } } @Override diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java index 1ac3aa13be1..c6a0893db02 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java @@ -117,7 +117,7 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { return ConfigOrError.fromConfig(new WeightedTargetConfig(parsedChildConfigs)); } catch (RuntimeException e) { return ConfigOrError.fromError( - Status.fromThrowable(e).withDescription( + Status.INTERNAL.withCause(e).withDescription( "Failed to parse weighted_target LB config: " + rawConfig)); } } diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java new file mode 100644 index 00000000000..b9196492624 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java @@ -0,0 +1,164 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; + +import com.google.common.base.MoreObjects; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.InternalLogId; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Status; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; +import io.grpc.xds.XdsLogger.XdsLogLevel; +import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * This load balancer acts as a parent for the {@link WeightedTargetLoadBalancer} and configures + * it with a child policy in its configuration and locality weights it gets from an attribute in + * {@link io.grpc.LoadBalancer.ResolvedAddresses}. + */ +final class WrrLocalityLoadBalancer extends LoadBalancer { + + private final XdsLogger logger; + private final Helper helper; + private final GracefulSwitchLoadBalancer switchLb; + private final LoadBalancerRegistry lbRegistry; + + WrrLocalityLoadBalancer(Helper helper) { + this(helper, LoadBalancerRegistry.getDefaultRegistry()); + } + + WrrLocalityLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) { + this.helper = checkNotNull(helper, "helper"); + this.lbRegistry = lbRegistry; + switchLb = new GracefulSwitchLoadBalancer(helper); + logger = XdsLogger.withLogId( + InternalLogId.allocate("xds-wrr-locality-lb", helper.getAuthority())); + logger.log(XdsLogLevel.INFO, "Created"); + } + + @Override + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); + + // The configuration with the child policy is combined with the locality weights + // to produce the weighted target LB config. + WrrLocalityConfig wrrLocalityConfig + = (WrrLocalityConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + + // A map of locality weights is built up from the locality weight attributes in each address. + Map localityWeights = new HashMap<>(); + for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) { + Attributes eagAttrs = eag.getAttributes(); + Locality locality = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY); + Integer localityWeight = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT); + + if (locality == null) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker( + Status.UNAVAILABLE.withDescription("wrr_locality error: no locality provided"))); + return false; + } + if (localityWeight == null) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker( + Status.UNAVAILABLE.withDescription( + "wrr_locality error: no weight provided for locality " + locality))); + return false; + } + + if (!localityWeights.containsKey(locality)) { + localityWeights.put(locality, localityWeight); + } else if (!localityWeights.get(locality).equals(localityWeight)) { + logger.log(XdsLogLevel.WARNING, + "Locality {0} has both weights {1} and {2}, using weight {1}", locality, + localityWeights.get(locality), localityWeight); + } + } + + // Weighted target LB expects a WeightedPolicySelection for each locality as it will create a + // child LB for each. + Map weightedPolicySelections = new HashMap<>(); + for (Locality locality : localityWeights.keySet()) { + weightedPolicySelections.put(locality.toString(), + new WeightedPolicySelection(localityWeights.get(locality), + wrrLocalityConfig.childPolicy)); + } + + switchLb.switchTo(lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME)); + switchLb.handleResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new WeightedTargetConfig(weightedPolicySelections)) + .build()); + + return true; + } + + @Override + public void handleNameResolutionError(Status error) { + logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); + switchLb.handleNameResolutionError(error); + } + + @Override + public void shutdown() { + switchLb.shutdown(); + } + + /** + * The LB config for {@link WrrLocalityLoadBalancer}. + */ + static final class WrrLocalityConfig { + + final PolicySelection childPolicy; + + WrrLocalityConfig(PolicySelection childPolicy) { + this.childPolicy = childPolicy; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WrrLocalityConfig that = (WrrLocalityConfig) o; + return Objects.equals(childPolicy, that.childPolicy); + } + + @Override + public int hashCode() { + return Objects.hashCode(childPolicy); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("childPolicy", childPolicy).toString(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java new file mode 100644 index 00000000000..31a4e128140 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java @@ -0,0 +1,85 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; +import java.util.List; +import java.util.Map; + +/** + * The provider for {@link WrrLocalityLoadBalancer}. An instance of this class should be acquired + * through {@link LoadBalancerRegistry#getProvider} by using the name + * "xds_wrr_locality_experimental". + */ +@Internal +public final class WrrLocalityLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new WrrLocalityLoadBalancer(helper); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return XdsLbPolicies.WRR_LOCALITY_POLICY_NAME; + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + try { + List childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(rawConfig, "childPolicy")); + if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { + return ConfigOrError.fromError(Status.INTERNAL.withDescription( + "No child policy in wrr_locality LB policy: " + + rawConfig)); + } + ConfigOrError selectedConfig = + ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, + LoadBalancerRegistry.getDefaultRegistry()); + if (selectedConfig.getError() != null) { + return selectedConfig; + } + PolicySelection policySelection = (PolicySelection) selectedConfig.getConfig(); + return ConfigOrError.fromConfig(new WrrLocalityConfig(policySelection)); + } catch (RuntimeException e) { + return ConfigOrError.fromError(Status.INTERNAL.withCause(e) + .withDescription("Failed to parse wrr_locality LB config: " + rawConfig)); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java b/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java index d07f10555c6..189ebc0ca19 100644 --- a/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java +++ b/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java @@ -22,7 +22,7 @@ import io.grpc.ExperimentalApi; import io.grpc.netty.InternalNettyChannelCredentials; import io.grpc.netty.InternalProtocolNegotiator; -import io.grpc.xds.internal.sds.SdsProtocolNegotiators; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7514") public class XdsChannelCredentials { @@ -40,6 +40,6 @@ public static ChannelCredentials create(ChannelCredentials fallback) { InternalProtocolNegotiator.ClientFactory fallbackNegotiator = InternalNettyChannelCredentials.toNegotiator(checkNotNull(fallback, "fallback")); return InternalNettyChannelCredentials.create( - SdsProtocolNegotiators.clientProtocolNegotiatorFactory(fallbackNegotiator)); + SecurityProtocolNegotiators.clientProtocolNegotiatorFactory(fallbackNegotiator)); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClient.java index 1c231b83a19..591c4d7f339 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/XdsClient.java @@ -19,21 +19,13 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.xds.Bootstrapper.XDSTP_SCHEME; -import com.google.auto.value.AutoValue; import com.google.common.base.Joiner; -import com.google.common.base.MoreObjects; import com.google.common.base.Splitter; -import com.google.common.collect.ImmutableList; import com.google.common.net.UrlEscapers; import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.Any; import io.grpc.Status; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.Endpoints.DropOverload; -import io.grpc.xds.Endpoints.LocalityLbEndpoints; -import io.grpc.xds.EnvoyServerProtoData.Listener; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.LoadStatsManager2.ClusterLocalityStats; import java.net.URI; @@ -41,10 +33,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import javax.annotation.Nullable; /** @@ -116,288 +106,23 @@ static String percentEncodePath(String input) { return Joiner.on('/').join(encodedSegs); } - @AutoValue - abstract static class LdsUpdate implements ResourceUpdate { - // Http level api listener configuration. - @Nullable - abstract HttpConnectionManager httpConnectionManager(); - - // Tcp level listener configuration. - @Nullable - abstract Listener listener(); - - static LdsUpdate forApiListener(HttpConnectionManager httpConnectionManager) { - checkNotNull(httpConnectionManager, "httpConnectionManager"); - return new AutoValue_XdsClient_LdsUpdate(httpConnectionManager, null); - } - - static LdsUpdate forTcpListener(Listener listener) { - checkNotNull(listener, "listener"); - return new AutoValue_XdsClient_LdsUpdate(null, listener); - } - } - - static final class RdsUpdate implements ResourceUpdate { - // The list virtual hosts that make up the route table. - final List virtualHosts; - - RdsUpdate(List virtualHosts) { - this.virtualHosts = Collections.unmodifiableList( - new ArrayList<>(checkNotNull(virtualHosts, "virtualHosts"))); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("virtualHosts", virtualHosts) - .toString(); - } - - @Override - public int hashCode() { - return Objects.hash(virtualHosts); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - RdsUpdate that = (RdsUpdate) o; - return Objects.equals(virtualHosts, that.virtualHosts); - } - } - - /** xDS resource update for cluster-level configuration. */ - @AutoValue - abstract static class CdsUpdate implements ResourceUpdate { - abstract String clusterName(); - - abstract ClusterType clusterType(); - - // Endpoint-level load balancing policy. - abstract LbPolicy lbPolicy(); - - // Only valid if lbPolicy is "ring_hash_experimental". - abstract long minRingSize(); - - // Only valid if lbPolicy is "ring_hash_experimental". - abstract long maxRingSize(); - - // Only valid if lbPolicy is "least_request_experimental". - abstract int choiceCount(); - - // Alternative resource name to be used in EDS requests. - /// Only valid for EDS cluster. - @Nullable - abstract String edsServiceName(); - - // Corresponding DNS name to be used if upstream endpoints of the cluster is resolvable - // via DNS. - // Only valid for LOGICAL_DNS cluster. - @Nullable - abstract String dnsHostName(); - - // Load report server info for reporting loads via LRS. - // Only valid for EDS or LOGICAL_DNS cluster. - @Nullable - abstract ServerInfo lrsServerInfo(); - - // Max number of concurrent requests can be sent to this cluster. - // Only valid for EDS or LOGICAL_DNS cluster. - @Nullable - abstract Long maxConcurrentRequests(); - - // TLS context used to connect to connect to this cluster. - // Only valid for EDS or LOGICAL_DNS cluster. - @Nullable - abstract UpstreamTlsContext upstreamTlsContext(); - - // List of underlying clusters making of this aggregate cluster. - // Only valid for AGGREGATE cluster. - @Nullable - abstract ImmutableList prioritizedClusterNames(); - - static Builder forAggregate(String clusterName, List prioritizedClusterNames) { - checkNotNull(prioritizedClusterNames, "prioritizedClusterNames"); - return new AutoValue_XdsClient_CdsUpdate.Builder() - .clusterName(clusterName) - .clusterType(ClusterType.AGGREGATE) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) - .prioritizedClusterNames(ImmutableList.copyOf(prioritizedClusterNames)); - } - - static Builder forEds(String clusterName, @Nullable String edsServiceName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext upstreamTlsContext) { - return new AutoValue_XdsClient_CdsUpdate.Builder() - .clusterName(clusterName) - .clusterType(ClusterType.EDS) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) - .edsServiceName(edsServiceName) - .lrsServerInfo(lrsServerInfo) - .maxConcurrentRequests(maxConcurrentRequests) - .upstreamTlsContext(upstreamTlsContext); - } - - static Builder forLogicalDns(String clusterName, String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext upstreamTlsContext) { - return new AutoValue_XdsClient_CdsUpdate.Builder() - .clusterName(clusterName) - .clusterType(ClusterType.LOGICAL_DNS) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) - .dnsHostName(dnsHostName) - .lrsServerInfo(lrsServerInfo) - .maxConcurrentRequests(maxConcurrentRequests) - .upstreamTlsContext(upstreamTlsContext); - } - - enum ClusterType { - EDS, LOGICAL_DNS, AGGREGATE - } - - enum LbPolicy { - ROUND_ROBIN, RING_HASH, LEAST_REQUEST - } - - // FIXME(chengyuanzhang): delete this after UpstreamTlsContext's toString() is fixed. - @Override - public final String toString() { - return MoreObjects.toStringHelper(this) - .add("clusterName", clusterName()) - .add("clusterType", clusterType()) - .add("lbPolicy", lbPolicy()) - .add("minRingSize", minRingSize()) - .add("maxRingSize", maxRingSize()) - .add("choiceCount", choiceCount()) - .add("edsServiceName", edsServiceName()) - .add("dnsHostName", dnsHostName()) - .add("lrsServerInfo", lrsServerInfo()) - .add("maxConcurrentRequests", maxConcurrentRequests()) - // Exclude upstreamTlsContext as its string representation is cumbersome. - .add("prioritizedClusterNames", prioritizedClusterNames()) - .toString(); - } - - @AutoValue.Builder - abstract static class Builder { - // Private, use one of the static factory methods instead. - protected abstract Builder clusterName(String clusterName); - - // Private, use one of the static factory methods instead. - protected abstract Builder clusterType(ClusterType clusterType); - - // Private, use roundRobinLbPolicy() or ringHashLbPolicy(long, long). - protected abstract Builder lbPolicy(LbPolicy lbPolicy); - - Builder roundRobinLbPolicy() { - return this.lbPolicy(LbPolicy.ROUND_ROBIN); - } - - Builder ringHashLbPolicy(long minRingSize, long maxRingSize) { - return this.lbPolicy(LbPolicy.RING_HASH).minRingSize(minRingSize).maxRingSize(maxRingSize); - } - - Builder leastRequestLbPolicy(int choiceCount) { - return this.lbPolicy(LbPolicy.LEAST_REQUEST).choiceCount(choiceCount); - } - - // Private, use leastRequestLbPolicy(int). - protected abstract Builder choiceCount(int choiceCount); - - // Private, use ringHashLbPolicy(long, long). - protected abstract Builder minRingSize(long minRingSize); - - // Private, use ringHashLbPolicy(long, long). - protected abstract Builder maxRingSize(long maxRingSize); - - // Private, use CdsUpdate.forEds() instead. - protected abstract Builder edsServiceName(String edsServiceName); - - // Private, use CdsUpdate.forLogicalDns() instead. - protected abstract Builder dnsHostName(String dnsHostName); - - // Private, use one of the static factory methods instead. - protected abstract Builder lrsServerInfo(ServerInfo lrsServerInfo); - - // Private, use one of the static factory methods instead. - protected abstract Builder maxConcurrentRequests(Long maxConcurrentRequests); - - // Private, use one of the static factory methods instead. - protected abstract Builder upstreamTlsContext(UpstreamTlsContext upstreamTlsContext); - - // Private, use CdsUpdate.forAggregate() instead. - protected abstract Builder prioritizedClusterNames(List prioritizedClusterNames); - - abstract CdsUpdate build(); - } - } - - static final class EdsUpdate implements ResourceUpdate { - final String clusterName; - final Map localityLbEndpointsMap; - final List dropPolicies; - - EdsUpdate(String clusterName, Map localityLbEndpoints, - List dropPolicies) { - this.clusterName = checkNotNull(clusterName, "clusterName"); - this.localityLbEndpointsMap = Collections.unmodifiableMap( - new LinkedHashMap<>(checkNotNull(localityLbEndpoints, "localityLbEndpoints"))); - this.dropPolicies = Collections.unmodifiableList( - new ArrayList<>(checkNotNull(dropPolicies, "dropPolicies"))); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - EdsUpdate that = (EdsUpdate) o; - return Objects.equals(clusterName, that.clusterName) - && Objects.equals(localityLbEndpointsMap, that.localityLbEndpointsMap) - && Objects.equals(dropPolicies, that.dropPolicies); - } - - @Override - public int hashCode() { - return Objects.hash(clusterName, localityLbEndpointsMap, dropPolicies); - } - - @Override - public String toString() { - return - MoreObjects - .toStringHelper(this) - .add("clusterName", clusterName) - .add("localityLbEndpointsMap", localityLbEndpointsMap) - .add("dropPolicies", dropPolicies) - .toString(); - } - } - interface ResourceUpdate { } /** * Watcher interface for a single requested xDS resource. */ - interface ResourceWatcher { + interface ResourceWatcher { /** * Called when the resource discovery RPC encounters some transient error. + * + *

    Note that we expect that the implementer to: + * - Comply with the guarantee to not generate certain statuses by the library: + * https://grpc.github.io/grpc/core/md_doc_statuscodes.html. If the code needs to be + * propagated to the channel, override it with {@link Status.Code#UNAVAILABLE}. + * - Keep {@link Status} description in one form or another, as it contains valuable debugging + * information. */ void onError(Status error); @@ -407,22 +132,8 @@ interface ResourceWatcher { * @param resourceName name of the resource requested in discovery request. */ void onResourceDoesNotExist(String resourceName); - } - - interface LdsResourceWatcher extends ResourceWatcher { - void onChanged(LdsUpdate update); - } - - interface RdsResourceWatcher extends ResourceWatcher { - void onChanged(RdsUpdate update); - } - interface CdsResourceWatcher extends ResourceWatcher { - void onChanged(CdsUpdate update); - } - - interface EdsResourceWatcher extends ResourceWatcher { - void onChanged(EdsUpdate update); + void onChanged(T update); } /** @@ -584,64 +295,25 @@ TlsContextManager getTlsContextManager() { * a map ("resource name": "resource metadata"). */ // Must be synchronized. - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { throw new UnsupportedOperationException(); } /** - * Registers a data watcher for the given LDS resource. - */ - void watchLdsResource(String resourceName, LdsResourceWatcher watcher) { - throw new UnsupportedOperationException(); - } - - /** - * Unregisters the given LDS resource watcher. - */ - void cancelLdsResourceWatch(String resourceName, LdsResourceWatcher watcher) { - throw new UnsupportedOperationException(); - } - - /** - * Registers a data watcher for the given RDS resource. - */ - void watchRdsResource(String resourceName, RdsResourceWatcher watcher) { - throw new UnsupportedOperationException(); - } - - /** - * Unregisters the given RDS resource watcher. - */ - void cancelRdsResourceWatch(String resourceName, RdsResourceWatcher watcher) { - throw new UnsupportedOperationException(); - } - - /** - * Registers a data watcher for the given CDS resource. - */ - void watchCdsResource(String resourceName, CdsResourceWatcher watcher) { - throw new UnsupportedOperationException(); - } - - /** - * Unregisters the given CDS resource watcher. - */ - void cancelCdsResourceWatch(String resourceName, CdsResourceWatcher watcher) { - throw new UnsupportedOperationException(); - } - - /** - * Registers a data watcher for the given EDS resource. + * Registers a data watcher for the given Xds resource. */ - void watchEdsResource(String resourceName, EdsResourceWatcher watcher) { + void watchXdsResource(XdsResourceType type, String resourceName, + ResourceWatcher watcher) { throw new UnsupportedOperationException(); } /** - * Unregisters the given EDS resource watcher. + * Unregisters the given resource watcher. */ - void cancelEdsResourceWatch(String resourceName, EdsResourceWatcher watcher) { + void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { throw new UnsupportedOperationException(); } @@ -672,21 +344,10 @@ ClusterLocalityStats addClusterLocalityStats( } interface XdsResponseHandler { - /** Called when an LDS response is received. */ - void handleLdsResponse( - ServerInfo serverInfo, String versionInfo, List resources, String nonce); - - /** Called when an RDS response is received. */ - void handleRdsResponse( - ServerInfo serverInfo, String versionInfo, List resources, String nonce); - - /** Called when an CDS response is received. */ - void handleCdsResponse( - ServerInfo serverInfo, String versionInfo, List resources, String nonce); - - /** Called when an EDS response is received. */ - void handleEdsResponse( - ServerInfo serverInfo, String versionInfo, List resources, String nonce); + /** Called when a xds response is received. */ + void handleResourceResponse( + XdsResourceType resourceType, ServerInfo serverInfo, String versionInfo, + List resources, String nonce); /** Called when the ADS stream is closed passively. */ // Must be synchronized. @@ -707,6 +368,17 @@ interface ResourceStore { */ // Must be synchronized. @Nullable - Collection getSubscribedResources(ServerInfo serverInfo, ResourceType type); + Collection getSubscribedResources(ServerInfo serverInfo, + XdsResourceType type); + + Map> getSubscribedResourceTypesWithTypeUrl(); + } + + interface TimerLaunch { + /** + * For all subscriber's for the specified server, if the resource hasn't yet been + * resolved then start a timer for it. + */ + void startSubscriberTimersIfNeeded(ServerInfo serverInfo); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsClientImpl.java b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java new file mode 100644 index 00000000000..e67bff12871 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java @@ -0,0 +1,740 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.Bootstrapper.XDSTP_SCHEME; +import static io.grpc.xds.XdsResourceType.ParsedResource; +import static io.grpc.xds.XdsResourceType.ValidatedResourceUpdate; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.Any; +import io.grpc.ChannelCredentials; +import io.grpc.Context; +import io.grpc.Grpc; +import io.grpc.InternalLogId; +import io.grpc.LoadBalancerRegistry; +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.TimeProvider; +import io.grpc.xds.Bootstrapper.AuthorityInfo; +import io.grpc.xds.Bootstrapper.ServerInfo; +import io.grpc.xds.LoadStatsManager2.ClusterDropStats; +import io.grpc.xds.LoadStatsManager2.ClusterLocalityStats; +import io.grpc.xds.XdsClient.ResourceStore; +import io.grpc.xds.XdsClient.TimerLaunch; +import io.grpc.xds.XdsClient.XdsResponseHandler; +import io.grpc.xds.XdsLogger.XdsLogLevel; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * XdsClient implementation. + */ +final class XdsClientImpl extends XdsClient + implements XdsResponseHandler, ResourceStore, TimerLaunch { + + private static boolean LOG_XDS_NODE_ID = Boolean.parseBoolean( + System.getenv("GRPC_LOG_XDS_NODE_ID")); + private static final Logger classLogger = Logger.getLogger(XdsClientImpl.class.getName()); + + // Longest time to wait, since the subscription to some resource, for concluding its absence. + @VisibleForTesting + static final int INITIAL_RESOURCE_FETCH_TIMEOUT_SEC = 15; + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log( + XdsLogLevel.ERROR, + "Uncaught exception in XdsClient SynchronizationContext. Panic!", + e); + // TODO(chengyuanzhang): better error handling. + throw new AssertionError(e); + } + }); + private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); + private final LoadBalancerRegistry loadBalancerRegistry + = LoadBalancerRegistry.getDefaultRegistry(); + private final Map serverChannelMap = new HashMap<>(); + private final Map, + Map>> + resourceSubscribers = new HashMap<>(); + private final Map> subscribedResourceTypeUrls = new HashMap<>(); + private final LoadStatsManager2 loadStatsManager; + private final Map serverLrsClientMap = new HashMap<>(); + private final XdsChannelFactory xdsChannelFactory; + private final Bootstrapper.BootstrapInfo bootstrapInfo; + private final Context context; + private final ScheduledExecutorService timeService; + private final BackoffPolicy.Provider backoffPolicyProvider; + private final Supplier stopwatchSupplier; + private final TimeProvider timeProvider; + private boolean reportingLoad; + private final TlsContextManager tlsContextManager; + private final InternalLogId logId; + private final XdsLogger logger; + private volatile boolean isShutdown; + + XdsClientImpl( + XdsChannelFactory xdsChannelFactory, + Bootstrapper.BootstrapInfo bootstrapInfo, + Context context, + ScheduledExecutorService timeService, + BackoffPolicy.Provider backoffPolicyProvider, + Supplier stopwatchSupplier, + TimeProvider timeProvider, + TlsContextManager tlsContextManager) { + this.xdsChannelFactory = xdsChannelFactory; + this.bootstrapInfo = bootstrapInfo; + this.context = context; + this.timeService = timeService; + loadStatsManager = new LoadStatsManager2(stopwatchSupplier); + this.backoffPolicyProvider = backoffPolicyProvider; + this.stopwatchSupplier = stopwatchSupplier; + this.timeProvider = timeProvider; + this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); + logId = InternalLogId.allocate("xds-client", null); + logger = XdsLogger.withLogId(logId); + logger.log(XdsLogLevel.INFO, "Created"); + if (LOG_XDS_NODE_ID) { + classLogger.log(Level.INFO, "xDS node ID: {0}", bootstrapInfo.node().getId()); + } + } + + private void maybeCreateXdsChannelWithLrs(ServerInfo serverInfo) { + syncContext.throwIfNotInThisSynchronizationContext(); + if (serverChannelMap.containsKey(serverInfo)) { + return; + } + AbstractXdsClient xdsChannel = new AbstractXdsClient( + xdsChannelFactory, + serverInfo, + bootstrapInfo.node(), + this, + this, + context, + timeService, + syncContext, + backoffPolicyProvider, + stopwatchSupplier, + this); + LoadReportClient lrsClient = new LoadReportClient( + loadStatsManager, xdsChannel.channel(), context, bootstrapInfo.node(), syncContext, + timeService, backoffPolicyProvider, stopwatchSupplier); + serverChannelMap.put(serverInfo, xdsChannel); + serverLrsClientMap.put(serverInfo, lrsClient); + } + + @Override + public void handleResourceResponse( + XdsResourceType xdsResourceType, ServerInfo serverInfo, String versionInfo, + List resources, String nonce) { + checkNotNull(xdsResourceType, "xdsResourceType"); + syncContext.throwIfNotInThisSynchronizationContext(); + Set toParseResourceNames = null; + if (!(xdsResourceType == XdsListenerResource.getInstance() + || xdsResourceType == XdsRouteConfigureResource.getInstance()) + && resourceSubscribers.containsKey(xdsResourceType)) { + toParseResourceNames = resourceSubscribers.get(xdsResourceType).keySet(); + } + XdsResourceType.Args args = new XdsResourceType.Args(serverInfo, versionInfo, nonce, + bootstrapInfo, filterRegistry, loadBalancerRegistry, tlsContextManager, + toParseResourceNames); + handleResourceUpdate(args, resources, xdsResourceType); + } + + @Override + public void handleStreamClosed(Status error) { + syncContext.throwIfNotInThisSynchronizationContext(); + cleanUpResourceTimers(); + for (Map> subscriberMap : + resourceSubscribers.values()) { + for (ResourceSubscriber subscriber : subscriberMap.values()) { + if (!subscriber.hasResult()) { + subscriber.onError(error); + } + } + } + } + + @Override + public void handleStreamRestarted(ServerInfo serverInfo) { + syncContext.throwIfNotInThisSynchronizationContext(); + for (Map> subscriberMap : + resourceSubscribers.values()) { + for (ResourceSubscriber subscriber : subscriberMap.values()) { + if (subscriber.serverInfo.equals(serverInfo)) { + subscriber.restartTimer(); + } + } + } + } + + @Override + void shutdown() { + syncContext.execute( + new Runnable() { + @Override + public void run() { + if (isShutdown) { + return; + } + isShutdown = true; + for (AbstractXdsClient xdsChannel : serverChannelMap.values()) { + xdsChannel.shutdown(); + } + if (reportingLoad) { + for (final LoadReportClient lrsClient : serverLrsClientMap.values()) { + lrsClient.stopLoadReporting(); + } + } + cleanUpResourceTimers(); + } + }); + } + + @Override + boolean isShutDown() { + return isShutdown; + } + + @Override + public Map> getSubscribedResourceTypesWithTypeUrl() { + return Collections.unmodifiableMap(subscribedResourceTypeUrls); + } + + @Nullable + @Override + public Collection getSubscribedResources(ServerInfo serverInfo, + XdsResourceType type) { + Map> resources = + resourceSubscribers.getOrDefault(type, Collections.emptyMap()); + ImmutableSet.Builder builder = ImmutableSet.builder(); + for (String key : resources.keySet()) { + if (resources.get(key).serverInfo.equals(serverInfo)) { + builder.add(key); + } + } + Collection retVal = builder.build(); + return retVal.isEmpty() ? null : retVal; + } + + // As XdsClient APIs becomes resource agnostic, subscribed resource types are dynamic. + // ResourceTypes that do not have subscribers does not show up in the snapshot keys. + @Override + ListenableFuture, Map>> + getSubscribedResourcesMetadataSnapshot() { + final SettableFuture, Map>> future = + SettableFuture.create(); + syncContext.execute(new Runnable() { + @Override + public void run() { + // A map from a "resource type" to a map ("resource name": "resource metadata") + ImmutableMap.Builder, Map> metadataSnapshot = + ImmutableMap.builder(); + for (XdsResourceType resourceType: resourceSubscribers.keySet()) { + ImmutableMap.Builder metadataMap = ImmutableMap.builder(); + for (Map.Entry> resourceEntry + : resourceSubscribers.get(resourceType).entrySet()) { + metadataMap.put(resourceEntry.getKey(), resourceEntry.getValue().metadata); + } + metadataSnapshot.put(resourceType, metadataMap.buildOrThrow()); + } + future.set(metadataSnapshot.buildOrThrow()); + } + }); + return future; + } + + @Override + TlsContextManager getTlsContextManager() { + return tlsContextManager; + } + + @Override + void watchXdsResource(XdsResourceType type, String resourceName, + ResourceWatcher watcher) { + syncContext.execute(new Runnable() { + @Override + @SuppressWarnings("unchecked") + public void run() { + if (!resourceSubscribers.containsKey(type)) { + resourceSubscribers.put(type, new HashMap<>()); + subscribedResourceTypeUrls.put(type.typeUrl(), type); + } + ResourceSubscriber subscriber = + (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName);; + if (subscriber == null) { + logger.log(XdsLogLevel.INFO, "Subscribe {0} resource {1}", type, resourceName); + subscriber = new ResourceSubscriber<>(type, resourceName); + resourceSubscribers.get(type).put(resourceName, subscriber); + if (subscriber.xdsChannel != null) { + subscriber.xdsChannel.adjustResourceSubscription(type); + } + } + subscriber.addWatcher(watcher); + } + }); + } + + @Override + void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { + syncContext.execute(new Runnable() { + @Override + @SuppressWarnings("unchecked") + public void run() { + ResourceSubscriber subscriber = + (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName);; + subscriber.removeWatcher(watcher); + if (!subscriber.isWatched()) { + subscriber.cancelResourceWatch(); + resourceSubscribers.get(type).remove(resourceName); + if (subscriber.xdsChannel != null) { + subscriber.xdsChannel.adjustResourceSubscription(type); + } + if (resourceSubscribers.get(type).isEmpty()) { + resourceSubscribers.remove(type); + subscribedResourceTypeUrls.remove(type.typeUrl()); + + } + } + } + }); + } + + @Override + ClusterDropStats addClusterDropStats( + final ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName) { + ClusterDropStats dropCounter = + loadStatsManager.getClusterDropStats(clusterName, edsServiceName); + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!reportingLoad) { + serverLrsClientMap.get(serverInfo).startLoadReporting(); + reportingLoad = true; + } + } + }); + return dropCounter; + } + + @Override + ClusterLocalityStats addClusterLocalityStats( + final ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, + Locality locality) { + ClusterLocalityStats loadCounter = + loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality); + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!reportingLoad) { + serverLrsClientMap.get(serverInfo).startLoadReporting(); + reportingLoad = true; + } + } + }); + return loadCounter; + } + + @Override + Bootstrapper.BootstrapInfo getBootstrapInfo() { + return bootstrapInfo; + } + + @Override + public String toString() { + return logId.toString(); + } + + @Override + public void startSubscriberTimersIfNeeded(ServerInfo serverInfo) { + if (isShutDown()) { + return; + } + + syncContext.execute(new Runnable() { + @Override + public void run() { + if (isShutDown()) { + return; + } + + for (Map> subscriberMap : resourceSubscribers.values()) { + for (ResourceSubscriber subscriber : subscriberMap.values()) { + if (subscriber.serverInfo.equals(serverInfo) && subscriber.respTimer == null) { + subscriber.restartTimer(); + } + } + } + } + }); + } + + private void cleanUpResourceTimers() { + for (Map> subscriberMap : resourceSubscribers.values()) { + for (ResourceSubscriber subscriber : subscriberMap.values()) { + subscriber.stopTimer(); + } + } + } + + @SuppressWarnings("unchecked") + private void handleResourceUpdate(XdsResourceType.Args args, + List resources, + XdsResourceType xdsResourceType) { + ValidatedResourceUpdate result = xdsResourceType.parse(args, resources); + logger.log(XdsLogger.XdsLogLevel.INFO, + "Received {0} Response version {1} nonce {2}. Parsed resources: {3}", + xdsResourceType.typeName(), args.versionInfo, args.nonce, result.unpackedResources); + Map> parsedResources = result.parsedResources; + Set invalidResources = result.invalidResources; + List errors = result.errors; + String errorDetail = null; + if (errors.isEmpty()) { + checkArgument(invalidResources.isEmpty(), "found invalid resources but missing errors"); + serverChannelMap.get(args.serverInfo).ackResponse(xdsResourceType, args.versionInfo, + args.nonce); + } else { + errorDetail = Joiner.on('\n').join(errors); + logger.log(XdsLogLevel.WARNING, + "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", + xdsResourceType.typeName(), args.versionInfo, args.nonce, errorDetail); + serverChannelMap.get(args.serverInfo).nackResponse(xdsResourceType, args.nonce, errorDetail); + } + + long updateTime = timeProvider.currentTimeNanos(); + Map> subscribedResources = + resourceSubscribers.getOrDefault(xdsResourceType, Collections.emptyMap()); + for (Map.Entry> entry : subscribedResources.entrySet()) { + String resourceName = entry.getKey(); + ResourceSubscriber subscriber = (ResourceSubscriber) entry.getValue(); + + if (parsedResources.containsKey(resourceName)) { + // Happy path: the resource updated successfully. Notify the watchers of the update. + subscriber.onData(parsedResources.get(resourceName), args.versionInfo, updateTime); + continue; + } + + if (invalidResources.contains(resourceName)) { + // The resource update is invalid. Capture the error without notifying the watchers. + subscriber.onRejected(args.versionInfo, updateTime, errorDetail); + } + + // Nothing else to do for incremental ADS resources. + if (!xdsResourceType.isFullStateOfTheWorld()) { + continue; + } + + // Handle State of the World ADS: invalid resources. + if (invalidResources.contains(resourceName)) { + // The resource is missing. Reuse the cached resource if possible. + if (subscriber.data == null) { + // No cached data. Notify the watchers of an invalid update. + subscriber.onError(Status.UNAVAILABLE.withDescription(errorDetail)); + } + continue; + } + + // For State of the World services, notify watchers when their watched resource is missing + // from the ADS update. + subscriber.onAbsent(); + } + } + + /** + * Tracks a single subscribed resource. + */ + private final class ResourceSubscriber { + @Nullable private final ServerInfo serverInfo; + @Nullable private final AbstractXdsClient xdsChannel; + private final XdsResourceType type; + private final String resource; + private final Set> watchers = new HashSet<>(); + @Nullable private T data; + private boolean absent; + // Tracks whether the deletion has been ignored per bootstrap server feature. + // See https://github.com/grpc/proposal/blob/master/A53-xds-ignore-resource-deletion.md + private boolean resourceDeletionIgnored; + @Nullable private ScheduledHandle respTimer; + @Nullable private ResourceMetadata metadata; + @Nullable private String errorDescription; + + ResourceSubscriber(XdsResourceType type, String resource) { + syncContext.throwIfNotInThisSynchronizationContext(); + this.type = type; + this.resource = resource; + this.serverInfo = getServerInfo(resource); + if (serverInfo == null) { + this.errorDescription = "Wrong configuration: xds server does not exist for resource " + + resource; + this.xdsChannel = null; + return; + } + // Initialize metadata in UNKNOWN state to cover the case when resource subscriber, + // is created but not yet requested because the client is in backoff. + this.metadata = ResourceMetadata.newResourceMetadataUnknown(); + + AbstractXdsClient xdsChannelTemp = null; + try { + maybeCreateXdsChannelWithLrs(serverInfo); + xdsChannelTemp = serverChannelMap.get(serverInfo); + if (xdsChannelTemp.isInBackoff()) { + return; + } + } catch (IllegalArgumentException e) { + xdsChannelTemp = null; + this.errorDescription = "Bad configuration: " + e.getMessage(); + return; + } finally { + this.xdsChannel = xdsChannelTemp; + } + + restartTimer(); + } + + @Nullable + private ServerInfo getServerInfo(String resource) { + if (BootstrapperImpl.enableFederation && resource.startsWith(XDSTP_SCHEME)) { + URI uri = URI.create(resource); + String authority = uri.getAuthority(); + if (authority == null) { + authority = ""; + } + AuthorityInfo authorityInfo = bootstrapInfo.authorities().get(authority); + if (authorityInfo == null || authorityInfo.xdsServers().isEmpty()) { + return null; + } + return authorityInfo.xdsServers().get(0); + } + return bootstrapInfo.servers().get(0); // use first server + } + + void addWatcher(ResourceWatcher watcher) { + checkArgument(!watchers.contains(watcher), "watcher %s already registered", watcher); + watchers.add(watcher); + if (errorDescription != null) { + watcher.onError(Status.INVALID_ARGUMENT.withDescription(errorDescription)); + return; + } + if (data != null) { + notifyWatcher(watcher, data); + } else if (absent) { + watcher.onResourceDoesNotExist(resource); + } + } + + void removeWatcher(ResourceWatcher watcher) { + checkArgument(watchers.contains(watcher), "watcher %s not registered", watcher); + watchers.remove(watcher); + } + + void restartTimer() { + if (data != null || absent) { // resource already resolved + return; + } + if (!xdsChannel.isReady()) { // When channel becomes ready, it will trigger a restartTimer + return; + } + + class ResourceNotFound implements Runnable { + @Override + public void run() { + logger.log(XdsLogLevel.INFO, "{0} resource {1} initial fetch timeout", + type, resource); + respTimer = null; + onAbsent(); + } + + @Override + public String toString() { + return type + this.getClass().getSimpleName(); + } + } + + // Initial fetch scheduled or rescheduled, transition metadata state to REQUESTED. + metadata = ResourceMetadata.newResourceMetadataRequested(); + + respTimer = syncContext.schedule( + new ResourceNotFound(), INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS, + timeService); + } + + void stopTimer() { + if (respTimer != null && respTimer.isPending()) { + respTimer.cancel(); + respTimer = null; + } + } + + void cancelResourceWatch() { + if (isWatched()) { + throw new IllegalStateException("Can't cancel resource watch with active watchers present"); + } + stopTimer(); + String message = "Unsubscribing {0} resource {1} from server {2}"; + XdsLogLevel logLevel = XdsLogLevel.INFO; + if (resourceDeletionIgnored) { + message += " for which we previously ignored a deletion"; + logLevel = XdsLogLevel.FORCE_INFO; + } + logger.log(logLevel, message, type, resource, + serverInfo != null ? serverInfo.target() : "unknown"); + } + + boolean isWatched() { + return !watchers.isEmpty(); + } + + boolean hasResult() { + return data != null || absent; + } + + void onData(ParsedResource parsedResource, String version, long updateTime) { + if (respTimer != null && respTimer.isPending()) { + respTimer.cancel(); + respTimer = null; + } + this.metadata = ResourceMetadata + .newResourceMetadataAcked(parsedResource.getRawResource(), version, updateTime); + ResourceUpdate oldData = this.data; + this.data = parsedResource.getResourceUpdate(); + absent = false; + if (resourceDeletionIgnored) { + logger.log(XdsLogLevel.FORCE_INFO, "xds server {0}: server returned new version " + + "of resource for which we previously ignored a deletion: type {1} name {2}", + serverInfo != null ? serverInfo.target() : "unknown", type, resource); + resourceDeletionIgnored = false; + } + if (!Objects.equals(oldData, data)) { + for (ResourceWatcher watcher : watchers) { + notifyWatcher(watcher, data); + } + } + } + + void onAbsent() { + if (respTimer != null && respTimer.isPending()) { // too early to conclude absence + return; + } + + // Ignore deletion of State of the World resources when this feature is on, + // and the resource is reusable. + boolean ignoreResourceDeletionEnabled = + serverInfo != null && serverInfo.ignoreResourceDeletion(); + if (ignoreResourceDeletionEnabled && type.isFullStateOfTheWorld() && data != null) { + if (!resourceDeletionIgnored) { + logger.log(XdsLogLevel.FORCE_WARNING, + "xds server {0}: ignoring deletion for resource type {1} name {2}}", + serverInfo.target(), type, resource); + resourceDeletionIgnored = true; + } + return; + } + + logger.log(XdsLogLevel.INFO, "Conclude {0} resource {1} not exist", type, resource); + if (!absent) { + data = null; + absent = true; + metadata = ResourceMetadata.newResourceMetadataDoesNotExist(); + for (ResourceWatcher watcher : watchers) { + watcher.onResourceDoesNotExist(resource); + } + } + } + + void onError(Status error) { + if (respTimer != null && respTimer.isPending()) { + respTimer.cancel(); + respTimer = null; + } + + // Include node ID in xds failures to allow cross-referencing with control plane logs + // when debugging. + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorAugmented = Status.fromCode(error.getCode()) + .withDescription(description + "nodeID: " + bootstrapInfo.node().getId()) + .withCause(error.getCause()); + + for (ResourceWatcher watcher : watchers) { + watcher.onError(errorAugmented); + } + } + + void onRejected(String rejectedVersion, long rejectedTime, String rejectedDetails) { + metadata = ResourceMetadata + .newResourceMetadataNacked(metadata, rejectedVersion, rejectedTime, rejectedDetails); + } + + private void notifyWatcher(ResourceWatcher watcher, T update) { + watcher.onChanged(update); + } + } + + static final class ResourceInvalidException extends Exception { + private static final long serialVersionUID = 0L; + + ResourceInvalidException(String message) { + super(message, null, false, false); + } + + ResourceInvalidException(String message, Throwable cause) { + super(cause != null ? message + ": " + cause.getMessage() : message, cause, false, false); + } + } + + abstract static class XdsChannelFactory { + static final XdsChannelFactory DEFAULT_XDS_CHANNEL_FACTORY = new XdsChannelFactory() { + @Override + ManagedChannel create(ServerInfo serverInfo) { + String target = serverInfo.target(); + ChannelCredentials channelCredentials = serverInfo.channelCredentials(); + return Grpc.newChannelBuilder(target, channelCredentials) + .keepAliveTime(5, TimeUnit.MINUTES) + .build(); + } + }; + + abstract ManagedChannel create(ServerInfo serverInfo); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java new file mode 100644 index 00000000000..33f6176474b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java @@ -0,0 +1,671 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.Bootstrapper.ServerInfo; + +import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Duration; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers.Thresholds; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.core.v3.RoutingPriority; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.XdsClient.ResourceUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import javax.annotation.Nullable; + +class XdsClusterResource extends XdsResourceType { + static final String ADS_TYPE_URL_CDS = + "type.googleapis.com/envoy.config.cluster.v3.Cluster"; + private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT = + "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext"; + private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT_V2 = + "type.googleapis.com/envoy.api.v2.auth.UpstreamTlsContext"; + + private static final XdsClusterResource instance = new XdsClusterResource(); + + public static XdsClusterResource getInstance() { + return instance; + } + + @Override + @Nullable + String extractResourceName(Message unpackedResource) { + if (!(unpackedResource instanceof Cluster)) { + return null; + } + return ((Cluster) unpackedResource).getName(); + } + + @Override + String typeName() { + return "CDS"; + } + + @Override + String typeUrl() { + return ADS_TYPE_URL_CDS; + } + + @Override + boolean isFullStateOfTheWorld() { + return true; + } + + @Override + @SuppressWarnings("unchecked") + Class unpackedClassName() { + return Cluster.class; + } + + @Override + CdsUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + if (!(unpackedMessage instanceof Cluster)) { + throw new ResourceInvalidException("Invalid message type: " + unpackedMessage.getClass()); + } + Set certProviderInstances = null; + if (args.bootstrapInfo != null && args.bootstrapInfo.certProviders() != null) { + certProviderInstances = args.bootstrapInfo.certProviders().keySet(); + } + return processCluster((Cluster) unpackedMessage, certProviderInstances, + args.serverInfo, args.loadBalancerRegistry); + } + + @VisibleForTesting + static CdsUpdate processCluster(Cluster cluster, + Set certProviderInstances, + Bootstrapper.ServerInfo serverInfo, + LoadBalancerRegistry loadBalancerRegistry) + throws ResourceInvalidException { + StructOrError structOrError; + switch (cluster.getClusterDiscoveryTypeCase()) { + case TYPE: + structOrError = parseNonAggregateCluster(cluster, + certProviderInstances, serverInfo); + break; + case CLUSTER_TYPE: + structOrError = parseAggregateCluster(cluster); + break; + case CLUSTERDISCOVERYTYPE_NOT_SET: + default: + throw new ResourceInvalidException( + "Cluster " + cluster.getName() + ": unspecified cluster discovery type"); + } + if (structOrError.getErrorDetail() != null) { + throw new ResourceInvalidException(structOrError.getErrorDetail()); + } + CdsUpdate.Builder updateBuilder = structOrError.getStruct(); + + ImmutableMap lbPolicyConfig = LoadBalancerConfigFactory.newConfig(cluster, + enableLeastRequest, enableCustomLbConfig); + + // Validate the LB config by trying to parse it with the corresponding LB provider. + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(lbPolicyConfig); + NameResolver.ConfigOrError configOrError = loadBalancerRegistry.getProvider( + lbConfig.getPolicyName()).parseLoadBalancingPolicyConfig( + lbConfig.getRawConfigValue()); + if (configOrError.getError() != null) { + throw new ResourceInvalidException(structOrError.getErrorDetail()); + } + + updateBuilder.lbPolicyConfig(lbPolicyConfig); + + return updateBuilder.build(); + } + + private static StructOrError parseAggregateCluster(Cluster cluster) { + String clusterName = cluster.getName(); + Cluster.CustomClusterType customType = cluster.getClusterType(); + String typeName = customType.getName(); + if (!typeName.equals(AGGREGATE_CLUSTER_TYPE_NAME)) { + return StructOrError.fromError( + "Cluster " + clusterName + ": unsupported custom cluster type: " + typeName); + } + io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig clusterConfig; + try { + clusterConfig = unpackCompatibleType(customType.getTypedConfig(), + io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig.class, + TYPE_URL_CLUSTER_CONFIG, null); + } catch (InvalidProtocolBufferException e) { + return StructOrError.fromError("Cluster " + clusterName + ": malformed ClusterConfig: " + e); + } + return StructOrError.fromStruct(CdsUpdate.forAggregate( + clusterName, clusterConfig.getClustersList())); + } + + private static StructOrError parseNonAggregateCluster( + Cluster cluster, Set certProviderInstances, Bootstrapper.ServerInfo serverInfo) { + String clusterName = cluster.getName(); + Bootstrapper.ServerInfo lrsServerInfo = null; + Long maxConcurrentRequests = null; + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = null; + OutlierDetection outlierDetection = null; + if (cluster.hasLrsServer()) { + if (!cluster.getLrsServer().hasSelf()) { + return StructOrError.fromError( + "Cluster " + clusterName + ": only support LRS for the same management server"); + } + lrsServerInfo = serverInfo; + } + if (cluster.hasCircuitBreakers()) { + List thresholds = cluster.getCircuitBreakers().getThresholdsList(); + for (Thresholds threshold : thresholds) { + if (threshold.getPriority() != RoutingPriority.DEFAULT) { + continue; + } + if (threshold.hasMaxRequests()) { + maxConcurrentRequests = (long) threshold.getMaxRequests().getValue(); + } + } + } + if (cluster.getTransportSocketMatchesCount() > 0) { + return StructOrError.fromError("Cluster " + clusterName + + ": transport-socket-matches not supported."); + } + if (cluster.hasTransportSocket()) { + if (!TRANSPORT_SOCKET_NAME_TLS.equals(cluster.getTransportSocket().getName())) { + return StructOrError.fromError("transport-socket with name " + + cluster.getTransportSocket().getName() + " not supported."); + } + try { + upstreamTlsContext = UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( + validateUpstreamTlsContext( + unpackCompatibleType(cluster.getTransportSocket().getTypedConfig(), + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext.class, + TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), + certProviderInstances)); + } catch (InvalidProtocolBufferException | ResourceInvalidException e) { + return StructOrError.fromError( + "Cluster " + clusterName + ": malformed UpstreamTlsContext: " + e); + } + } + + if (cluster.hasOutlierDetection() && enableOutlierDetection) { + try { + outlierDetection = OutlierDetection.fromEnvoyOutlierDetection( + validateOutlierDetection(cluster.getOutlierDetection())); + } catch (ResourceInvalidException e) { + return StructOrError.fromError( + "Cluster " + clusterName + ": malformed outlier_detection: " + e); + } + } + + Cluster.DiscoveryType type = cluster.getType(); + if (type == Cluster.DiscoveryType.EDS) { + String edsServiceName = null; + io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig edsClusterConfig = + cluster.getEdsClusterConfig(); + if (!edsClusterConfig.getEdsConfig().hasAds() + && ! edsClusterConfig.getEdsConfig().hasSelf()) { + return StructOrError.fromError( + "Cluster " + clusterName + ": field eds_cluster_config must be set to indicate to use" + + " EDS over ADS or self ConfigSource"); + } + // If the service_name field is set, that value will be used for the EDS request. + if (!edsClusterConfig.getServiceName().isEmpty()) { + edsServiceName = edsClusterConfig.getServiceName(); + } + return StructOrError.fromStruct(CdsUpdate.forEds( + clusterName, edsServiceName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext, + outlierDetection)); + } else if (type.equals(Cluster.DiscoveryType.LOGICAL_DNS)) { + if (!cluster.hasLoadAssignment()) { + return StructOrError.fromError( + "Cluster " + clusterName + ": LOGICAL_DNS clusters must have a single host"); + } + ClusterLoadAssignment assignment = cluster.getLoadAssignment(); + if (assignment.getEndpointsCount() != 1 + || assignment.getEndpoints(0).getLbEndpointsCount() != 1) { + return StructOrError.fromError( + "Cluster " + clusterName + ": LOGICAL_DNS clusters must have a single " + + "locality_lb_endpoint and a single lb_endpoint"); + } + io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint lbEndpoint = + assignment.getEndpoints(0).getLbEndpoints(0); + if (!lbEndpoint.hasEndpoint() || !lbEndpoint.getEndpoint().hasAddress() + || !lbEndpoint.getEndpoint().getAddress().hasSocketAddress()) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": LOGICAL_DNS clusters must have an endpoint with address and socket_address"); + } + SocketAddress socketAddress = lbEndpoint.getEndpoint().getAddress().getSocketAddress(); + if (!socketAddress.getResolverName().isEmpty()) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": LOGICAL DNS clusters must NOT have a custom resolver name set"); + } + if (socketAddress.getPortSpecifierCase() != SocketAddress.PortSpecifierCase.PORT_VALUE) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": LOGICAL DNS clusters socket_address must have port_value"); + } + String dnsHostName = String.format( + Locale.US, "%s:%d", socketAddress.getAddress(), socketAddress.getPortValue()); + return StructOrError.fromStruct(CdsUpdate.forLogicalDns( + clusterName, dnsHostName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext)); + } + return StructOrError.fromError( + "Cluster " + clusterName + ": unsupported built-in discovery type: " + type); + } + + static io.envoyproxy.envoy.config.cluster.v3.OutlierDetection validateOutlierDetection( + io.envoyproxy.envoy.config.cluster.v3.OutlierDetection outlierDetection) + throws ResourceInvalidException { + if (outlierDetection.hasInterval()) { + if (!Durations.isValid(outlierDetection.getInterval())) { + throw new ResourceInvalidException("outlier_detection interval is not a valid Duration"); + } + if (hasNegativeValues(outlierDetection.getInterval())) { + throw new ResourceInvalidException("outlier_detection interval has a negative value"); + } + } + if (outlierDetection.hasBaseEjectionTime()) { + if (!Durations.isValid(outlierDetection.getBaseEjectionTime())) { + throw new ResourceInvalidException( + "outlier_detection base_ejection_time is not a valid Duration"); + } + if (hasNegativeValues(outlierDetection.getBaseEjectionTime())) { + throw new ResourceInvalidException( + "outlier_detection base_ejection_time has a negative value"); + } + } + if (outlierDetection.hasMaxEjectionTime()) { + if (!Durations.isValid(outlierDetection.getMaxEjectionTime())) { + throw new ResourceInvalidException( + "outlier_detection max_ejection_time is not a valid Duration"); + } + if (hasNegativeValues(outlierDetection.getMaxEjectionTime())) { + throw new ResourceInvalidException( + "outlier_detection max_ejection_time has a negative value"); + } + } + if (outlierDetection.hasMaxEjectionPercent() + && outlierDetection.getMaxEjectionPercent().getValue() > 100) { + throw new ResourceInvalidException( + "outlier_detection max_ejection_percent is > 100"); + } + if (outlierDetection.hasEnforcingSuccessRate() + && outlierDetection.getEnforcingSuccessRate().getValue() > 100) { + throw new ResourceInvalidException( + "outlier_detection enforcing_success_rate is > 100"); + } + if (outlierDetection.hasFailurePercentageThreshold() + && outlierDetection.getFailurePercentageThreshold().getValue() > 100) { + throw new ResourceInvalidException( + "outlier_detection failure_percentage_threshold is > 100"); + } + if (outlierDetection.hasEnforcingFailurePercentage() + && outlierDetection.getEnforcingFailurePercentage().getValue() > 100) { + throw new ResourceInvalidException( + "outlier_detection enforcing_failure_percentage is > 100"); + } + + return outlierDetection; + } + + static boolean hasNegativeValues(Duration duration) { + return duration.getSeconds() < 0 || duration.getNanos() < 0; + } + + @VisibleForTesting + static io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + validateUpstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext, + Set certProviderInstances) + throws ResourceInvalidException { + if (upstreamTlsContext.hasCommonTlsContext()) { + validateCommonTlsContext(upstreamTlsContext.getCommonTlsContext(), certProviderInstances, + false); + } else { + throw new ResourceInvalidException("common-tls-context is required in upstream-tls-context"); + } + return upstreamTlsContext; + } + + @VisibleForTesting + static void validateCommonTlsContext( + CommonTlsContext commonTlsContext, Set certProviderInstances, boolean server) + throws ResourceInvalidException { + if (commonTlsContext.hasCustomHandshaker()) { + throw new ResourceInvalidException( + "common-tls-context with custom_handshaker is not supported"); + } + if (commonTlsContext.hasTlsParams()) { + throw new ResourceInvalidException("common-tls-context with tls_params is not supported"); + } + if (commonTlsContext.hasValidationContextSdsSecretConfig()) { + throw new ResourceInvalidException( + "common-tls-context with validation_context_sds_secret_config is not supported"); + } + if (commonTlsContext.hasValidationContextCertificateProvider()) { + throw new ResourceInvalidException( + "common-tls-context with validation_context_certificate_provider is not supported"); + } + if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { + throw new ResourceInvalidException( + "common-tls-context with validation_context_certificate_provider_instance is not" + + " supported"); + } + String certInstanceName = getIdentityCertInstanceName(commonTlsContext); + if (certInstanceName == null) { + if (server) { + throw new ResourceInvalidException( + "tls_certificate_provider_instance is required in downstream-tls-context"); + } + if (commonTlsContext.getTlsCertificatesCount() > 0) { + throw new ResourceInvalidException( + "tls_certificate_provider_instance is unset"); + } + if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { + throw new ResourceInvalidException( + "tls_certificate_provider_instance is unset"); + } + if (commonTlsContext.hasTlsCertificateCertificateProvider()) { + throw new ResourceInvalidException( + "tls_certificate_provider_instance is unset"); + } + } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { + throw new ResourceInvalidException( + "CertificateProvider instance name '" + certInstanceName + + "' not defined in the bootstrap file."); + } + String rootCaInstanceName = getRootCertInstanceName(commonTlsContext); + if (rootCaInstanceName == null) { + if (!server) { + throw new ResourceInvalidException( + "ca_certificate_provider_instance is required in upstream-tls-context"); + } + } else { + if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { + throw new ResourceInvalidException( + "ca_certificate_provider_instance name '" + rootCaInstanceName + + "' not defined in the bootstrap file."); + } + CertificateValidationContext certificateValidationContext = null; + if (commonTlsContext.hasValidationContext()) { + certificateValidationContext = commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext() && commonTlsContext + .getCombinedValidationContext().hasDefaultValidationContext()) { + certificateValidationContext = commonTlsContext.getCombinedValidationContext() + .getDefaultValidationContext(); + } + if (certificateValidationContext != null) { + if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) { + throw new ResourceInvalidException( + "match_subject_alt_names only allowed in upstream_tls_context"); + } + if (certificateValidationContext.getVerifyCertificateSpkiCount() > 0) { + throw new ResourceInvalidException( + "verify_certificate_spki in default_validation_context is not supported"); + } + if (certificateValidationContext.getVerifyCertificateHashCount() > 0) { + throw new ResourceInvalidException( + "verify_certificate_hash in default_validation_context is not supported"); + } + if (certificateValidationContext.hasRequireSignedCertificateTimestamp()) { + throw new ResourceInvalidException( + "require_signed_certificate_timestamp in default_validation_context is not " + + "supported"); + } + if (certificateValidationContext.hasCrl()) { + throw new ResourceInvalidException("crl in default_validation_context is not supported"); + } + if (certificateValidationContext.hasCustomValidatorConfig()) { + throw new ResourceInvalidException( + "custom_validator_config in default_validation_context is not supported"); + } + } + } + } + + private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName(); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName(); + } + return null; + } + + private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + if (commonTlsContext.getValidationContext().hasCaCertificateProviderInstance()) { + return commonTlsContext.getValidationContext().getCaCertificateProviderInstance() + .getInstanceName(); + } + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext + = commonTlsContext.getCombinedValidationContext(); + if (combinedCertificateValidationContext.hasDefaultValidationContext() + && combinedCertificateValidationContext.getDefaultValidationContext() + .hasCaCertificateProviderInstance()) { + return combinedCertificateValidationContext.getDefaultValidationContext() + .getCaCertificateProviderInstance().getInstanceName(); + } else if (combinedCertificateValidationContext + .hasValidationContextCertificateProviderInstance()) { + return combinedCertificateValidationContext + .getValidationContextCertificateProviderInstance().getInstanceName(); + } + } + return null; + } + + /** xDS resource update for cluster-level configuration. */ + @AutoValue + abstract static class CdsUpdate implements ResourceUpdate { + abstract String clusterName(); + + abstract ClusterType clusterType(); + + abstract ImmutableMap lbPolicyConfig(); + + // Only valid if lbPolicy is "ring_hash_experimental". + abstract long minRingSize(); + + // Only valid if lbPolicy is "ring_hash_experimental". + abstract long maxRingSize(); + + // Only valid if lbPolicy is "least_request_experimental". + abstract int choiceCount(); + + // Alternative resource name to be used in EDS requests. + /// Only valid for EDS cluster. + @Nullable + abstract String edsServiceName(); + + // Corresponding DNS name to be used if upstream endpoints of the cluster is resolvable + // via DNS. + // Only valid for LOGICAL_DNS cluster. + @Nullable + abstract String dnsHostName(); + + // Load report server info for reporting loads via LRS. + // Only valid for EDS or LOGICAL_DNS cluster. + @Nullable + abstract ServerInfo lrsServerInfo(); + + // Max number of concurrent requests can be sent to this cluster. + // Only valid for EDS or LOGICAL_DNS cluster. + @Nullable + abstract Long maxConcurrentRequests(); + + // TLS context used to connect to connect to this cluster. + // Only valid for EDS or LOGICAL_DNS cluster. + @Nullable + abstract UpstreamTlsContext upstreamTlsContext(); + + // List of underlying clusters making of this aggregate cluster. + // Only valid for AGGREGATE cluster. + @Nullable + abstract ImmutableList prioritizedClusterNames(); + + // Outlier detection configuration. + @Nullable + abstract OutlierDetection outlierDetection(); + + static Builder forAggregate(String clusterName, List prioritizedClusterNames) { + checkNotNull(prioritizedClusterNames, "prioritizedClusterNames"); + return new AutoValue_XdsClusterResource_CdsUpdate.Builder() + .clusterName(clusterName) + .clusterType(ClusterType.AGGREGATE) + .minRingSize(0) + .maxRingSize(0) + .choiceCount(0) + .prioritizedClusterNames(ImmutableList.copyOf(prioritizedClusterNames)); + } + + static Builder forEds(String clusterName, @Nullable String edsServiceName, + @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, + @Nullable UpstreamTlsContext upstreamTlsContext, + @Nullable OutlierDetection outlierDetection) { + return new AutoValue_XdsClusterResource_CdsUpdate.Builder() + .clusterName(clusterName) + .clusterType(ClusterType.EDS) + .minRingSize(0) + .maxRingSize(0) + .choiceCount(0) + .edsServiceName(edsServiceName) + .lrsServerInfo(lrsServerInfo) + .maxConcurrentRequests(maxConcurrentRequests) + .upstreamTlsContext(upstreamTlsContext) + .outlierDetection(outlierDetection); + } + + static Builder forLogicalDns(String clusterName, String dnsHostName, + @Nullable ServerInfo lrsServerInfo, + @Nullable Long maxConcurrentRequests, + @Nullable UpstreamTlsContext upstreamTlsContext) { + return new AutoValue_XdsClusterResource_CdsUpdate.Builder() + .clusterName(clusterName) + .clusterType(ClusterType.LOGICAL_DNS) + .minRingSize(0) + .maxRingSize(0) + .choiceCount(0) + .dnsHostName(dnsHostName) + .lrsServerInfo(lrsServerInfo) + .maxConcurrentRequests(maxConcurrentRequests) + .upstreamTlsContext(upstreamTlsContext); + } + + enum ClusterType { + EDS, LOGICAL_DNS, AGGREGATE + } + + enum LbPolicy { + ROUND_ROBIN, RING_HASH, LEAST_REQUEST + } + + // FIXME(chengyuanzhang): delete this after UpstreamTlsContext's toString() is fixed. + @Override + public final String toString() { + return MoreObjects.toStringHelper(this) + .add("clusterName", clusterName()) + .add("clusterType", clusterType()) + .add("lbPolicyConfig", lbPolicyConfig()) + .add("minRingSize", minRingSize()) + .add("maxRingSize", maxRingSize()) + .add("choiceCount", choiceCount()) + .add("edsServiceName", edsServiceName()) + .add("dnsHostName", dnsHostName()) + .add("lrsServerInfo", lrsServerInfo()) + .add("maxConcurrentRequests", maxConcurrentRequests()) + // Exclude upstreamTlsContext and outlierDetection as their string representations are + // cumbersome. + .add("prioritizedClusterNames", prioritizedClusterNames()) + .toString(); + } + + @AutoValue.Builder + abstract static class Builder { + // Private, use one of the static factory methods instead. + protected abstract Builder clusterName(String clusterName); + + // Private, use one of the static factory methods instead. + protected abstract Builder clusterType(ClusterType clusterType); + + protected abstract Builder lbPolicyConfig(ImmutableMap lbPolicyConfig); + + Builder roundRobinLbPolicy() { + return this.lbPolicyConfig(ImmutableMap.of("round_robin", ImmutableMap.of())); + } + + Builder ringHashLbPolicy(Long minRingSize, Long maxRingSize) { + return this.lbPolicyConfig(ImmutableMap.of("ring_hash_experimental", + ImmutableMap.of("minRingSize", minRingSize.doubleValue(), "maxRingSize", + maxRingSize.doubleValue()))); + } + + Builder leastRequestLbPolicy(Integer choiceCount) { + return this.lbPolicyConfig(ImmutableMap.of("least_request_experimental", + ImmutableMap.of("choiceCount", choiceCount.doubleValue()))); + } + + // Private, use leastRequestLbPolicy(int). + protected abstract Builder choiceCount(int choiceCount); + + // Private, use ringHashLbPolicy(long, long). + protected abstract Builder minRingSize(long minRingSize); + + // Private, use ringHashLbPolicy(long, long). + protected abstract Builder maxRingSize(long maxRingSize); + + // Private, use CdsUpdate.forEds() instead. + protected abstract Builder edsServiceName(String edsServiceName); + + // Private, use CdsUpdate.forLogicalDns() instead. + protected abstract Builder dnsHostName(String dnsHostName); + + // Private, use one of the static factory methods instead. + protected abstract Builder lrsServerInfo(ServerInfo lrsServerInfo); + + // Private, use one of the static factory methods instead. + protected abstract Builder maxConcurrentRequests(Long maxConcurrentRequests); + + // Private, use one of the static factory methods instead. + protected abstract Builder upstreamTlsContext(UpstreamTlsContext upstreamTlsContext); + + // Private, use CdsUpdate.forAggregate() instead. + protected abstract Builder prioritizedClusterNames(List prioritizedClusterNames); + + protected abstract Builder outlierDetection(OutlierDetection outlierDetection); + + abstract CdsUpdate build(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsCredentialsProvider.java b/xds/src/main/java/io/grpc/xds/XdsCredentialsProvider.java new file mode 100644 index 00000000000..e9466f37a0a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsCredentialsProvider.java @@ -0,0 +1,74 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.ChannelCredentials; +import io.grpc.Internal; +import java.util.Map; + +/** + * Provider of credentials which can be consumed by clients for xds communications. The actual + * credential to be used for a particular xds communication will be chosen based on the bootstrap + * configuration. + * + *

    Implementations can be automatically discovered by gRPC via Java's SPI mechanism. For + * automatic discovery, the implementation must have a zero-argument constructor and include + * a resource named {@code META-INF/services/io.grpc.xds.XdsCredentialsProvider} in their JAR. The + * file's contents should be the implementation's class name. + * Implementations that need arguments in their constructor can be manually registered by + * {@link XdsCredentialsRegistry#register}. + * + *

    Implementations should not throw. If they do, it may interrupt class loading. If + * exceptions may reasonably occur for implementation-specific reasons, implementations should + * generally handle the exception gracefully and return {@code false} from {@link #isAvailable()}. + */ +@Internal +public abstract class XdsCredentialsProvider { + /** + * Creates a {@link ChannelCredentials} from the given jsonConfig, or + * {@code null} if the given config is invalid. The provider is free to ignore + * the config if it's not needed for producing the channel credentials. + * + * @param jsonConfig json config that can be consumed by the provider to create + * the channel credentials + * + */ + protected abstract ChannelCredentials newChannelCredentials(Map jsonConfig); + + /** + * Returns the xDS credential name associated with this provider which makes it selectable + * via {@link XdsCredentialsRegistry#getProvider}. This is called only when the class is loaded. + * It shouldn't change, and there is no point doing so. + */ + protected abstract String getName(); + + /** + * Whether this provider is available for use, taking the current environment + * into consideration. + * If {@code false}, {@link #newChannelCredentials} is not safe to be called. + */ + public abstract boolean isAvailable(); + + /** + * A priority, from 0 to 10 that this provider should be used, taking the + * current environment into consideration. + * 5 should be considered the default, and then tweaked based on + * environment detection. A priority of 0 does not imply that the provider + * wouldn't work; just that it should be last in line. + */ + public abstract int priority(); +} diff --git a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java new file mode 100644 index 00000000000..c33b3cd2f85 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java @@ -0,0 +1,189 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import io.grpc.InternalServiceProviders; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Registry of {@link XdsCredentialsProvider}s. The {@link #getDefaultRegistry default + * instance} loads providers at runtime through the Java service provider mechanism. + */ +@ThreadSafe +final class XdsCredentialsRegistry { + private static final Logger logger = Logger.getLogger(XdsCredentialsRegistry.class.getName()); + private static XdsCredentialsRegistry instance; + + @GuardedBy("this") + private final LinkedHashSet allProviders = new LinkedHashSet<>(); + + /** + * Generated from {@code allProviders}. Is mapping from scheme key to the + * highest priority {@link XdsCredentialsProvider}. + * Is replaced instead of mutating. + */ + @GuardedBy("this") + private ImmutableMap effectiveProviders = ImmutableMap.of(); + + /** + * Register a provider. + * + *

    If the provider's {@link XdsCredentialsProvider#isAvailable isAvailable()} + * returns {@code false}, this method will throw {@link IllegalArgumentException}. + * + *

    Providers will be used in priority order. In case of ties, providers are used + * in registration order. + */ + public synchronized void register(XdsCredentialsProvider provider) { + addProvider(provider); + refreshProviders(); + } + + private synchronized void addProvider(XdsCredentialsProvider provider) { + checkArgument(provider.isAvailable(), "isAvailable() returned false"); + allProviders.add(provider); + } + + /** + * Deregisters a provider. No-op if the provider is not in the registry. + * + * @param provider the provider that was added to the register via + * {@link #register}. + */ + public synchronized void deregister(XdsCredentialsProvider provider) { + allProviders.remove(provider); + refreshProviders(); + } + + private synchronized void refreshProviders() { + Map refreshedProviders = new HashMap<>(); + int maxPriority = Integer.MIN_VALUE; + // We prefer first-registered providers. + for (XdsCredentialsProvider provider : allProviders) { + String credsName = provider.getName(); + XdsCredentialsProvider existing = refreshedProviders.get(credsName); + if (existing == null || existing.priority() < provider.priority()) { + refreshedProviders.put(credsName, provider); + } + if (maxPriority < provider.priority()) { + maxPriority = provider.priority(); + } + } + effectiveProviders = ImmutableMap.copyOf(refreshedProviders); + } + + /** + * Returns the default registry that loads providers via the Java service loader + * mechanism. + */ + public static synchronized XdsCredentialsRegistry getDefaultRegistry() { + if (instance == null) { + List providerList = InternalServiceProviders.loadAll( + XdsCredentialsProvider.class, + getHardCodedClasses(), + XdsCredentialsProvider.class.getClassLoader(), + new XdsCredentialsProviderPriorityAccessor()); + if (providerList.isEmpty()) { + logger.warning("No XdsCredsRegistry found via ServiceLoader, including for GoogleDefault, " + + "TLS and Insecure. This is probably due to a broken build."); + } + instance = new XdsCredentialsRegistry(); + for (XdsCredentialsProvider provider : providerList) { + logger.fine("Service loader found " + provider); + if (provider.isAvailable()) { + instance.addProvider(provider); + } + } + instance.refreshProviders(); + } + return instance; + } + + /** + * Returns effective providers map from scheme to the highest priority + * XdsCredsProvider of that scheme. + */ + @VisibleForTesting + synchronized Map providers() { + return effectiveProviders; + } + + /** + * Returns the effective provider for the given xds credential name, or {@code null} if no + * suitable provider can be found. + * Each provider declares its name via {@link XdsCredentialsProvider#getName}. + */ + @Nullable + public synchronized XdsCredentialsProvider getProvider(String name) { + return effectiveProviders.get(checkNotNull(name, "name")); + } + + @VisibleForTesting + static List> getHardCodedClasses() { + // Class.forName(String) is used to remove the need for ProGuard configuration. Note that + // ProGuard does not detect usages of Class.forName(String, boolean, ClassLoader): + // https://sourceforge.net/p/proguard/bugs/418/ + ArrayList> list = new ArrayList<>(); + try { + list.add(Class.forName("io.grpc.xds.internal.GoogleDefaultXdsCredentialsProvider")); + } catch (ClassNotFoundException e) { + logger.log(Level.WARNING, "Unable to find GoogleDefaultXdsCredentialsProvider", e); + } + + try { + list.add(Class.forName("io.grpc.xds.internal.InsecureXdsCredentialsProvider")); + } catch (ClassNotFoundException e) { + logger.log(Level.WARNING, "Unable to find InsecureXdsCredentialsProvider", e); + } + + try { + list.add(Class.forName("io.grpc.xds.internal.TlsXdsCredentialsProvider")); + } catch (ClassNotFoundException e) { + logger.log(Level.WARNING, "Unable to find TlsXdsCredentialsProvider", e); + } + + return Collections.unmodifiableList(list); + } + + private static final class XdsCredentialsProviderPriorityAccessor + implements InternalServiceProviders.PriorityAccessor { + @Override + public boolean isAvailable(XdsCredentialsProvider provider) { + return provider.isAvailable(); + } + + @Override + public int getPriority(XdsCredentialsProvider provider) { + return provider.priority(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java new file mode 100644 index 00000000000..39caa9a8597 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java @@ -0,0 +1,248 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.grpc.EquivalentAddressGroup; +import io.grpc.xds.Endpoints.DropOverload; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.XdsClient.ResourceUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import javax.annotation.Nullable; + +class XdsEndpointResource extends XdsResourceType { + static final String ADS_TYPE_URL_EDS = + "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + + private static final XdsEndpointResource instance = new XdsEndpointResource(); + + public static XdsEndpointResource getInstance() { + return instance; + } + + @Override + @Nullable + String extractResourceName(Message unpackedResource) { + if (!(unpackedResource instanceof ClusterLoadAssignment)) { + return null; + } + return ((ClusterLoadAssignment) unpackedResource).getClusterName(); + } + + @Override + String typeName() { + return "EDS"; + } + + @Override + String typeUrl() { + return ADS_TYPE_URL_EDS; + } + + @Override + boolean isFullStateOfTheWorld() { + return false; + } + + @Override + Class unpackedClassName() { + return ClusterLoadAssignment.class; + } + + @Override + EdsUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + if (!(unpackedMessage instanceof ClusterLoadAssignment)) { + throw new ResourceInvalidException("Invalid message type: " + unpackedMessage.getClass()); + } + return processClusterLoadAssignment((ClusterLoadAssignment) unpackedMessage); + } + + private static EdsUpdate processClusterLoadAssignment(ClusterLoadAssignment assignment) + throws ResourceInvalidException { + Map> priorities = new HashMap<>(); + Map localityLbEndpointsMap = new LinkedHashMap<>(); + List dropOverloads = new ArrayList<>(); + int maxPriority = -1; + for (io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints localityLbEndpointsProto + : assignment.getEndpointsList()) { + StructOrError structOrError = + parseLocalityLbEndpoints(localityLbEndpointsProto); + if (structOrError == null) { + continue; + } + if (structOrError.getErrorDetail() != null) { + throw new ResourceInvalidException(structOrError.getErrorDetail()); + } + + LocalityLbEndpoints localityLbEndpoints = structOrError.getStruct(); + int priority = localityLbEndpoints.priority(); + maxPriority = Math.max(maxPriority, priority); + // Note endpoints with health status other than HEALTHY and UNKNOWN are still + // handed over to watching parties. It is watching parties' responsibility to + // filter out unhealthy endpoints. See EnvoyProtoData.LbEndpoint#isHealthy(). + Locality locality = parseLocality(localityLbEndpointsProto.getLocality()); + localityLbEndpointsMap.put(locality, localityLbEndpoints); + if (!priorities.containsKey(priority)) { + priorities.put(priority, new HashSet<>()); + } + if (!priorities.get(priority).add(locality)) { + throw new ResourceInvalidException("ClusterLoadAssignment has duplicate locality:" + + locality + " for priority:" + priority); + } + } + if (priorities.size() != maxPriority + 1) { + throw new ResourceInvalidException("ClusterLoadAssignment has sparse priorities"); + } + + for (ClusterLoadAssignment.Policy.DropOverload dropOverloadProto + : assignment.getPolicy().getDropOverloadsList()) { + dropOverloads.add(parseDropOverload(dropOverloadProto)); + } + return new EdsUpdate(assignment.getClusterName(), localityLbEndpointsMap, dropOverloads); + } + + private static Locality parseLocality(io.envoyproxy.envoy.config.core.v3.Locality proto) { + return Locality.create(proto.getRegion(), proto.getZone(), proto.getSubZone()); + } + + private static DropOverload parseDropOverload( + io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment.Policy.DropOverload proto) { + return DropOverload.create(proto.getCategory(), getRatePerMillion(proto.getDropPercentage())); + } + + private static int getRatePerMillion(FractionalPercent percent) { + int numerator = percent.getNumerator(); + FractionalPercent.DenominatorType type = percent.getDenominator(); + switch (type) { + case TEN_THOUSAND: + numerator *= 100; + break; + case HUNDRED: + numerator *= 10_000; + break; + case MILLION: + break; + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown denominator type of " + percent); + } + + if (numerator > 1_000_000 || numerator < 0) { + numerator = 1_000_000; + } + return numerator; + } + + + @VisibleForTesting + @Nullable + static StructOrError parseLocalityLbEndpoints( + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto) { + // Filter out localities without or with 0 weight. + if (!proto.hasLoadBalancingWeight() || proto.getLoadBalancingWeight().getValue() < 1) { + return null; + } + if (proto.getPriority() < 0) { + return StructOrError.fromError("negative priority"); + } + List endpoints = new ArrayList<>(proto.getLbEndpointsCount()); + for (io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint endpoint : proto.getLbEndpointsList()) { + // The endpoint field of each lb_endpoints must be set. + // Inside of it: the address field must be set. + if (!endpoint.hasEndpoint() || !endpoint.getEndpoint().hasAddress()) { + return StructOrError.fromError("LbEndpoint with no endpoint/address"); + } + io.envoyproxy.envoy.config.core.v3.SocketAddress socketAddress = + endpoint.getEndpoint().getAddress().getSocketAddress(); + InetSocketAddress addr = + new InetSocketAddress(socketAddress.getAddress(), socketAddress.getPortValue()); + boolean isHealthy = + endpoint.getHealthStatus() == io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY + || endpoint.getHealthStatus() + == io.envoyproxy.envoy.config.core.v3.HealthStatus.UNKNOWN; + endpoints.add(Endpoints.LbEndpoint.create( + new EquivalentAddressGroup(ImmutableList.of(addr)), + endpoint.getLoadBalancingWeight().getValue(), isHealthy)); + } + return StructOrError.fromStruct(Endpoints.LocalityLbEndpoints.create( + endpoints, proto.getLoadBalancingWeight().getValue(), proto.getPriority())); + } + + static final class EdsUpdate implements ResourceUpdate { + final String clusterName; + final Map localityLbEndpointsMap; + final List dropPolicies; + + EdsUpdate(String clusterName, Map localityLbEndpoints, + List dropPolicies) { + this.clusterName = checkNotNull(clusterName, "clusterName"); + this.localityLbEndpointsMap = Collections.unmodifiableMap( + new LinkedHashMap<>(checkNotNull(localityLbEndpoints, "localityLbEndpoints"))); + this.dropPolicies = Collections.unmodifiableList( + new ArrayList<>(checkNotNull(dropPolicies, "dropPolicies"))); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + EdsUpdate that = (EdsUpdate) o; + return Objects.equals(clusterName, that.clusterName) + && Objects.equals(localityLbEndpointsMap, that.localityLbEndpointsMap) + && Objects.equals(dropPolicies, that.dropPolicies); + } + + @Override + public int hashCode() { + return Objects.hash(clusterName, localityLbEndpointsMap, dropPolicies); + } + + @Override + public String toString() { + return + MoreObjects + .toStringHelper(this) + .add("clusterName", clusterName) + .add("localityLbEndpointsMap", localityLbEndpointsMap) + .add("dropPolicies", dropPolicies) + .toString(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java b/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java index b11c7853473..dcca2fbfff3 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java +++ b/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java @@ -23,6 +23,7 @@ final class XdsLbPolicies { static final String PRIORITY_POLICY_NAME = "priority_experimental"; static final String CLUSTER_IMPL_POLICY_NAME = "cluster_impl_experimental"; static final String WEIGHTED_TARGET_POLICY_NAME = "weighted_target_experimental"; + static final String WRR_LOCALITY_POLICY_NAME = "wrr_locality_experimental"; private XdsLbPolicies() {} } diff --git a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java new file mode 100644 index 00000000000..789f78ba5b7 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java @@ -0,0 +1,614 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.XdsClient.ResourceUpdate; +import static io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import static io.grpc.xds.XdsClusterResource.validateCommonTlsContext; +import static io.grpc.xds.XdsRouteConfigureResource.extractVirtualHosts; + +import com.github.udpa.udpa.type.v1.TypedStruct; +import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TrafficDirection; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; +import io.grpc.xds.EnvoyServerProtoData.CidrRange; +import io.grpc.xds.EnvoyServerProtoData.ConnectionSourceType; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +class XdsListenerResource extends XdsResourceType { + static final String ADS_TYPE_URL_LDS = + "type.googleapis.com/envoy.config.listener.v3.Listener"; + static final String TYPE_URL_HTTP_CONNECTION_MANAGER = + "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3" + + ".HttpConnectionManager"; + private static final String TRANSPORT_SOCKET_NAME_TLS = "envoy.transport_sockets.tls"; + private static final XdsListenerResource instance = new XdsListenerResource(); + + public static XdsListenerResource getInstance() { + return instance; + } + + @Override + @Nullable + String extractResourceName(Message unpackedResource) { + if (!(unpackedResource instanceof Listener)) { + return null; + } + return ((Listener) unpackedResource).getName(); + } + + @Override + String typeName() { + return "LDS"; + } + + @Override + Class unpackedClassName() { + return Listener.class; + } + + @Override + String typeUrl() { + return ADS_TYPE_URL_LDS; + } + + @Override + boolean isFullStateOfTheWorld() { + return true; + } + + @Override + LdsUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + if (!(unpackedMessage instanceof Listener)) { + throw new ResourceInvalidException("Invalid message type: " + unpackedMessage.getClass()); + } + Listener listener = (Listener) unpackedMessage; + + if (listener.hasApiListener()) { + return processClientSideListener( + listener, args, enableFaultInjection); + } else { + return processServerSideListener( + listener, args, enableRbac); + } + } + + private LdsUpdate processClientSideListener(Listener listener, Args args, boolean parseHttpFilter) + throws ResourceInvalidException { + // Unpack HttpConnectionManager from the Listener. + HttpConnectionManager hcm; + try { + hcm = unpackCompatibleType( + listener.getApiListener().getApiListener(), HttpConnectionManager.class, + TYPE_URL_HTTP_CONNECTION_MANAGER, null); + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException( + "Could not parse HttpConnectionManager config from ApiListener", e); + } + return LdsUpdate.forApiListener(parseHttpConnectionManager( + hcm, args.filterRegistry, parseHttpFilter, true /* isForClient */)); + } + + private LdsUpdate processServerSideListener(Listener proto, Args args, boolean parseHttpFilter) + throws ResourceInvalidException { + Set certProviderInstances = null; + if (args.bootstrapInfo != null && args.bootstrapInfo.certProviders() != null) { + certProviderInstances = args.bootstrapInfo.certProviders().keySet(); + } + return LdsUpdate.forTcpListener(parseServerSideListener(proto, args.tlsContextManager, + args.filterRegistry, certProviderInstances, parseHttpFilter)); + } + + @VisibleForTesting + static EnvoyServerProtoData.Listener parseServerSideListener( + Listener proto, TlsContextManager tlsContextManager, + FilterRegistry filterRegistry, Set certProviderInstances, boolean parseHttpFilter) + throws ResourceInvalidException { + if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND) + && !proto.getTrafficDirection().equals(TrafficDirection.UNSPECIFIED)) { + throw new ResourceInvalidException( + "Listener " + proto.getName() + " with invalid traffic direction: " + + proto.getTrafficDirection()); + } + if (!proto.getListenerFiltersList().isEmpty()) { + throw new ResourceInvalidException( + "Listener " + proto.getName() + " cannot have listener_filters"); + } + if (proto.hasUseOriginalDst()) { + throw new ResourceInvalidException( + "Listener " + proto.getName() + " cannot have use_original_dst set to true"); + } + + String address = null; + if (proto.getAddress().hasSocketAddress()) { + SocketAddress socketAddress = proto.getAddress().getSocketAddress(); + address = socketAddress.getAddress(); + switch (socketAddress.getPortSpecifierCase()) { + case NAMED_PORT: + address = address + ":" + socketAddress.getNamedPort(); + break; + case PORT_VALUE: + address = address + ":" + socketAddress.getPortValue(); + break; + default: + // noop + } + } + + ImmutableList.Builder filterChains = ImmutableList.builder(); + Set uniqueSet = new HashSet<>(); + for (io.envoyproxy.envoy.config.listener.v3.FilterChain fc : proto.getFilterChainsList()) { + filterChains.add( + parseFilterChain(fc, tlsContextManager, filterRegistry, uniqueSet, + certProviderInstances, parseHttpFilter)); + } + FilterChain defaultFilterChain = null; + if (proto.hasDefaultFilterChain()) { + defaultFilterChain = parseFilterChain( + proto.getDefaultFilterChain(), tlsContextManager, filterRegistry, + null, certProviderInstances, parseHttpFilter); + } + + return EnvoyServerProtoData.Listener.create( + proto.getName(), address, filterChains.build(), defaultFilterChain); + } + + @VisibleForTesting + static FilterChain parseFilterChain( + io.envoyproxy.envoy.config.listener.v3.FilterChain proto, + TlsContextManager tlsContextManager, FilterRegistry filterRegistry, + Set uniqueSet, Set certProviderInstances, boolean parseHttpFilters) + throws ResourceInvalidException { + if (proto.getFiltersCount() != 1) { + throw new ResourceInvalidException("FilterChain " + proto.getName() + + " should contain exact one HttpConnectionManager filter"); + } + io.envoyproxy.envoy.config.listener.v3.Filter filter = proto.getFiltersList().get(0); + if (!filter.hasTypedConfig()) { + throw new ResourceInvalidException( + "FilterChain " + proto.getName() + " contains filter " + filter.getName() + + " without typed_config"); + } + Any any = filter.getTypedConfig(); + // HttpConnectionManager is the only supported network filter at the moment. + if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { + throw new ResourceInvalidException( + "FilterChain " + proto.getName() + " contains filter " + filter.getName() + + " with unsupported typed_config type " + any.getTypeUrl()); + } + HttpConnectionManager hcmProto; + try { + hcmProto = any.unpack(HttpConnectionManager.class); + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " + + filter.getName() + " failed to unpack message", e); + } + io.grpc.xds.HttpConnectionManager httpConnectionManager = parseHttpConnectionManager( + hcmProto, filterRegistry, parseHttpFilters, false /* isForClient */); + + EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = null; + if (proto.hasTransportSocket()) { + if (!TRANSPORT_SOCKET_NAME_TLS.equals(proto.getTransportSocket().getName())) { + throw new ResourceInvalidException("transport-socket with name " + + proto.getTransportSocket().getName() + " not supported."); + } + DownstreamTlsContext downstreamTlsContextProto; + try { + downstreamTlsContextProto = + proto.getTransportSocket().getTypedConfig().unpack(DownstreamTlsContext.class); + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException("FilterChain " + proto.getName() + + " failed to unpack message", e); + } + downstreamTlsContext = + EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( + validateDownstreamTlsContext(downstreamTlsContextProto, certProviderInstances)); + } + + FilterChainMatch filterChainMatch = parseFilterChainMatch(proto.getFilterChainMatch()); + checkForUniqueness(uniqueSet, filterChainMatch); + return FilterChain.create( + proto.getName(), + filterChainMatch, + httpConnectionManager, + downstreamTlsContext, + tlsContextManager + ); + } + + @VisibleForTesting + static DownstreamTlsContext validateDownstreamTlsContext( + DownstreamTlsContext downstreamTlsContext, Set certProviderInstances) + throws ResourceInvalidException { + if (downstreamTlsContext.hasCommonTlsContext()) { + validateCommonTlsContext(downstreamTlsContext.getCommonTlsContext(), certProviderInstances, + true); + } else { + throw new ResourceInvalidException( + "common-tls-context is required in downstream-tls-context"); + } + if (downstreamTlsContext.hasRequireSni()) { + throw new ResourceInvalidException( + "downstream-tls-context with require-sni is not supported"); + } + DownstreamTlsContext.OcspStaplePolicy ocspStaplePolicy = downstreamTlsContext + .getOcspStaplePolicy(); + if (ocspStaplePolicy != DownstreamTlsContext.OcspStaplePolicy.UNRECOGNIZED + && ocspStaplePolicy != DownstreamTlsContext.OcspStaplePolicy.LENIENT_STAPLING) { + throw new ResourceInvalidException( + "downstream-tls-context with ocsp_staple_policy value " + ocspStaplePolicy.name() + + " is not supported"); + } + return downstreamTlsContext; + } + + private static void checkForUniqueness(Set uniqueSet, + FilterChainMatch filterChainMatch) throws ResourceInvalidException { + if (uniqueSet != null) { + List crossProduct = getCrossProduct(filterChainMatch); + for (FilterChainMatch cur : crossProduct) { + if (!uniqueSet.add(cur)) { + throw new ResourceInvalidException("FilterChainMatch must be unique. " + + "Found duplicate: " + cur); + } + } + } + } + + private static List getCrossProduct(FilterChainMatch filterChainMatch) { + // repeating fields to process: + // prefixRanges, applicationProtocols, sourcePrefixRanges, sourcePorts, serverNames + List expandedList = expandOnPrefixRange(filterChainMatch); + expandedList = expandOnApplicationProtocols(expandedList); + expandedList = expandOnSourcePrefixRange(expandedList); + expandedList = expandOnSourcePorts(expandedList); + return expandOnServerNames(expandedList); + } + + private static List expandOnPrefixRange(FilterChainMatch filterChainMatch) { + ArrayList expandedList = new ArrayList<>(); + if (filterChainMatch.prefixRanges().isEmpty()) { + expandedList.add(filterChainMatch); + } else { + for (CidrRange cidrRange : filterChainMatch.prefixRanges()) { + expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(), + ImmutableList.of(cidrRange), + filterChainMatch.applicationProtocols(), + filterChainMatch.sourcePrefixRanges(), + filterChainMatch.connectionSourceType(), + filterChainMatch.sourcePorts(), + filterChainMatch.serverNames(), + filterChainMatch.transportProtocol())); + } + } + return expandedList; + } + + private static List expandOnApplicationProtocols( + Collection set) { + ArrayList expandedList = new ArrayList<>(); + for (FilterChainMatch filterChainMatch : set) { + if (filterChainMatch.applicationProtocols().isEmpty()) { + expandedList.add(filterChainMatch); + } else { + for (String applicationProtocol : filterChainMatch.applicationProtocols()) { + expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(), + filterChainMatch.prefixRanges(), + ImmutableList.of(applicationProtocol), + filterChainMatch.sourcePrefixRanges(), + filterChainMatch.connectionSourceType(), + filterChainMatch.sourcePorts(), + filterChainMatch.serverNames(), + filterChainMatch.transportProtocol())); + } + } + } + return expandedList; + } + + private static List expandOnSourcePrefixRange( + Collection set) { + ArrayList expandedList = new ArrayList<>(); + for (FilterChainMatch filterChainMatch : set) { + if (filterChainMatch.sourcePrefixRanges().isEmpty()) { + expandedList.add(filterChainMatch); + } else { + for (EnvoyServerProtoData.CidrRange cidrRange : filterChainMatch.sourcePrefixRanges()) { + expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(), + filterChainMatch.prefixRanges(), + filterChainMatch.applicationProtocols(), + ImmutableList.of(cidrRange), + filterChainMatch.connectionSourceType(), + filterChainMatch.sourcePorts(), + filterChainMatch.serverNames(), + filterChainMatch.transportProtocol())); + } + } + } + return expandedList; + } + + private static List expandOnSourcePorts(Collection set) { + ArrayList expandedList = new ArrayList<>(); + for (FilterChainMatch filterChainMatch : set) { + if (filterChainMatch.sourcePorts().isEmpty()) { + expandedList.add(filterChainMatch); + } else { + for (Integer sourcePort : filterChainMatch.sourcePorts()) { + expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(), + filterChainMatch.prefixRanges(), + filterChainMatch.applicationProtocols(), + filterChainMatch.sourcePrefixRanges(), + filterChainMatch.connectionSourceType(), + ImmutableList.of(sourcePort), + filterChainMatch.serverNames(), + filterChainMatch.transportProtocol())); + } + } + } + return expandedList; + } + + private static List expandOnServerNames(Collection set) { + ArrayList expandedList = new ArrayList<>(); + for (FilterChainMatch filterChainMatch : set) { + if (filterChainMatch.serverNames().isEmpty()) { + expandedList.add(filterChainMatch); + } else { + for (String serverName : filterChainMatch.serverNames()) { + expandedList.add(FilterChainMatch.create(filterChainMatch.destinationPort(), + filterChainMatch.prefixRanges(), + filterChainMatch.applicationProtocols(), + filterChainMatch.sourcePrefixRanges(), + filterChainMatch.connectionSourceType(), + filterChainMatch.sourcePorts(), + ImmutableList.of(serverName), + filterChainMatch.transportProtocol())); + } + } + } + return expandedList; + } + + private static FilterChainMatch parseFilterChainMatch( + io.envoyproxy.envoy.config.listener.v3.FilterChainMatch proto) + throws ResourceInvalidException { + ImmutableList.Builder prefixRanges = ImmutableList.builder(); + ImmutableList.Builder sourcePrefixRanges = ImmutableList.builder(); + try { + for (io.envoyproxy.envoy.config.core.v3.CidrRange range : proto.getPrefixRangesList()) { + prefixRanges.add( + CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue())); + } + for (io.envoyproxy.envoy.config.core.v3.CidrRange range + : proto.getSourcePrefixRangesList()) { + sourcePrefixRanges.add( + CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue())); + } + } catch (UnknownHostException e) { + throw new ResourceInvalidException("Failed to create CidrRange", e); + } + ConnectionSourceType sourceType; + switch (proto.getSourceType()) { + case ANY: + sourceType = ConnectionSourceType.ANY; + break; + case EXTERNAL: + sourceType = ConnectionSourceType.EXTERNAL; + break; + case SAME_IP_OR_LOOPBACK: + sourceType = ConnectionSourceType.SAME_IP_OR_LOOPBACK; + break; + default: + throw new ResourceInvalidException("Unknown source-type: " + proto.getSourceType()); + } + return FilterChainMatch.create( + proto.getDestinationPort().getValue(), + prefixRanges.build(), + ImmutableList.copyOf(proto.getApplicationProtocolsList()), + sourcePrefixRanges.build(), + sourceType, + ImmutableList.copyOf(proto.getSourcePortsList()), + ImmutableList.copyOf(proto.getServerNamesList()), + proto.getTransportProtocol()); + } + + @VisibleForTesting + static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( + HttpConnectionManager proto, FilterRegistry filterRegistry, + boolean parseHttpFilter, boolean isForClient) throws ResourceInvalidException { + if (enableRbac && proto.getXffNumTrustedHops() != 0) { + throw new ResourceInvalidException( + "HttpConnectionManager with xff_num_trusted_hops unsupported"); + } + if (enableRbac && !proto.getOriginalIpDetectionExtensionsList().isEmpty()) { + throw new ResourceInvalidException("HttpConnectionManager with " + + "original_ip_detection_extensions unsupported"); + } + // Obtain max_stream_duration from Http Protocol Options. + long maxStreamDuration = 0; + if (proto.hasCommonHttpProtocolOptions()) { + HttpProtocolOptions options = proto.getCommonHttpProtocolOptions(); + if (options.hasMaxStreamDuration()) { + maxStreamDuration = Durations.toNanos(options.getMaxStreamDuration()); + } + } + + // Parse http filters. + List filterConfigs = null; + if (parseHttpFilter) { + if (proto.getHttpFiltersList().isEmpty()) { + throw new ResourceInvalidException("Missing HttpFilter in HttpConnectionManager."); + } + filterConfigs = new ArrayList<>(); + Set names = new HashSet<>(); + for (int i = 0; i < proto.getHttpFiltersCount(); i++) { + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter + httpFilter = proto.getHttpFiltersList().get(i); + String filterName = httpFilter.getName(); + if (!names.add(filterName)) { + throw new ResourceInvalidException( + "HttpConnectionManager contains duplicate HttpFilter: " + filterName); + } + StructOrError filterConfig = + parseHttpFilter(httpFilter, filterRegistry, isForClient); + if ((i == proto.getHttpFiltersCount() - 1) + && (filterConfig == null || !isTerminalFilter(filterConfig.getStruct()))) { + throw new ResourceInvalidException("The last HttpFilter must be a terminal filter: " + + filterName); + } + if (filterConfig == null) { + continue; + } + if (filterConfig.getErrorDetail() != null) { + throw new ResourceInvalidException( + "HttpConnectionManager contains invalid HttpFilter: " + + filterConfig.getErrorDetail()); + } + if ((i < proto.getHttpFiltersCount() - 1) && isTerminalFilter(filterConfig.getStruct())) { + throw new ResourceInvalidException("A terminal HttpFilter must be the last filter: " + + filterName); + } + filterConfigs.add(new Filter.NamedFilterConfig(filterName, filterConfig.getStruct())); + } + } + + // Parse inlined RouteConfiguration or RDS. + if (proto.hasRouteConfig()) { + List virtualHosts = extractVirtualHosts( + proto.getRouteConfig(), filterRegistry, parseHttpFilter); + return io.grpc.xds.HttpConnectionManager.forVirtualHosts( + maxStreamDuration, virtualHosts, filterConfigs); + } + if (proto.hasRds()) { + Rds rds = proto.getRds(); + if (!rds.hasConfigSource()) { + throw new ResourceInvalidException( + "HttpConnectionManager contains invalid RDS: missing config_source"); + } + if (!rds.getConfigSource().hasAds() && !rds.getConfigSource().hasSelf()) { + throw new ResourceInvalidException( + "HttpConnectionManager contains invalid RDS: must specify ADS or self ConfigSource"); + } + return io.grpc.xds.HttpConnectionManager.forRdsName( + maxStreamDuration, rds.getRouteConfigName(), filterConfigs); + } + throw new ResourceInvalidException( + "HttpConnectionManager neither has inlined route_config nor RDS"); + } + + // hard-coded: currently router config is the only terminal filter. + private static boolean isTerminalFilter(Filter.FilterConfig filterConfig) { + return RouterFilter.ROUTER_CONFIG.equals(filterConfig); + } + + @VisibleForTesting + @Nullable // Returns null if the filter is optional but not supported. + static StructOrError parseHttpFilter( + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter + httpFilter, FilterRegistry filterRegistry, boolean isForClient) { + String filterName = httpFilter.getName(); + boolean isOptional = httpFilter.getIsOptional(); + if (!httpFilter.hasTypedConfig()) { + if (isOptional) { + return null; + } else { + return StructOrError.fromError( + "HttpFilter [" + filterName + "] is not optional and has no typed config"); + } + } + Message rawConfig = httpFilter.getTypedConfig(); + String typeUrl = httpFilter.getTypedConfig().getTypeUrl(); + + try { + if (typeUrl.equals(TYPE_URL_TYPED_STRUCT_UDPA)) { + TypedStruct typedStruct = httpFilter.getTypedConfig().unpack(TypedStruct.class); + typeUrl = typedStruct.getTypeUrl(); + rawConfig = typedStruct.getValue(); + } else if (typeUrl.equals(TYPE_URL_TYPED_STRUCT)) { + com.github.xds.type.v3.TypedStruct newTypedStruct = + httpFilter.getTypedConfig().unpack(com.github.xds.type.v3.TypedStruct.class); + typeUrl = newTypedStruct.getTypeUrl(); + rawConfig = newTypedStruct.getValue(); + } + } catch (InvalidProtocolBufferException e) { + return StructOrError.fromError( + "HttpFilter [" + filterName + "] contains invalid proto: " + e); + } + Filter filter = filterRegistry.get(typeUrl); + if ((isForClient && !(filter instanceof Filter.ClientInterceptorBuilder)) + || (!isForClient && !(filter instanceof Filter.ServerInterceptorBuilder))) { + if (isOptional) { + return null; + } else { + return StructOrError.fromError( + "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " + + (isForClient ? "client" : "server")); + } + } + ConfigOrError filterConfig = filter.parseFilterConfig(rawConfig); + if (filterConfig.errorDetail != null) { + return StructOrError.fromError( + "Invalid filter config for HttpFilter [" + filterName + "]: " + filterConfig.errorDetail); + } + return StructOrError.fromStruct(filterConfig.config); + } + + @AutoValue + abstract static class LdsUpdate implements ResourceUpdate { + // Http level api listener configuration. + @Nullable + abstract io.grpc.xds.HttpConnectionManager httpConnectionManager(); + + // Tcp level listener configuration. + @Nullable + abstract EnvoyServerProtoData.Listener listener(); + + static LdsUpdate forApiListener(io.grpc.xds.HttpConnectionManager httpConnectionManager) { + checkNotNull(httpConnectionManager, "httpConnectionManager"); + return new AutoValue_XdsListenerResource_LdsUpdate(httpConnectionManager, null); + } + + static LdsUpdate forTcpListener(EnvoyServerProtoData.Listener listener) { + checkNotNull(listener, "listener"); + return new AutoValue_XdsListenerResource_LdsUpdate(null, listener); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsLogger.java b/xds/src/main/java/io/grpc/xds/XdsLogger.java index 616602536f9..7bcf190bb18 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLogger.java +++ b/xds/src/main/java/io/grpc/xds/XdsLogger.java @@ -81,6 +81,10 @@ private static Level toJavaLogLevel(XdsLogLevel level) { return Level.FINE; case INFO: return Level.FINER; + case FORCE_INFO: + return Level.INFO; + case FORCE_WARNING: + return Level.WARNING; default: return Level.FINEST; } @@ -89,6 +93,11 @@ private static Level toJavaLogLevel(XdsLogLevel level) { /** * Log levels. See the table below for the mapping from the XdsLogger levels to * Java logger levels. + * + *

    NOTE: + * Please use {@code FORCE_} levels with care, only when the message is expected to be + * surfaced to the library user. Normally libraries should minimize the usage + * of highly visible logs. *

        * +---------------------+-------------------+
        * | XdsLogger Level     | Java Logger Level |
    @@ -97,6 +106,8 @@ private static Level toJavaLogLevel(XdsLogLevel level) {
        * | INFO                | FINER             |
        * | WARNING             | FINE              |
        * | ERROR               | FINE              |
    +   * | FORCE_INFO          | INFO              |
    +   * | FORCE_WARNING       | WARNING           |
        * +---------------------+-------------------+
        * 
    */ @@ -104,6 +115,8 @@ enum XdsLogLevel { DEBUG, INFO, WARNING, - ERROR + ERROR, + FORCE_INFO, + FORCE_WARNING, } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index b15af622256..094bb944d85 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -47,7 +47,6 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.AuthorityInfo; import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; @@ -63,13 +62,12 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; -import io.grpc.xds.XdsClient.LdsResourceWatcher; -import io.grpc.xds.XdsClient.LdsUpdate; -import io.grpc.xds.XdsClient.RdsResourceWatcher; -import io.grpc.xds.XdsClient.RdsUpdate; +import io.grpc.xds.XdsClient.ResourceWatcher; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.internal.Matchers.FractionMatcher; import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.util.ArrayList; @@ -111,6 +109,7 @@ final class XdsNameResolver extends NameResolver { @Nullable private final String targetAuthority; private final String serviceAuthority; + private final String overrideAuthority; private final ServiceConfigParser serviceConfigParser; private final SynchronizationContext syncContext; private final ScheduledExecutorService scheduler; @@ -122,6 +121,7 @@ final class XdsNameResolver extends NameResolver { // put()/remove() must be called in SyncContext, and get() can be called in any thread. private final ConcurrentMap clusterRefs = new ConcurrentHashMap<>(); private final ConfigSelector configSelector = new ConfigSelector(); + private final long randomChannelId; private volatile RoutingConfig routingConfig = RoutingConfig.empty; private Listener2 listener; @@ -129,24 +129,30 @@ final class XdsNameResolver extends NameResolver { private XdsClient xdsClient; private CallCounterProvider callCounterProvider; private ResolveState resolveState; + // Workaround for https://github.com/grpc/grpc-java/issues/8886 . This should be handled in + // XdsClient instead of here. + private boolean receivedConfig; XdsNameResolver( - @Nullable String targetAuthority, String name, ServiceConfigParser serviceConfigParser, + @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, + ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, @Nullable Map bootstrapOverride) { - this(targetAuthority, name, serviceConfigParser, syncContext, scheduler, + this(targetAuthority, name, overrideAuthority, serviceConfigParser, syncContext, scheduler, SharedXdsClientPoolProvider.getDefaultProvider(), ThreadSafeRandomImpl.instance, FilterRegistry.getDefaultRegistry(), bootstrapOverride); } @VisibleForTesting XdsNameResolver( - @Nullable String targetAuthority, String name, ServiceConfigParser serviceConfigParser, + @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, + ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, XdsClientPoolFactory xdsClientPoolFactory, ThreadSafeRandom random, FilterRegistry filterRegistry, @Nullable Map bootstrapOverride) { this.targetAuthority = targetAuthority; serviceAuthority = GrpcUtil.checkAuthority(checkNotNull(name, "name")); + this.overrideAuthority = overrideAuthority; this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.scheduler = checkNotNull(scheduler, "scheduler"); @@ -155,6 +161,7 @@ final class XdsNameResolver extends NameResolver { this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); this.random = checkNotNull(random, "random"); this.filterRegistry = checkNotNull(filterRegistry, "filterRegistry"); + randomChannelId = random.nextLong(); logId = InternalLogId.allocate("xds-resolver", name); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created resolver for {0}", name); @@ -194,8 +201,8 @@ public void start(Listener2 listener) { replacement = XdsClient.percentEncodePath(replacement); } String ldsResourceName = expandPercentS(listenerNameTemplate, replacement); - if (!XdsClient.isResourceNameValid(ldsResourceName, ResourceType.LDS.typeUrl()) - && !XdsClient.isResourceNameValid(ldsResourceName, ResourceType.LDS.typeUrlV2())) { + if (!XdsClient.isResourceNameValid(ldsResourceName, XdsListenerResource.getInstance().typeUrl()) + ) { listener.onError(Status.INVALID_ARGUMENT.withDescription( "invalid listener resource URI for service authority: " + serviceAuthority)); return; @@ -247,14 +254,14 @@ public void shutdown() { rawRetryPolicy.put( "perAttemptRecvTimeout", Durations.toString(retryPolicy.perAttemptRecvTimeout())); } - methodConfig.put("retryPolicy", rawRetryPolicy.build()); + methodConfig.put("retryPolicy", rawRetryPolicy.buildOrThrow()); } if (timeoutNano != null) { String timeout = timeoutNano / 1_000_000_000.0 + "s"; methodConfig.put("timeout", timeout); } return Collections.singletonMap( - "methodConfig", Collections.singletonList(methodConfig.build())); + "methodConfig", Collections.singletonList(methodConfig.buildOrThrow())); } @VisibleForTesting @@ -274,7 +281,8 @@ private void updateResolutionResult() { Map rawServiceConfig = ImmutableMap.of( "loadBalancingConfig", ImmutableList.of(ImmutableMap.of( - "cluster_manager_experimental", ImmutableMap.of("childPolicy", childPolicy.build())))); + XdsLbPolicies.CLUSTER_MANAGER_POLICY_NAME, + ImmutableMap.of("childPolicy", childPolicy.buildOrThrow())))); if (logger.isLoggable(XdsLogLevel.INFO)) { logger.log( @@ -293,6 +301,7 @@ private void updateResolutionResult() { .setServiceConfig(parsedServiceConfig) .build(); listener.onResult(result); + receivedConfig = true; } @VisibleForTesting @@ -577,7 +586,7 @@ private long generateHash(List hashPolicies, Metadata headers) { newHash = hashFunc.hashAsciiString(value); } } else if (policy.type() == HashPolicy.Type.CHANNEL_ID) { - newHash = hashFunc.hashLong(logId.getId()); + newHash = hashFunc.hashLong(randomChannelId); } if (newHash != null ) { // Rotating the old value prevents duplicate hash rules from cancelling each other out @@ -664,14 +673,22 @@ private static String prefixedClusterSpecifierPluginName(String pluginName) { return "cluster_specifier_plugin:" + pluginName; } - private class ResolveState implements LdsResourceWatcher { + private static final class FailingConfigSelector extends InternalConfigSelector { + private final Result result; + + public FailingConfigSelector(Status error) { + this.result = Result.forError(error); + } + + @Override + public Result selectConfig(PickSubchannelArgs args) { + return result; + } + } + + private class ResolveState implements ResourceWatcher { private final ConfigOrError emptyServiceConfig = serviceConfigParser.parseServiceConfig(Collections.emptyMap()); - private final ResolutionResult emptyResult = - ResolutionResult.newBuilder() - .setServiceConfig(emptyServiceConfig) - // let channel take action for no config selector - .build(); private final String ldsResourceName; private boolean stopped; @Nullable @@ -704,7 +721,8 @@ public void run() { rdsName, httpConnectionManager.httpMaxStreamDurationNano(), httpConnectionManager.httpFilterConfigs()); logger.log(XdsLogLevel.INFO, "Start watching RDS resource {0}", rdsName); - xdsClient.watchRdsResource(rdsName, routeDiscoveryState); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + rdsName, routeDiscoveryState); } } }); @@ -715,10 +733,12 @@ public void onError(final Status error) { syncContext.execute(new Runnable() { @Override public void run() { - if (stopped) { + if (stopped || receivedConfig) { return; } - listener.onError(error); + listener.onError(Status.UNAVAILABLE.withCause(error.getCause()).withDescription( + String.format("Unable to load LDS %s. xDS server returned: %s: %s", + ldsResourceName, error.getCode(), error.getDescription()))); } }); } @@ -731,33 +751,35 @@ public void run() { if (stopped) { return; } - logger.log(XdsLogLevel.INFO, "LDS resource {0} unavailable", resourceName); + String error = "LDS resource does not exist: " + resourceName; + logger.log(XdsLogLevel.INFO, error); cleanUpRouteDiscoveryState(); - cleanUpRoutes(); + cleanUpRoutes(error); } }); } private void start() { logger.log(XdsLogLevel.INFO, "Start watching LDS resource {0}", ldsResourceName); - xdsClient.watchLdsResource(ldsResourceName, this); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceName, this); } private void stop() { logger.log(XdsLogLevel.INFO, "Stop watching LDS resource {0}", ldsResourceName); stopped = true; cleanUpRouteDiscoveryState(); - xdsClient.cancelLdsResourceWatch(ldsResourceName, this); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), ldsResourceName, this); } // called in syncContext private void updateRoutes(List virtualHosts, long httpMaxStreamDurationNano, @Nullable List filterConfigs) { - VirtualHost virtualHost = findVirtualHostForHostName(virtualHosts, ldsResourceName); + String authority = overrideAuthority != null ? overrideAuthority : ldsResourceName; + VirtualHost virtualHost = findVirtualHostForHostName(virtualHosts, authority); if (virtualHost == null) { - logger.log(XdsLogLevel.WARNING, - "Failed to find virtual host matching hostname {0}", ldsResourceName); - cleanUpRoutes(); + String error = "Failed to find virtual host matching hostname: " + authority; + logger.log(XdsLogLevel.WARNING, error); + cleanUpRoutes(error); return; } @@ -801,7 +823,7 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura existingClusters == null ? clusters : Sets.difference(clusters, existingClusters); Set deletedClusters = existingClusters == null - ? Collections.emptySet() : Sets.difference(existingClusters, clusters); + ? Collections.emptySet() : Sets.difference(existingClusters, clusters); existingClusters = clusters; for (String cluster : addedClusters) { if (clusterRefs.containsKey(cluster)) { @@ -853,7 +875,7 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura } } - private void cleanUpRoutes() { + private void cleanUpRoutes(String error) { if (existingClusters != null) { for (String cluster : existingClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); @@ -864,14 +886,26 @@ private void cleanUpRoutes() { existingClusters = null; } routingConfig = RoutingConfig.empty; - listener.onResult(emptyResult); + // Without addresses the default LB (normally pick_first) should become TRANSIENT_FAILURE, and + // the config selector handles the error message itself. Once the LB API allows providing + // failure information for addresses yet still providing a service config, the config seector + // could be avoided. + listener.onResult(ResolutionResult.newBuilder() + .setAttributes(Attributes.newBuilder() + .set(InternalConfigSelector.KEY, + new FailingConfigSelector(Status.UNAVAILABLE.withDescription(error))) + .build()) + .setServiceConfig(emptyServiceConfig) + .build()); + receivedConfig = true; } private void cleanUpRouteDiscoveryState() { if (routeDiscoveryState != null) { String rdsName = routeDiscoveryState.resourceName; logger.log(XdsLogLevel.INFO, "Stop watching RDS resource {0}", rdsName); - xdsClient.cancelRdsResourceWatch(rdsName, routeDiscoveryState); + xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), rdsName, + routeDiscoveryState); routeDiscoveryState = null; } } @@ -880,7 +914,7 @@ private void cleanUpRouteDiscoveryState() { * Discovery state for RouteConfiguration resource. One instance for each Listener resource * update. */ - private class RouteDiscoveryState implements RdsResourceWatcher { + private class RouteDiscoveryState implements ResourceWatcher { private final String resourceName; private final long httpMaxStreamDurationNano; @Nullable @@ -902,7 +936,8 @@ public void run() { return; } logger.log(XdsLogLevel.INFO, "Received RDS resource update: {0}", update); - updateRoutes(update.virtualHosts, httpMaxStreamDurationNano, filterConfigs); + updateRoutes(update.virtualHosts, httpMaxStreamDurationNano, + filterConfigs); } }); } @@ -912,10 +947,12 @@ public void onError(final Status error) { syncContext.execute(new Runnable() { @Override public void run() { - if (RouteDiscoveryState.this != routeDiscoveryState) { + if (RouteDiscoveryState.this != routeDiscoveryState || receivedConfig) { return; } - listener.onError(error); + listener.onError(Status.UNAVAILABLE.withCause(error.getCause()).withDescription( + String.format("Unable to load RDS %s. xDS server returned: %s: %s", + resourceName, error.getCode(), error.getDescription()))); } }); } @@ -928,8 +965,9 @@ public void run() { if (RouteDiscoveryState.this != routeDiscoveryState) { return; } - logger.log(XdsLogLevel.INFO, "RDS resource {0} unavailable", resourceName); - cleanUpRoutes(); + String error = "RDS resource does not exist: " + resourceName; + logger.log(XdsLogLevel.INFO, error); + cleanUpRoutes(error); } }); } @@ -947,7 +985,7 @@ private static class RoutingConfig { final Map virtualHostOverrideConfig; private static RoutingConfig empty = new RoutingConfig( - 0L, Collections.emptyList(), null, Collections.emptyMap()); + 0, Collections.emptyList(), null, Collections.emptyMap()); private RoutingConfig( long fallbackTimeoutNano, List routes, @Nullable List filterChain, @@ -979,15 +1017,17 @@ private ClusterRefState( private Map toLbPolicy() { if (traditionalCluster != null) { - return ImmutableMap.of("cds_experimental", ImmutableMap.of("cluster", traditionalCluster)); + return ImmutableMap.of( + XdsLbPolicies.CDS_POLICY_NAME, + ImmutableMap.of("cluster", traditionalCluster)); } else { ImmutableMap rlsConfig = new ImmutableMap.Builder() .put("routeLookupConfig", rlsPluginConfig.config()) .put( "childPolicy", - ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of()))) + ImmutableList.of(ImmutableMap.of(XdsLbPolicies.CDS_POLICY_NAME, ImmutableMap.of()))) .put("childPolicyConfigTargetFieldName", "cluster") - .build(); + .buildOrThrow(); return ImmutableMap.of("rls_experimental", rlsConfig); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index a02e27c37c7..4875a85ea63 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -23,7 +23,11 @@ import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.ObjectPool; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; @@ -75,8 +79,9 @@ public XdsNameResolver newNameResolver(URI targetUri, Args args) { targetUri); String name = targetPath.substring(1); return new XdsNameResolver( - targetUri.getAuthority(), name, args.getServiceConfigParser(), - args.getSynchronizationContext(), args.getScheduledExecutorService(), + targetUri.getAuthority(), name, args.getOverrideAuthority(), + args.getServiceConfigParser(), args.getSynchronizationContext(), + args.getScheduledExecutorService(), bootstrapOverride); } return null; @@ -99,6 +104,11 @@ protected int priority() { return 4; } + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + interface XdsClientPoolFactory { void setBootstrapOverride(Map bootstrap); diff --git a/xds/src/main/java/io/grpc/xds/XdsResourceType.java b/xds/src/main/java/io/grpc/xds/XdsResourceType.java new file mode 100644 index 00000000000..1302f5a59e1 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsResourceType.java @@ -0,0 +1,295 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.Bootstrapper.ServerInfo; +import static io.grpc.xds.XdsClient.ResourceUpdate; +import static io.grpc.xds.XdsClient.canonifyResourceName; +import static io.grpc.xds.XdsClient.isResourceNameValid; +import static io.grpc.xds.XdsClientImpl.ResourceInvalidException; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.service.discovery.v3.Resource; +import io.grpc.LoadBalancerRegistry; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; + +abstract class XdsResourceType { + static final String TYPE_URL_RESOURCE = + "type.googleapis.com/envoy.service.discovery.v3.Resource"; + static final String TRANSPORT_SOCKET_NAME_TLS = "envoy.transport_sockets.tls"; + @VisibleForTesting + static final String AGGREGATE_CLUSTER_TYPE_NAME = "envoy.clusters.aggregate"; + @VisibleForTesting + static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id"; + @VisibleForTesting + static boolean enableFaultInjection = getFlag("GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION", true); + @VisibleForTesting + static boolean enableRetry = getFlag("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY", true); + @VisibleForTesting + static boolean enableRbac = getFlag("GRPC_XDS_EXPERIMENTAL_RBAC", true); + @VisibleForTesting + static boolean enableRouteLookup = getFlag("GRPC_EXPERIMENTAL_XDS_RLS_LB", false); + @VisibleForTesting + static boolean enableLeastRequest = + !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) + ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) + : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest")); + @VisibleForTesting + static boolean enableCustomLbConfig = getFlag("GRPC_EXPERIMENTAL_XDS_CUSTOM_LB_CONFIG", true); + @VisibleForTesting + static boolean enableOutlierDetection = getFlag("GRPC_EXPERIMENTAL_ENABLE_OUTLIER_DETECTION", + true); + static final String TYPE_URL_CLUSTER_CONFIG = + "type.googleapis.com/envoy.extensions.clusters.aggregate.v3.ClusterConfig"; + static final String TYPE_URL_TYPED_STRUCT_UDPA = + "type.googleapis.com/udpa.type.v1.TypedStruct"; + static final String TYPE_URL_TYPED_STRUCT = + "type.googleapis.com/xds.type.v3.TypedStruct"; + + @Nullable + abstract String extractResourceName(Message unpackedResource); + + abstract Class unpackedClassName(); + + abstract String typeName(); + + abstract String typeUrl(); + + // Do not confuse with the SotW approach: it is the mechanism in which the client must specify all + // resource names it is interested in with each request. Different resource types may behave + // differently in this approach. For LDS and CDS resources, the server must return all resources + // that the client has subscribed to in each request. For RDS and EDS, the server may only return + // the resources that need an update. + abstract boolean isFullStateOfTheWorld(); + + static class Args { + final ServerInfo serverInfo; + final String versionInfo; + final String nonce; + final Bootstrapper.BootstrapInfo bootstrapInfo; + final FilterRegistry filterRegistry; + final LoadBalancerRegistry loadBalancerRegistry; + final TlsContextManager tlsContextManager; + // Management server is required to always send newly requested resources, even if they + // may have been sent previously (proactively). Thus, client does not need to cache + // unrequested resources. + // Only resources in the set needs to be parsed. Null means parse everything. + final @Nullable Set subscribedResources; + + public Args(ServerInfo serverInfo, String versionInfo, String nonce, + Bootstrapper.BootstrapInfo bootstrapInfo, + FilterRegistry filterRegistry, + LoadBalancerRegistry loadBalancerRegistry, + TlsContextManager tlsContextManager, + @Nullable Set subscribedResources) { + this.serverInfo = serverInfo; + this.versionInfo = versionInfo; + this.nonce = nonce; + this.bootstrapInfo = bootstrapInfo; + this.filterRegistry = filterRegistry; + this.loadBalancerRegistry = loadBalancerRegistry; + this.tlsContextManager = tlsContextManager; + this.subscribedResources = subscribedResources; + } + } + + ValidatedResourceUpdate parse(Args args, List resources) { + Map> parsedResources = new HashMap<>(resources.size()); + Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); + List errors = new ArrayList<>(); + + for (int i = 0; i < resources.size(); i++) { + Any resource = resources.get(i); + + Message unpackedMessage; + try { + resource = maybeUnwrapResources(resource); + unpackedMessage = unpackCompatibleType(resource, unpackedClassName(), typeUrl(), null); + } catch (InvalidProtocolBufferException e) { + errors.add(String.format("%s response Resource index %d - can't decode %s: %s", + typeName(), i, unpackedClassName().getSimpleName(), e.getMessage())); + continue; + } + String name = extractResourceName(unpackedMessage); + if (name == null || !isResourceNameValid(name, resource.getTypeUrl())) { + errors.add( + "Unsupported resource name: " + name + " for type: " + typeName()); + continue; + } + String cname = canonifyResourceName(name); + if (args.subscribedResources != null && !args.subscribedResources.contains(name)) { + continue; + } + unpackedResources.add(cname); + + T resourceUpdate; + try { + resourceUpdate = doParse(args, unpackedMessage); + } catch (XdsClientImpl.ResourceInvalidException e) { + errors.add(String.format("%s response %s '%s' validation error: %s", + typeName(), unpackedClassName().getSimpleName(), cname, e.getMessage())); + invalidResources.add(cname); + continue; + } + + // Resource parsed successfully. + parsedResources.put(cname, new ParsedResource(resourceUpdate, resource)); + } + return new ValidatedResourceUpdate(parsedResources, unpackedResources, invalidResources, + errors); + + } + + abstract T doParse(Args args, Message unpackedMessage) throws ResourceInvalidException; + + /** + * Helper method to unpack serialized {@link com.google.protobuf.Any} message, while replacing + * Type URL {@code compatibleTypeUrl} with {@code typeUrl}. + * + * @param The type of unpacked message + * @param any serialized message to unpack + * @param clazz the class to unpack the message to + * @param typeUrl type URL to replace message Type URL, when it's compatible + * @param compatibleTypeUrl compatible Type URL to be replaced with {@code typeUrl} + * @return Unpacked message + * @throws InvalidProtocolBufferException if the message couldn't be unpacked + */ + static T unpackCompatibleType( + Any any, Class clazz, String typeUrl, String compatibleTypeUrl) + throws InvalidProtocolBufferException { + if (any.getTypeUrl().equals(compatibleTypeUrl)) { + any = any.toBuilder().setTypeUrl(typeUrl).build(); + } + return any.unpack(clazz); + } + + private Any maybeUnwrapResources(Any resource) + throws InvalidProtocolBufferException { + if (resource.getTypeUrl().equals(TYPE_URL_RESOURCE)) { + return unpackCompatibleType(resource, Resource.class, TYPE_URL_RESOURCE, + null).getResource(); + } else { + return resource; + } + } + + static final class ParsedResource { + private final T resourceUpdate; + private final Any rawResource; + + public ParsedResource(T resourceUpdate, Any rawResource) { + this.resourceUpdate = checkNotNull(resourceUpdate, "resourceUpdate"); + this.rawResource = checkNotNull(rawResource, "rawResource"); + } + + T getResourceUpdate() { + return resourceUpdate; + } + + Any getRawResource() { + return rawResource; + } + } + + static final class ValidatedResourceUpdate { + Map> parsedResources; + Set unpackedResources; + Set invalidResources; + List errors; + + // validated resource update + public ValidatedResourceUpdate(Map> parsedResources, + Set unpackedResources, + Set invalidResources, + List errors) { + this.parsedResources = parsedResources; + this.unpackedResources = unpackedResources; + this.invalidResources = invalidResources; + this.errors = errors; + } + } + + private static boolean getFlag(String envVarName, boolean enableByDefault) { + String envVar = System.getenv(envVarName); + if (enableByDefault) { + return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar); + } else { + return !Strings.isNullOrEmpty(envVar) && Boolean.parseBoolean(envVar); + } + } + + @VisibleForTesting + static final class StructOrError { + + /** + * Returns a {@link StructOrError} for the successfully converted data object. + */ + static StructOrError fromStruct(T struct) { + return new StructOrError<>(struct); + } + + /** + * Returns a {@link StructOrError} for the failure to convert the data object. + */ + static StructOrError fromError(String errorDetail) { + return new StructOrError<>(errorDetail); + } + + private final String errorDetail; + private final T struct; + + private StructOrError(T struct) { + this.struct = checkNotNull(struct, "struct"); + this.errorDetail = null; + } + + private StructOrError(String errorDetail) { + this.struct = null; + this.errorDetail = checkNotNull(errorDetail, "errorDetail"); + } + + /** + * Returns struct if exists, otherwise null. + */ + @VisibleForTesting + @Nullable + T getStruct() { + return struct; + } + + /** + * Returns error detail if exists, otherwise null. + */ + @VisibleForTesting + @Nullable + String getErrorDetail() { + return errorDetail; + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java new file mode 100644 index 00000000000..ed109fd694b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java @@ -0,0 +1,673 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.github.udpa.udpa.type.v1.TypedStruct; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.Any; +import com.google.protobuf.Duration; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.util.Durations; +import com.google.re2j.Pattern; +import com.google.re2j.PatternSyntaxException; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin; +import io.envoyproxy.envoy.config.route.v3.RetryPolicy.RetryBackOff; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.grpc.Status; +import io.grpc.xds.ClusterSpecifierPlugin.NamedPluginConfig; +import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.VirtualHost.Route.RouteAction; +import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; +import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; +import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClient.ResourceUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.internal.MatcherParser; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import javax.annotation.Nullable; + +class XdsRouteConfigureResource extends XdsResourceType { + static final String ADS_TYPE_URL_RDS = + "type.googleapis.com/envoy.config.route.v3.RouteConfiguration"; + private static final String TYPE_URL_FILTER_CONFIG = + "type.googleapis.com/envoy.config.route.v3.FilterConfig"; + // TODO(zdapeng): need to discuss how to handle unsupported values. + private static final Set SUPPORTED_RETRYABLE_CODES = + Collections.unmodifiableSet(EnumSet.of( + Status.Code.CANCELLED, Status.Code.DEADLINE_EXCEEDED, Status.Code.INTERNAL, + Status.Code.RESOURCE_EXHAUSTED, Status.Code.UNAVAILABLE)); + + private static final XdsRouteConfigureResource instance = new XdsRouteConfigureResource(); + + public static XdsRouteConfigureResource getInstance() { + return instance; + } + + @Override + @Nullable + String extractResourceName(Message unpackedResource) { + if (!(unpackedResource instanceof RouteConfiguration)) { + return null; + } + return ((RouteConfiguration) unpackedResource).getName(); + } + + @Override + String typeName() { + return "RDS"; + } + + @Override + String typeUrl() { + return ADS_TYPE_URL_RDS; + } + + @Override + boolean isFullStateOfTheWorld() { + return false; + } + + @Override + Class unpackedClassName() { + return RouteConfiguration.class; + } + + @Override + RdsUpdate doParse(XdsResourceType.Args args, Message unpackedMessage) + throws ResourceInvalidException { + if (!(unpackedMessage instanceof RouteConfiguration)) { + throw new ResourceInvalidException("Invalid message type: " + unpackedMessage.getClass()); + } + return processRouteConfiguration((RouteConfiguration) unpackedMessage, + args.filterRegistry, enableFaultInjection); + } + + private static RdsUpdate processRouteConfiguration( + RouteConfiguration routeConfig, FilterRegistry filterRegistry, boolean parseHttpFilter) + throws ResourceInvalidException { + return new RdsUpdate(extractVirtualHosts(routeConfig, filterRegistry, parseHttpFilter)); + } + + static List extractVirtualHosts( + RouteConfiguration routeConfig, FilterRegistry filterRegistry, boolean parseHttpFilter) + throws ResourceInvalidException { + Map pluginConfigMap = new HashMap<>(); + ImmutableSet.Builder optionalPlugins = ImmutableSet.builder(); + + if (enableRouteLookup) { + List plugins = routeConfig.getClusterSpecifierPluginsList(); + for (ClusterSpecifierPlugin plugin : plugins) { + String pluginName = plugin.getExtension().getName(); + PluginConfig pluginConfig = parseClusterSpecifierPlugin(plugin); + if (pluginConfig != null) { + if (pluginConfigMap.put(pluginName, pluginConfig) != null) { + throw new ResourceInvalidException( + "Multiple ClusterSpecifierPlugins with the same name: " + pluginName); + } + } else { + // The plugin parsed successfully, and it's not supported, but it's marked as optional. + optionalPlugins.add(pluginName); + } + } + } + List virtualHosts = new ArrayList<>(routeConfig.getVirtualHostsCount()); + for (io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHostProto + : routeConfig.getVirtualHostsList()) { + StructOrError virtualHost = + parseVirtualHost(virtualHostProto, filterRegistry, parseHttpFilter, pluginConfigMap, + optionalPlugins.build()); + if (virtualHost.getErrorDetail() != null) { + throw new ResourceInvalidException( + "RouteConfiguration contains invalid virtual host: " + virtualHost.getErrorDetail()); + } + virtualHosts.add(virtualHost.getStruct()); + } + return virtualHosts; + } + + private static StructOrError parseVirtualHost( + io.envoyproxy.envoy.config.route.v3.VirtualHost proto, FilterRegistry filterRegistry, + boolean parseHttpFilter, Map pluginConfigMap, + Set optionalPlugins) { + String name = proto.getName(); + List routes = new ArrayList<>(proto.getRoutesCount()); + for (io.envoyproxy.envoy.config.route.v3.Route routeProto : proto.getRoutesList()) { + StructOrError route = parseRoute( + routeProto, filterRegistry, parseHttpFilter, pluginConfigMap, optionalPlugins); + if (route == null) { + continue; + } + if (route.getErrorDetail() != null) { + return StructOrError.fromError( + "Virtual host [" + name + "] contains invalid route : " + route.getErrorDetail()); + } + routes.add(route.getStruct()); + } + if (!parseHttpFilter) { + return StructOrError.fromStruct(VirtualHost.create( + name, proto.getDomainsList(), routes, new HashMap())); + } + StructOrError> overrideConfigs = + parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry); + if (overrideConfigs.getErrorDetail() != null) { + return StructOrError.fromError( + "VirtualHost [" + proto.getName() + "] contains invalid HttpFilter config: " + + overrideConfigs.getErrorDetail()); + } + return StructOrError.fromStruct(VirtualHost.create( + name, proto.getDomainsList(), routes, overrideConfigs.getStruct())); + } + + @VisibleForTesting + static StructOrError> parseOverrideFilterConfigs( + Map rawFilterConfigMap, FilterRegistry filterRegistry) { + Map overrideConfigs = new HashMap<>(); + for (String name : rawFilterConfigMap.keySet()) { + Any anyConfig = rawFilterConfigMap.get(name); + String typeUrl = anyConfig.getTypeUrl(); + boolean isOptional = false; + if (typeUrl.equals(TYPE_URL_FILTER_CONFIG)) { + io.envoyproxy.envoy.config.route.v3.FilterConfig filterConfig; + try { + filterConfig = + anyConfig.unpack(io.envoyproxy.envoy.config.route.v3.FilterConfig.class); + } catch (InvalidProtocolBufferException e) { + return StructOrError.fromError( + "FilterConfig [" + name + "] contains invalid proto: " + e); + } + isOptional = filterConfig.getIsOptional(); + anyConfig = filterConfig.getConfig(); + typeUrl = anyConfig.getTypeUrl(); + } + Message rawConfig = anyConfig; + try { + if (typeUrl.equals(TYPE_URL_TYPED_STRUCT_UDPA)) { + TypedStruct typedStruct = anyConfig.unpack(TypedStruct.class); + typeUrl = typedStruct.getTypeUrl(); + rawConfig = typedStruct.getValue(); + } else if (typeUrl.equals(TYPE_URL_TYPED_STRUCT)) { + com.github.xds.type.v3.TypedStruct newTypedStruct = + anyConfig.unpack(com.github.xds.type.v3.TypedStruct.class); + typeUrl = newTypedStruct.getTypeUrl(); + rawConfig = newTypedStruct.getValue(); + } + } catch (InvalidProtocolBufferException e) { + return StructOrError.fromError( + "FilterConfig [" + name + "] contains invalid proto: " + e); + } + Filter filter = filterRegistry.get(typeUrl); + if (filter == null) { + if (isOptional) { + continue; + } + return StructOrError.fromError( + "HttpFilter [" + name + "](" + typeUrl + ") is required but unsupported"); + } + ConfigOrError filterConfig = + filter.parseFilterConfigOverride(rawConfig); + if (filterConfig.errorDetail != null) { + return StructOrError.fromError( + "Invalid filter config for HttpFilter [" + name + "]: " + filterConfig.errorDetail); + } + overrideConfigs.put(name, filterConfig.config); + } + return StructOrError.fromStruct(overrideConfigs); + } + + @VisibleForTesting + @Nullable + static StructOrError parseRoute( + io.envoyproxy.envoy.config.route.v3.Route proto, FilterRegistry filterRegistry, + boolean parseHttpFilter, Map pluginConfigMap, + Set optionalPlugins) { + StructOrError routeMatch = parseRouteMatch(proto.getMatch()); + if (routeMatch == null) { + return null; + } + if (routeMatch.getErrorDetail() != null) { + return StructOrError.fromError( + "Route [" + proto.getName() + "] contains invalid RouteMatch: " + + routeMatch.getErrorDetail()); + } + + Map overrideConfigs = Collections.emptyMap(); + if (parseHttpFilter) { + StructOrError> overrideConfigsOrError = + parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry); + if (overrideConfigsOrError.getErrorDetail() != null) { + return StructOrError.fromError( + "Route [" + proto.getName() + "] contains invalid HttpFilter config: " + + overrideConfigsOrError.getErrorDetail()); + } + overrideConfigs = overrideConfigsOrError.getStruct(); + } + + switch (proto.getActionCase()) { + case ROUTE: + StructOrError routeAction = + parseRouteAction(proto.getRoute(), filterRegistry, parseHttpFilter, pluginConfigMap, + optionalPlugins); + if (routeAction == null) { + return null; + } + if (routeAction.getErrorDetail() != null) { + return StructOrError.fromError( + "Route [" + proto.getName() + "] contains invalid RouteAction: " + + routeAction.getErrorDetail()); + } + return StructOrError.fromStruct( + Route.forAction(routeMatch.getStruct(), routeAction.getStruct(), overrideConfigs)); + case NON_FORWARDING_ACTION: + return StructOrError.fromStruct( + Route.forNonForwardingAction(routeMatch.getStruct(), overrideConfigs)); + case REDIRECT: + case DIRECT_RESPONSE: + case FILTER_ACTION: + case ACTION_NOT_SET: + default: + return StructOrError.fromError( + "Route [" + proto.getName() + "] with unknown action type: " + proto.getActionCase()); + } + } + + @VisibleForTesting + @Nullable + static StructOrError parseRouteMatch( + io.envoyproxy.envoy.config.route.v3.RouteMatch proto) { + if (proto.getQueryParametersCount() != 0) { + return null; + } + StructOrError pathMatch = parsePathMatcher(proto); + if (pathMatch.getErrorDetail() != null) { + return StructOrError.fromError(pathMatch.getErrorDetail()); + } + + FractionMatcher fractionMatch = null; + if (proto.hasRuntimeFraction()) { + StructOrError parsedFraction = + parseFractionMatcher(proto.getRuntimeFraction().getDefaultValue()); + if (parsedFraction.getErrorDetail() != null) { + return StructOrError.fromError(parsedFraction.getErrorDetail()); + } + fractionMatch = parsedFraction.getStruct(); + } + + List headerMatchers = new ArrayList<>(); + for (io.envoyproxy.envoy.config.route.v3.HeaderMatcher hmProto : proto.getHeadersList()) { + StructOrError headerMatcher = parseHeaderMatcher(hmProto); + if (headerMatcher.getErrorDetail() != null) { + return StructOrError.fromError(headerMatcher.getErrorDetail()); + } + headerMatchers.add(headerMatcher.getStruct()); + } + + return StructOrError.fromStruct(RouteMatch.create( + pathMatch.getStruct(), headerMatchers, fractionMatch)); + } + + @VisibleForTesting + static StructOrError parsePathMatcher( + io.envoyproxy.envoy.config.route.v3.RouteMatch proto) { + boolean caseSensitive = proto.getCaseSensitive().getValue(); + switch (proto.getPathSpecifierCase()) { + case PREFIX: + return StructOrError.fromStruct( + PathMatcher.fromPrefix(proto.getPrefix(), caseSensitive)); + case PATH: + return StructOrError.fromStruct(PathMatcher.fromPath(proto.getPath(), caseSensitive)); + case SAFE_REGEX: + String rawPattern = proto.getSafeRegex().getRegex(); + Pattern safeRegEx; + try { + safeRegEx = Pattern.compile(rawPattern); + } catch (PatternSyntaxException e) { + return StructOrError.fromError("Malformed safe regex pattern: " + e.getMessage()); + } + return StructOrError.fromStruct(PathMatcher.fromRegEx(safeRegEx)); + case PATHSPECIFIER_NOT_SET: + default: + return StructOrError.fromError("Unknown path match type"); + } + } + + private static StructOrError parseFractionMatcher(FractionalPercent proto) { + int numerator = proto.getNumerator(); + int denominator = 0; + switch (proto.getDenominator()) { + case HUNDRED: + denominator = 100; + break; + case TEN_THOUSAND: + denominator = 10_000; + break; + case MILLION: + denominator = 1_000_000; + break; + case UNRECOGNIZED: + default: + return StructOrError.fromError( + "Unrecognized fractional percent denominator: " + proto.getDenominator()); + } + return StructOrError.fromStruct(FractionMatcher.create(numerator, denominator)); + } + + @VisibleForTesting + static StructOrError parseHeaderMatcher( + io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { + try { + Matchers.HeaderMatcher headerMatcher = MatcherParser.parseHeaderMatcher(proto); + return StructOrError.fromStruct(headerMatcher); + } catch (IllegalArgumentException e) { + return StructOrError.fromError(e.getMessage()); + } + } + + /** + * Parses the RouteAction config. The returned result may contain a (parsed form) + * {@link RouteAction} or an error message. Returns {@code null} if the RouteAction + * should be ignored. + */ + @VisibleForTesting + @Nullable + static StructOrError parseRouteAction( + io.envoyproxy.envoy.config.route.v3.RouteAction proto, FilterRegistry filterRegistry, + boolean parseHttpFilter, Map pluginConfigMap, + Set optionalPlugins) { + Long timeoutNano = null; + if (proto.hasMaxStreamDuration()) { + io.envoyproxy.envoy.config.route.v3.RouteAction.MaxStreamDuration maxStreamDuration + = proto.getMaxStreamDuration(); + if (maxStreamDuration.hasGrpcTimeoutHeaderMax()) { + timeoutNano = Durations.toNanos(maxStreamDuration.getGrpcTimeoutHeaderMax()); + } else if (maxStreamDuration.hasMaxStreamDuration()) { + timeoutNano = Durations.toNanos(maxStreamDuration.getMaxStreamDuration()); + } + } + RetryPolicy retryPolicy = null; + if (enableRetry && proto.hasRetryPolicy()) { + StructOrError retryPolicyOrError = parseRetryPolicy(proto.getRetryPolicy()); + if (retryPolicyOrError != null) { + if (retryPolicyOrError.getErrorDetail() != null) { + return StructOrError.fromError(retryPolicyOrError.getErrorDetail()); + } + retryPolicy = retryPolicyOrError.getStruct(); + } + } + List hashPolicies = new ArrayList<>(); + for (io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy config + : proto.getHashPolicyList()) { + HashPolicy policy = null; + boolean terminal = config.getTerminal(); + switch (config.getPolicySpecifierCase()) { + case HEADER: + io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.Header headerCfg = + config.getHeader(); + Pattern regEx = null; + String regExSubstitute = null; + if (headerCfg.hasRegexRewrite() && headerCfg.getRegexRewrite().hasPattern() + && headerCfg.getRegexRewrite().getPattern().hasGoogleRe2()) { + regEx = Pattern.compile(headerCfg.getRegexRewrite().getPattern().getRegex()); + regExSubstitute = headerCfg.getRegexRewrite().getSubstitution(); + } + policy = HashPolicy.forHeader( + terminal, headerCfg.getHeaderName(), regEx, regExSubstitute); + break; + case FILTER_STATE: + if (config.getFilterState().getKey().equals(HASH_POLICY_FILTER_STATE_KEY)) { + policy = HashPolicy.forChannelId(terminal); + } + break; + default: + // Ignore + } + if (policy != null) { + hashPolicies.add(policy); + } + } + + switch (proto.getClusterSpecifierCase()) { + case CLUSTER: + return StructOrError.fromStruct(RouteAction.forCluster( + proto.getCluster(), hashPolicies, timeoutNano, retryPolicy)); + case CLUSTER_HEADER: + return null; + case WEIGHTED_CLUSTERS: + List clusterWeights + = proto.getWeightedClusters().getClustersList(); + if (clusterWeights.isEmpty()) { + return StructOrError.fromError("No cluster found in weighted cluster list"); + } + List weightedClusters = new ArrayList<>(); + int clusterWeightSum = 0; + for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight + : clusterWeights) { + StructOrError clusterWeightOrError = + parseClusterWeight(clusterWeight, filterRegistry, parseHttpFilter); + if (clusterWeightOrError.getErrorDetail() != null) { + return StructOrError.fromError("RouteAction contains invalid ClusterWeight: " + + clusterWeightOrError.getErrorDetail()); + } + clusterWeightSum += clusterWeight.getWeight().getValue(); + weightedClusters.add(clusterWeightOrError.getStruct()); + } + if (clusterWeightSum <= 0) { + return StructOrError.fromError("Sum of cluster weights should be above 0."); + } + return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters( + weightedClusters, hashPolicies, timeoutNano, retryPolicy)); + case CLUSTER_SPECIFIER_PLUGIN: + if (enableRouteLookup) { + String pluginName = proto.getClusterSpecifierPlugin(); + PluginConfig pluginConfig = pluginConfigMap.get(pluginName); + if (pluginConfig == null) { + // Skip route if the plugin is not registered, but it's optional. + if (optionalPlugins.contains(pluginName)) { + return null; + } + return StructOrError.fromError( + "ClusterSpecifierPlugin for [" + pluginName + "] not found"); + } + NamedPluginConfig namedPluginConfig = NamedPluginConfig.create(pluginName, pluginConfig); + return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forClusterSpecifierPlugin( + namedPluginConfig, hashPolicies, timeoutNano, retryPolicy)); + } else { + return null; + } + case CLUSTERSPECIFIER_NOT_SET: + default: + return null; + } + } + + @Nullable // Return null if we ignore the given policy. + private static StructOrError parseRetryPolicy( + io.envoyproxy.envoy.config.route.v3.RetryPolicy retryPolicyProto) { + int maxAttempts = 2; + if (retryPolicyProto.hasNumRetries()) { + maxAttempts = retryPolicyProto.getNumRetries().getValue() + 1; + } + Duration initialBackoff = Durations.fromMillis(25); + Duration maxBackoff = Durations.fromMillis(250); + if (retryPolicyProto.hasRetryBackOff()) { + RetryBackOff retryBackOff = retryPolicyProto.getRetryBackOff(); + if (!retryBackOff.hasBaseInterval()) { + return StructOrError.fromError("No base_interval specified in retry_backoff"); + } + Duration originalInitialBackoff = initialBackoff = retryBackOff.getBaseInterval(); + if (Durations.compare(initialBackoff, Durations.ZERO) <= 0) { + return StructOrError.fromError("base_interval in retry_backoff must be positive"); + } + if (Durations.compare(initialBackoff, Durations.fromMillis(1)) < 0) { + initialBackoff = Durations.fromMillis(1); + } + if (retryBackOff.hasMaxInterval()) { + maxBackoff = retryPolicyProto.getRetryBackOff().getMaxInterval(); + if (Durations.compare(maxBackoff, originalInitialBackoff) < 0) { + return StructOrError.fromError( + "max_interval in retry_backoff cannot be less than base_interval"); + } + if (Durations.compare(maxBackoff, Durations.fromMillis(1)) < 0) { + maxBackoff = Durations.fromMillis(1); + } + } else { + maxBackoff = Durations.fromNanos(Durations.toNanos(initialBackoff) * 10); + } + } + Iterable retryOns = + Splitter.on(',').omitEmptyStrings().trimResults().split(retryPolicyProto.getRetryOn()); + ImmutableList.Builder retryableStatusCodesBuilder = ImmutableList.builder(); + for (String retryOn : retryOns) { + Status.Code code; + try { + code = Status.Code.valueOf(retryOn.toUpperCase(Locale.US).replace('-', '_')); + } catch (IllegalArgumentException e) { + // unsupported value, such as "5xx" + continue; + } + if (!SUPPORTED_RETRYABLE_CODES.contains(code)) { + // unsupported value + continue; + } + retryableStatusCodesBuilder.add(code); + } + List retryableStatusCodes = retryableStatusCodesBuilder.build(); + return StructOrError.fromStruct( + VirtualHost.Route.RouteAction.RetryPolicy.create( + maxAttempts, retryableStatusCodes, initialBackoff, maxBackoff, + /* perAttemptRecvTimeout= */ null)); + } + + @VisibleForTesting + static StructOrError parseClusterWeight( + io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight proto, + FilterRegistry filterRegistry, boolean parseHttpFilter) { + if (!parseHttpFilter) { + return StructOrError.fromStruct(ClusterWeight.create(proto.getName(), + proto.getWeight().getValue(), new HashMap())); + } + StructOrError> overrideConfigs = + parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry); + if (overrideConfigs.getErrorDetail() != null) { + return StructOrError.fromError( + "ClusterWeight [" + proto.getName() + "] contains invalid HttpFilter config: " + + overrideConfigs.getErrorDetail()); + } + return StructOrError.fromStruct(VirtualHost.Route.RouteAction.ClusterWeight.create( + proto.getName(), proto.getWeight().getValue(), overrideConfigs.getStruct())); + } + + @Nullable // null if the plugin is not supported, but it's marked as optional. + private static PluginConfig parseClusterSpecifierPlugin(ClusterSpecifierPlugin pluginProto) + throws ResourceInvalidException { + return parseClusterSpecifierPlugin( + pluginProto, ClusterSpecifierPluginRegistry.getDefaultRegistry()); + } + + @Nullable // null if the plugin is not supported, but it's marked as optional. + @VisibleForTesting + static PluginConfig parseClusterSpecifierPlugin( + ClusterSpecifierPlugin pluginProto, ClusterSpecifierPluginRegistry registry) + throws ResourceInvalidException { + TypedExtensionConfig extension = pluginProto.getExtension(); + String pluginName = extension.getName(); + Any anyConfig = extension.getTypedConfig(); + String typeUrl = anyConfig.getTypeUrl(); + Message rawConfig = anyConfig; + if (typeUrl.equals(TYPE_URL_TYPED_STRUCT_UDPA) || typeUrl.equals(TYPE_URL_TYPED_STRUCT)) { + try { + TypedStruct typedStruct = unpackCompatibleType( + anyConfig, TypedStruct.class, TYPE_URL_TYPED_STRUCT_UDPA, TYPE_URL_TYPED_STRUCT); + typeUrl = typedStruct.getTypeUrl(); + rawConfig = typedStruct.getValue(); + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException( + "ClusterSpecifierPlugin [" + pluginName + "] contains invalid proto", e); + } + } + io.grpc.xds.ClusterSpecifierPlugin plugin = registry.get(typeUrl); + if (plugin == null) { + if (!pluginProto.getIsOptional()) { + throw new ResourceInvalidException("Unsupported ClusterSpecifierPlugin type: " + typeUrl); + } + return null; + } + ConfigOrError pluginConfigOrError = plugin.parsePlugin(rawConfig); + if (pluginConfigOrError.errorDetail != null) { + throw new ResourceInvalidException(pluginConfigOrError.errorDetail); + } + return pluginConfigOrError.config; + } + + static final class RdsUpdate implements ResourceUpdate { + // The list virtual hosts that make up the route table. + final List virtualHosts; + + RdsUpdate(List virtualHosts) { + this.virtualHosts = Collections.unmodifiableList( + new ArrayList<>(checkNotNull(virtualHosts, "virtualHosts"))); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("virtualHosts", virtualHosts) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(virtualHosts); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RdsUpdate that = (RdsUpdate) o; + return Objects.equals(virtualHosts, that.virtualHosts); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index d4df317a7e9..e6dac5d25af 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -176,19 +176,18 @@ public interface XdsServingStatusListener { private static class DefaultListener implements XdsServingStatusListener { private final Logger logger; private final String prefix; - boolean notServing; + boolean notServingDueToError; DefaultListener(String prefix) { logger = Logger.getLogger(DefaultListener.class.getName()); this.prefix = prefix; - notServing = true; } /** Log calls to onServing() following a call to onNotServing() at WARNING level. */ @Override public void onServing() { - if (notServing) { - notServing = false; + if (notServingDueToError) { + notServingDueToError = false; logger.warning("[" + prefix + "] Entering serving state."); } } @@ -196,7 +195,7 @@ public void onServing() { @Override public void onNotServing(Throwable throwable) { logger.warning("[" + prefix + "] " + throwable.getMessage()); - notServing = true; + notServingDueToError = true; } } } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerCredentials.java b/xds/src/main/java/io/grpc/xds/XdsServerCredentials.java index e6e78f319c7..2212e7a1855 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerCredentials.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerCredentials.java @@ -22,7 +22,7 @@ import io.grpc.ServerCredentials; import io.grpc.netty.InternalNettyServerCredentials; import io.grpc.netty.InternalProtocolNegotiator; -import io.grpc.xds.internal.sds.SdsProtocolNegotiators; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7514") public class XdsServerCredentials { @@ -40,6 +40,6 @@ public static ServerCredentials create(ServerCredentials fallback) { InternalProtocolNegotiator.ServerFactory fallbackNegotiator = InternalNettyServerCredentials.toNegotiator(checkNotNull(fallback, "fallback")); return InternalNettyServerCredentials.create( - SdsProtocolNegotiators.serverProtocolNegotiatorFactory(fallbackNegotiator)); + SecurityProtocolNegotiators.serverProtocolNegotiatorFactory(fallbackNegotiator)); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index eef4a9c2fb3..b3bbe005825 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -50,18 +50,16 @@ import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.VirtualHost.Route; -import io.grpc.xds.XdsClient.LdsResourceWatcher; -import io.grpc.xds.XdsClient.LdsUpdate; -import io.grpc.xds.XdsClient.RdsResourceWatcher; -import io.grpc.xds.XdsClient.RdsUpdate; +import io.grpc.xds.XdsClient.ResourceWatcher; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.io.IOException; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collections; -import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -182,9 +180,8 @@ private void internalStart() { return; } xdsClient = xdsClientPool.getObject(); - boolean useProtocolV3 = xdsClient.getBootstrapInfo().servers().get(0).useProtocolV3(); String listenerTemplate = xdsClient.getBootstrapInfo().serverListenerResourceNameTemplate(); - if (!useProtocolV3 || listenerTemplate == null) { + if (listenerTemplate == null) { StatusException statusException = Status.UNAVAILABLE.withDescription( "Can only support xDS v3 with listener resource name template").asException(); @@ -330,6 +327,8 @@ private void startDelegateServer() { if (!initialStarted) { initialStarted = true; initialStartFuture.set(e); + } else { + listener.onNotServing(e); } restartTimer = syncContext.schedule( new RestartTask(), RETRY_DELAY_NANOS, TimeUnit.NANOSECONDS, timeService); @@ -343,7 +342,7 @@ public void run() { } } - private final class DiscoveryState implements LdsResourceWatcher { + private final class DiscoveryState implements ResourceWatcher { private final String resourceName; // RDS resource name is the key. private final Map routeDiscoveryStates = new HashMap<>(); @@ -367,7 +366,7 @@ public Listener interceptCall(ServerCall call, private DiscoveryState(String resourceName) { this.resourceName = checkNotNull(resourceName, "resourceName"); - xdsClient.watchLdsResource(resourceName, this); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), resourceName, this); } @Override @@ -401,7 +400,8 @@ public void run() { if (rdsState == null) { rdsState = new RouteDiscoveryState(hcm.rdsName()); routeDiscoveryStates.put(hcm.rdsName(), rdsState); - xdsClient.watchRdsResource(hcm.rdsName(), rdsState); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + hcm.rdsName(), rdsState); } if (rdsState.isPending) { pendingRds.add(hcm.rdsName()); @@ -411,7 +411,8 @@ public void run() { } for (Map.Entry entry: routeDiscoveryStates.entrySet()) { if (!allRds.contains(entry.getKey())) { - xdsClient.cancelRdsResourceWatch(entry.getKey(), entry.getValue()); + xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), + entry.getKey(), entry.getValue()); } } routeDiscoveryStates.keySet().retainAll(allRds); @@ -445,12 +446,8 @@ public void run() { if (stopped) { return; } - boolean isPermanentError = isPermanentError(error); - logger.log(Level.FINE, "{0} error from XdsClient: {1}", - new Object[]{isPermanentError ? "Permanent" : "Transient", error}); - if (isPermanentError) { - handleConfigNotFound(error.asException()); - } else if (!isServing) { + logger.log(Level.FINE, "Error from XdsClient", error); + if (!isServing) { listener.onNotServing(error.asException()); } } @@ -461,7 +458,7 @@ private void shutdown() { stopped = true; cleanUpRouteDiscoveryStates(); logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); - xdsClient.cancelLdsResourceWatch(resourceName, this); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), resourceName, this); List toRelease = getSuppliersInUse(); filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { @@ -546,7 +543,7 @@ private ImmutableMap generatePerRouteInterceptors( perRouteInterceptors.put(route, interceptor); } } - return perRouteInterceptors.build(); + return perRouteInterceptors.buildOrThrow(); } private ServerInterceptor combineInterceptors(final List interceptors) { @@ -591,7 +588,8 @@ private void cleanUpRouteDiscoveryStates() { for (RouteDiscoveryState rdsState : routeDiscoveryStates.values()) { String rdsName = rdsState.resourceName; logger.log(Level.FINE, "Stop watching RDS resource {0}", rdsName); - xdsClient.cancelRdsResourceWatch(rdsName, rdsState); + xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), rdsName, + rdsState); } routeDiscoveryStates.clear(); savedRdsRoutingConfigRef.clear(); @@ -629,7 +627,7 @@ private void releaseSuppliersInFlight() { } } - private final class RouteDiscoveryState implements RdsResourceWatcher { + private final class RouteDiscoveryState implements ResourceWatcher { private final String resourceName; private ImmutableList savedVirtualHosts; private boolean isPending = true; @@ -719,16 +717,6 @@ private void maybeUpdateSelector() { } } } - - private boolean isPermanentError(Status error) { - return EnumSet.of( - Status.Code.INTERNAL, - Status.Code.INVALID_ARGUMENT, - Status.Code.FAILED_PRECONDITION, - Status.Code.PERMISSION_DENIED, - Status.Code.UNAUTHENTICATED) - .contains(error.getCode()); - } } @VisibleForTesting diff --git a/xds/src/main/java/io/grpc/xds/internal/GoogleDefaultXdsCredentialsProvider.java b/xds/src/main/java/io/grpc/xds/internal/GoogleDefaultXdsCredentialsProvider.java new file mode 100644 index 00000000000..383c19b6665 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/GoogleDefaultXdsCredentialsProvider.java @@ -0,0 +1,50 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.ChannelCredentials; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.xds.XdsCredentialsProvider; +import java.util.Map; + +/** + * A wrapper class that supports {@link GoogleDefaultChannelCredentials} for + * Xds by implementing {@link XdsCredentialsProvider}. + */ +public final class GoogleDefaultXdsCredentialsProvider extends XdsCredentialsProvider { + private static final String CREDS_NAME = "google_default"; + + @Override + protected ChannelCredentials newChannelCredentials(Map jsonConfig) { + return GoogleDefaultChannelCredentials.create(); + } + + @Override + protected String getName() { + return CREDS_NAME; + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int priority() { + return 5; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/InsecureXdsCredentialsProvider.java b/xds/src/main/java/io/grpc/xds/internal/InsecureXdsCredentialsProvider.java new file mode 100644 index 00000000000..d57cfe2f238 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/InsecureXdsCredentialsProvider.java @@ -0,0 +1,51 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.xds.XdsCredentialsProvider; +import java.util.Map; + +/** + * A wrapper class that supports {@link InsecureChannelCredentials} for Xds + * by implementing {@link XdsCredentialsProvider}. + */ +public final class InsecureXdsCredentialsProvider extends XdsCredentialsProvider { + private static final String CREDS_NAME = "insecure"; + + @Override + protected ChannelCredentials newChannelCredentials(Map jsonConfig) { + return InsecureChannelCredentials.create(); + } + + @Override + protected String getName() { + return CREDS_NAME; + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int priority() { + return 5; + } + +} diff --git a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java index 0a971655df1..39b80bbcc03 100644 --- a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java +++ b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java @@ -54,6 +54,12 @@ public static Matchers.HeaderMatcher parseHeaderMatcher( case SUFFIX_MATCH: return Matchers.HeaderMatcher.forSuffix( proto.getName(), proto.getSuffixMatch(), proto.getInvertMatch()); + case CONTAINS_MATCH: + return Matchers.HeaderMatcher.forContains( + proto.getName(), proto.getContainsMatch(), proto.getInvertMatch()); + case STRING_MATCH: + return Matchers.HeaderMatcher.forString( + proto.getName(), parseStringMatcher(proto.getStringMatch()), proto.getInvertMatch()); case HEADERMATCHSPECIFIER_NOT_SET: default: throw new IllegalArgumentException( diff --git a/xds/src/main/java/io/grpc/xds/internal/Matchers.java b/xds/src/main/java/io/grpc/xds/internal/Matchers.java index 3bf7b7723e2..f833fd2e480 100644 --- a/xds/src/main/java/io/grpc/xds/internal/Matchers.java +++ b/xds/src/main/java/io/grpc/xds/internal/Matchers.java @@ -62,6 +62,14 @@ public abstract static class HeaderMatcher { @Nullable public abstract String suffix(); + // Matches header value with the substring. + @Nullable + public abstract String contains(); + + // Matches header value with the string matcher. + @Nullable + public abstract StringMatcher stringMatcher(); + // Whether the matching semantics is inverted. E.g., present && !inverted -> !present public abstract boolean inverted(); @@ -69,50 +77,71 @@ public abstract static class HeaderMatcher { public static HeaderMatcher forExactValue(String name, String exactValue, boolean inverted) { checkNotNull(name, "name"); checkNotNull(exactValue, "exactValue"); - return HeaderMatcher.create(name, exactValue, null, null, null, null, null, inverted); + return HeaderMatcher.create( + name, exactValue, null, null, null, null, null, null, null, inverted); } /** The request header value should match the regular expression pattern. */ public static HeaderMatcher forSafeRegEx(String name, Pattern safeRegEx, boolean inverted) { checkNotNull(name, "name"); checkNotNull(safeRegEx, "safeRegEx"); - return HeaderMatcher.create(name, null, safeRegEx, null, null, null, null, inverted); + return HeaderMatcher.create( + name, null, safeRegEx, null, null, null, null, null, null, inverted); } /** The request header value should be within the range. */ public static HeaderMatcher forRange(String name, Range range, boolean inverted) { checkNotNull(name, "name"); checkNotNull(range, "range"); - return HeaderMatcher.create(name, null, null, range, null, null, null, inverted); + return HeaderMatcher.create(name, null, null, range, null, null, null, null, null, inverted); } /** The request header value should exist. */ public static HeaderMatcher forPresent(String name, boolean present, boolean inverted) { checkNotNull(name, "name"); - return HeaderMatcher.create(name, null, null, null, present, null, null, inverted); + return HeaderMatcher.create( + name, null, null, null, present, null, null, null, null, inverted); } /** The request header value should have this prefix. */ public static HeaderMatcher forPrefix(String name, String prefix, boolean inverted) { checkNotNull(name, "name"); checkNotNull(prefix, "prefix"); - return HeaderMatcher.create(name, null, null, null, null, prefix, null, inverted); + return HeaderMatcher.create(name, null, null, null, null, prefix, null, null, null, inverted); } /** The request header value should have this suffix. */ public static HeaderMatcher forSuffix(String name, String suffix, boolean inverted) { checkNotNull(name, "name"); checkNotNull(suffix, "suffix"); - return HeaderMatcher.create(name, null, null, null, null, null, suffix, inverted); + return HeaderMatcher.create(name, null, null, null, null, null, suffix, null, null, inverted); + } + + /** The request header value should have this substring. */ + public static HeaderMatcher forContains(String name, String contains, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(contains, "contains"); + return HeaderMatcher.create( + name, null, null, null, null, null, null, contains, null, inverted); + } + + /** The request header value should match this stringMatcher. */ + public static HeaderMatcher forString( + String name, StringMatcher stringMatcher, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(stringMatcher, "stringMatcher"); + return HeaderMatcher.create( + name, null, null, null, null, null, null, null, stringMatcher, inverted); } private static HeaderMatcher create(String name, @Nullable String exactValue, @Nullable Pattern safeRegEx, @Nullable Range range, @Nullable Boolean present, @Nullable String prefix, - @Nullable String suffix, boolean inverted) { + @Nullable String suffix, @Nullable String contains, + @Nullable StringMatcher stringMatcher, boolean inverted) { checkNotNull(name, "name"); return new AutoValue_Matchers_HeaderMatcher(name, exactValue, safeRegEx, range, present, - prefix, suffix, inverted); + prefix, suffix, contains, stringMatcher, inverted); } /** Returns the matching result. */ @@ -138,8 +167,12 @@ public boolean matches(@Nullable String value) { baseMatch = value.startsWith(prefix()); } else if (present() != null) { baseMatch = present(); - } else { + } else if (suffix() != null) { baseMatch = value.endsWith(suffix()); + } else if (contains() != null) { + baseMatch = value.contains(contains()); + } else { + baseMatch = stringMatcher().matches(value); } return baseMatch != inverted(); } diff --git a/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java b/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java new file mode 100644 index 00000000000..f4d26a83795 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/TlsXdsCredentialsProvider.java @@ -0,0 +1,51 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.ChannelCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.xds.XdsCredentialsProvider; +import java.util.Map; + +/** + * A wrapper class that supports {@link TlsChannelCredentials} for Xds + * by implementing {@link XdsCredentialsProvider}. + */ +public final class TlsXdsCredentialsProvider extends XdsCredentialsProvider { + private static final String CREDS_NAME = "tls"; + + @Override + protected ChannelCredentials newChannelCredentials(Map jsonConfig) { + return TlsChannelCredentials.create(); + } + + @Override + protected String getName() { + return CREDS_NAME; + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int priority() { + return 5; + } + +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java similarity index 84% rename from xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java rename to xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java index bb7c0314bb4..4bf11fba3ff 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java @@ -14,29 +14,29 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider; -import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; /** Factory to create client-side SslContextProvider from UpstreamTlsContext. */ final class ClientSslContextProviderFactory implements ValueFactory { private BootstrapInfo bootstrapInfo; - private final CertProviderClientSslContextProvider.Factory + private final CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory; ClientSslContextProviderFactory(BootstrapInfo bootstrapInfo) { - this(bootstrapInfo, CertProviderClientSslContextProvider.Factory.getInstance()); + this(bootstrapInfo, CertProviderClientSslContextProviderFactory.getInstance()); } ClientSslContextProviderFactory( - BootstrapInfo bootstrapInfo, CertProviderClientSslContextProvider.Factory factory) { + BootstrapInfo bootstrapInfo, CertProviderClientSslContextProviderFactory factory) { this.bootstrapInfo = bootstrapInfo; this.certProviderClientSslContextProviderFactory = factory; } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java b/xds/src/main/java/io/grpc/xds/internal/security/Closeable.java similarity index 92% rename from xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java rename to xds/src/main/java/io/grpc/xds/internal/security/Closeable.java index c3695cecaf3..c78714abfbf 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/Closeable.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/Closeable.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; public interface Closeable extends java.io.Closeable { @Override - public void close(); + void close(); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java similarity index 98% rename from xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java rename to xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java index df09e8bb247..d3003b4a792 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java similarity index 98% rename from xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index c75347c1f5e..6bf66d022ff 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -14,13 +14,14 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.collect.ImmutableList; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.Internal; import io.grpc.Status; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.netty.handler.ssl.ApplicationProtocolConfig; @@ -34,6 +35,7 @@ import javax.annotation.Nullable; /** Base class for dynamic {@link SslContextProvider}s. */ +@Internal public abstract class DynamicSslContextProvider extends SslContextProvider { protected final List pendingCallbacks = new ArrayList<>(); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java b/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java similarity index 99% rename from xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java rename to xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java index 6a3a03a2870..b7f56492fa5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ReferenceCountingMap.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java similarity index 96% rename from xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java rename to xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index a032737e647..08f2e86fb69 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; @@ -48,13 +48,14 @@ * context. */ @VisibleForTesting -public final class SdsProtocolNegotiators { +public final class SecurityProtocolNegotiators { // Prevent instantiation. - private SdsProtocolNegotiators() { + private SecurityProtocolNegotiators() { } - private static final Logger logger = Logger.getLogger(SdsProtocolNegotiators.class.getName()); + private static final Logger logger + = Logger.getLogger(SecurityProtocolNegotiators.class.getName()); private static final AsciiString SCHEME = AsciiString.of("http"); @@ -207,10 +208,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { logger.log( Level.FINEST, - "ClientSdsHandler.updateSecret authority={0}, ctx.name={1}", + "ClientSdsHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); @@ -346,7 +347,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { ChannelHandler handler = InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ServerSslContextProviderFactory.java similarity index 85% rename from xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java rename to xds/src/main/java/io/grpc/xds/internal/security/ServerSslContextProviderFactory.java index 590ffdb47c5..6206ccdcfe6 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ServerSslContextProviderFactory.java @@ -14,29 +14,29 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; -import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider; -import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.certprovider.CertProviderServerSslContextProviderFactory; /** Factory to create server-side SslContextProvider from DownstreamTlsContext. */ final class ServerSslContextProviderFactory implements ValueFactory { private BootstrapInfo bootstrapInfo; - private final CertProviderServerSslContextProvider.Factory + private final CertProviderServerSslContextProviderFactory certProviderServerSslContextProviderFactory; ServerSslContextProviderFactory(BootstrapInfo bootstrapInfo) { - this(bootstrapInfo, CertProviderServerSslContextProvider.Factory.getInstance()); + this(bootstrapInfo, CertProviderServerSslContextProviderFactory.getInstance()); } ServerSslContextProviderFactory( - BootstrapInfo bootstrapInfo, CertProviderServerSslContextProvider.Factory factory) { + BootstrapInfo bootstrapInfo, CertProviderServerSslContextProviderFactory factory) { this.bootstrapInfo = bootstrapInfo; this.certProviderServerSslContextProviderFactory = factory; } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java similarity index 90% rename from xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index 6b661715e48..a0c4ed37dfb 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -14,17 +14,18 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.Internal; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; @@ -39,6 +40,7 @@ * stream that is receiving the requested secret(s) or it could represent file-system based * secret(s) that are dynamic. */ +@Internal public abstract class SslContextProvider implements Closeable { protected final BaseTlsContext tlsContext; @@ -55,7 +57,7 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSecret(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -70,11 +72,11 @@ protected CommonTlsContext getCommonTlsContext() { } protected void setClientAuthValues( - SslContextBuilder sslContextBuilder, SdsTrustManagerFactory sdsTrustManagerFactory) + SslContextBuilder sslContextBuilder, XdsTrustManagerFactory xdsTrustManagerFactory) throws CertificateException, IOException, CertStoreException { DownstreamTlsContext downstreamTlsContext = getDownstreamTlsContext(); - if (sdsTrustManagerFactory != null) { - sslContextBuilder.trustManager(sdsTrustManagerFactory); + if (xdsTrustManagerFactory != null) { + sslContextBuilder.trustManager(xdsTrustManagerFactory); sslContextBuilder.clientAuth( downstreamTlsContext.isRequireClientCertificate() ? ClientAuth.REQUIRE @@ -118,7 +120,7 @@ protected final void performCallback( public void run() { try { SslContext sslContext = sslContextGetter.get(); - callback.updateSecret(sslContext); + callback.updateSslContext(sslContext); } catch (Throwable e) { callback.onException(e); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java similarity index 95% rename from xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java rename to xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 664b4881bc2..5f629273179 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; @@ -29,7 +29,7 @@ /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} - * and communicate it to the consumer i.e. {@link SdsProtocolNegotiators} + * and communicate it to the consumer i.e. {@link SecurityProtocolNegotiators} * to lazily evaluate the {@link SslContextProvider}. The supplier prevents credentials leakage in * cases where the user is not using xDS credentials but the client/server contains a non-default * {@link BaseTlsContext}. @@ -66,8 +66,8 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSecret(SslContext sslContext) { - callback.updateSecret(sslContext); + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); releaseSslContextProvider(toRelease); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java similarity index 97% rename from xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java rename to xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java index 75a5d297d90..8d4fce60350 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.base.Preconditions.checkNotNull; @@ -24,7 +24,7 @@ import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; -import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; /** * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java similarity index 56% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index ce9ef3de680..3953fd5c46b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -14,19 +14,17 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkNotNull; -import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; @@ -34,10 +32,9 @@ import javax.annotation.Nullable; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ -@Internal -public final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { +final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { - private CertProviderClientSslContextProvider( + CertProviderClientSslContextProvider( Node node, @Nullable Map certProviders, CommonTlsContext.CertificateProviderInstance certInstance, @@ -62,7 +59,7 @@ protected final SslContextBuilder getSslContextBuilder( SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient() .trustManager( - new SdsTrustManagerFactory( + new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContextdationContext)); if (isMtls()) { @@ -71,42 +68,4 @@ protected final SslContextBuilder getSslContextBuilder( return sslContextBuilder; } - /** Creates CertProviderClientSslContextProvider. */ - @Internal - public static final class Factory { - private static final Factory DEFAULT_INSTANCE = - new Factory(CertificateProviderStore.getInstance()); - private final CertificateProviderStore certificateProviderStore; - - @VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) { - this.certificateProviderStore = certificateProviderStore; - } - - public static Factory getInstance() { - return DEFAULT_INSTANCE; - } - - /** Creates a {@link CertProviderClientSslContextProvider}. */ - public CertProviderClientSslContextProvider getProvider( - UpstreamTlsContext upstreamTlsContext, - Node node, - @Nullable Map certProviders) { - checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); - CertificateValidationContext staticCertValidationContext = getStaticValidationContext( - commonTlsContext); - CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( - commonTlsContext); - CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( - commonTlsContext); - return new CertProviderClientSslContextProvider( - node, - certProviders, - certInstance, - rootCertInstance, - staticCertValidationContext, - upstreamTlsContext, - certificateProviderStore); - } - } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java new file mode 100644 index 00000000000..ef91cb56703 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java @@ -0,0 +1,76 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.config.core.v3.Node; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.Internal; +import io.grpc.xds.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.security.SslContextProvider; +import java.util.Map; +import javax.annotation.Nullable; + +/** + * Creates CertProviderClientSslContextProvider. + */ +@Internal +public final class CertProviderClientSslContextProviderFactory { + + private static final CertProviderClientSslContextProviderFactory DEFAULT_INSTANCE = + new CertProviderClientSslContextProviderFactory(CertificateProviderStore.getInstance()); + private final CertificateProviderStore certificateProviderStore; + + @VisibleForTesting + public CertProviderClientSslContextProviderFactory( + CertificateProviderStore certificateProviderStore) { + this.certificateProviderStore = certificateProviderStore; + } + + public static CertProviderClientSslContextProviderFactory getInstance() { + return DEFAULT_INSTANCE; + } + + /** + * Creates a {@link CertProviderClientSslContextProvider}. + */ + public SslContextProvider getProvider( + UpstreamTlsContext upstreamTlsContext, + Node node, + @Nullable Map certProviders) { + checkNotNull(upstreamTlsContext, "upstreamTlsContext"); + CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); + CertificateValidationContext staticCertValidationContext + = CertProviderSslContextProvider.getStaticValidationContext(commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance + = CertProviderSslContextProvider.getRootCertProviderInstance(commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance + = CertProviderSslContextProvider.getCertProviderInstance(commonTlsContext); + return new CertProviderClientSslContextProvider( + node, + certProviders, + certInstance, + rootCertInstance, + staticCertValidationContext, + upstreamTlsContext, + certificateProviderStore); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java similarity index 50% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index a7f0849d00b..9d936f02dc1 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -14,21 +14,18 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkNotNull; -import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; -import io.grpc.xds.internal.sds.trust.SdsTrustManagerFactory; +import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; - import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; @@ -37,17 +34,16 @@ import javax.annotation.Nullable; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ -@Internal -public final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { +final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { - private CertProviderServerSslContextProvider( - Node node, - @Nullable Map certProviders, - CommonTlsContext.CertificateProviderInstance certInstance, - CommonTlsContext.CertificateProviderInstance rootCertInstance, - CertificateValidationContext staticCertValidationContext, - DownstreamTlsContext downstreamTlsContext, - CertificateProviderStore certificateProviderStore) { + CertProviderServerSslContextProvider( + Node node, + @Nullable Map certProviders, + CommonTlsContext.CertificateProviderInstance certInstance, + CommonTlsContext.CertificateProviderInstance rootCertInstance, + CertificateValidationContext staticCertValidationContext, + DownstreamTlsContext downstreamTlsContext, + CertificateProviderStore certificateProviderStore) { super( node, certProviders, @@ -66,7 +62,7 @@ protected final SslContextBuilder getSslContextBuilder( setClientAuthValues( sslContextBuilder, isMtls() - ? new SdsTrustManagerFactory( + ? new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContextdationContext) : null); @@ -74,42 +70,4 @@ protected final SslContextBuilder getSslContextBuilder( return sslContextBuilder; } - /** Creates CertProviderServerSslContextProvider. */ - @Internal - public static final class Factory { - private static final Factory DEFAULT_INSTANCE = - new Factory(CertificateProviderStore.getInstance()); - private final CertificateProviderStore certificateProviderStore; - - @VisibleForTesting public Factory(CertificateProviderStore certificateProviderStore) { - this.certificateProviderStore = certificateProviderStore; - } - - public static Factory getInstance() { - return DEFAULT_INSTANCE; - } - - /** Creates a {@link CertProviderServerSslContextProvider}. */ - public CertProviderServerSslContextProvider getProvider( - DownstreamTlsContext downstreamTlsContext, - Node node, - @Nullable Map certProviders) { - checkNotNull(downstreamTlsContext, "downstreamTlsContext"); - CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); - CertificateValidationContext staticCertValidationContext = getStaticValidationContext( - commonTlsContext); - CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( - commonTlsContext); - CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( - commonTlsContext); - return new CertProviderServerSslContextProvider( - node, - certProviders, - certInstance, - rootCertInstance, - staticCertValidationContext, - downstreamTlsContext, - certificateProviderStore); - } - } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderFactory.java new file mode 100644 index 00000000000..3189d49f27b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderFactory.java @@ -0,0 +1,76 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.config.core.v3.Node; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.Internal; +import io.grpc.xds.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.internal.security.SslContextProvider; +import java.util.Map; +import javax.annotation.Nullable; + +/** + * Creates CertProviderServerSslContextProvider. + */ +@Internal +public final class CertProviderServerSslContextProviderFactory { + + private static final CertProviderServerSslContextProviderFactory DEFAULT_INSTANCE = + new CertProviderServerSslContextProviderFactory(CertificateProviderStore.getInstance()); + private final CertificateProviderStore certificateProviderStore; + + @VisibleForTesting + public CertProviderServerSslContextProviderFactory( + CertificateProviderStore certificateProviderStore) { + this.certificateProviderStore = certificateProviderStore; + } + + public static CertProviderServerSslContextProviderFactory getInstance() { + return DEFAULT_INSTANCE; + } + + /** + * Creates a {@link CertProviderServerSslContextProvider}. + */ + public SslContextProvider getProvider( + DownstreamTlsContext downstreamTlsContext, + Node node, + @Nullable Map certProviders) { + checkNotNull(downstreamTlsContext, "downstreamTlsContext"); + CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); + CertificateValidationContext staticCertValidationContext + = CertProviderSslContextProvider.getStaticValidationContext(commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance + = CertProviderSslContextProvider.getRootCertProviderInstance(commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance + = CertProviderSslContextProvider.getCertProviderInstance(commonTlsContext); + return new CertProviderServerSslContextProvider( + node, + certProviders, + certInstance, + rootCertInstance, + staticCertValidationContext, + downstreamTlsContext, + certificateProviderStore); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java similarity index 97% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 5c4dba99dc3..065501fa53c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -22,8 +22,8 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; -import io.grpc.xds.internal.sds.CommonTlsContextUtil; -import io.grpc.xds.internal.sds.DynamicSslContextProvider; +import io.grpc.xds.internal.security.CommonTlsContextUtil; +import io.grpc.xds.internal.security.DynamicSslContextProvider; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java similarity index 98% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java index 04ed997fa58..a0d5d0fc69f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import io.grpc.Status; -import io.grpc.xds.internal.sds.Closeable; +import io.grpc.xds.internal.security.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.Collections; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderProvider.java similarity index 93% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderProvider.java index a426542eea0..e2e26ead502 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderProvider.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import io.grpc.Internal; -import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher; +import io.grpc.xds.internal.security.certprovider.CertificateProvider.Watcher; /** * Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderRegistry.java similarity index 98% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderRegistry.java index 12eb6f6573f..2c320b79964 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderRegistry.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkNotNull; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStore.java similarity index 97% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStore.java index 43143ebb3ae..0fe342a36c0 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderStore.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStore.java @@ -14,12 +14,11 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import com.google.common.annotations.VisibleForTesting; -import io.grpc.xds.internal.certprovider.CertificateProvider.Watcher; -import io.grpc.xds.internal.sds.ReferenceCountingMap; - +import io.grpc.xds.internal.security.ReferenceCountingMap; +import io.grpc.xds.internal.security.certprovider.CertificateProvider.Watcher; import java.io.Closeable; import java.util.Objects; import java.util.logging.Level; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java similarity index 98% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java index b86de55766e..dd945ce850e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java @@ -14,15 +14,14 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import io.grpc.Status; import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sds.trust.CertificateUtils; - +import io.grpc.xds.internal.security.trust.CertificateUtils; import java.io.ByteArrayInputStream; import java.nio.file.Files; import java.nio.file.Path; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java similarity index 99% rename from xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java rename to xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java index c1b0ce3f508..c4b140442cb 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java similarity index 99% rename from xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java rename to xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java index e4ddb99a2b1..6e244a438c0 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds.trust; +package io.grpc.xds.internal.security.trust; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java similarity index 87% rename from xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java rename to xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index 479569f1596..26d6bcd81b8 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds.trust; +package io.grpc.xds.internal.security.trust; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -23,7 +23,6 @@ import com.google.common.base.Strings; import io.envoyproxy.envoy.config.core.v3.DataSource.SpecifierCase; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; -import io.grpc.xds.internal.sds.TlsContextManagerImpl; import io.netty.handler.ssl.util.SimpleTrustManagerFactory; import java.io.File; import java.io.IOException; @@ -42,16 +41,15 @@ import javax.net.ssl.X509ExtendedTrustManager; /** - * Factory class used by providers of {@link TlsContextManagerImpl} to provide a - * {@link SdsX509TrustManager} for trust and SAN checks. + * Factory class used to provide a {@link XdsX509TrustManager} for trust and SAN checks. */ -public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory { +public final class XdsTrustManagerFactory extends SimpleTrustManagerFactory { - private static final Logger logger = Logger.getLogger(SdsTrustManagerFactory.class.getName()); - private SdsX509TrustManager sdsX509TrustManager; + private static final Logger logger = Logger.getLogger(XdsTrustManagerFactory.class.getName()); + private XdsX509TrustManager xdsX509TrustManager; /** Constructor constructs from a {@link CertificateValidationContext}. */ - public SdsTrustManagerFactory(CertificateValidationContext certificateValidationContext) + public XdsTrustManagerFactory(CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException { this( getTrustedCaFromCertContext(certificateValidationContext), @@ -59,13 +57,13 @@ public SdsTrustManagerFactory(CertificateValidationContext certificateValidation false); } - public SdsTrustManagerFactory( + public XdsTrustManagerFactory( X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext) throws CertStoreException { this(certs, staticCertificateValidationContext, true); } - private SdsTrustManagerFactory( + private XdsTrustManagerFactory( X509Certificate[] certs, CertificateValidationContext certificateValidationContext, boolean validationContextIsStatic) @@ -75,7 +73,7 @@ private SdsTrustManagerFactory( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); } - sdsX509TrustManager = createSdsX509TrustManager(certs, certificateValidationContext); + xdsX509TrustManager = createSdsX509TrustManager(certs, certificateValidationContext); } private static X509Certificate[] getTrustedCaFromCertContext( @@ -100,7 +98,7 @@ private static X509Certificate[] getTrustedCaFromCertContext( } @VisibleForTesting - static SdsX509TrustManager createSdsX509TrustManager( + static XdsX509TrustManager createSdsX509TrustManager( X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException { TrustManagerFactory tmf = null; try { @@ -133,7 +131,7 @@ static SdsX509TrustManager createSdsX509TrustManager( if (myDelegate == null) { throw new CertStoreException("Native X509 TrustManager not found."); } - return new SdsX509TrustManager(certContext, myDelegate); + return new XdsX509TrustManager(certContext, myDelegate); } @Override @@ -148,6 +146,6 @@ protected void engineInit(ManagerFactoryParameters managerFactoryParameters) thr @Override protected TrustManager[] engineGetTrustManagers() { - return new TrustManager[] {sdsX509TrustManager}; + return new TrustManager[] {xdsX509TrustManager}; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java similarity index 97% rename from xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java rename to xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 3178d2b3e4b..4bb6f0520c4 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds.trust; +package io.grpc.xds.internal.security.trust; import static com.google.common.base.Preconditions.checkNotNull; @@ -41,7 +41,7 @@ * Extension of {@link X509ExtendedTrustManager} that implements verification of * SANs (subject-alternate-names) against the list in CertificateValidationContext. */ -final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509TrustManager { +final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509TrustManager { // ref: io.grpc.okhttp.internal.OkHostnameVerifier and // sun.security.x509.GeneralNameInterface @@ -52,8 +52,8 @@ final class SdsX509TrustManager extends X509ExtendedTrustManager implements X509 private final X509ExtendedTrustManager delegate; private final CertificateValidationContext certContext; - SdsX509TrustManager(@Nullable CertificateValidationContext certContext, - X509ExtendedTrustManager delegate) { + XdsX509TrustManager(@Nullable CertificateValidationContext certContext, + X509ExtendedTrustManager delegate) { checkNotNull(delegate, "delegate"); this.certContext = certContext; this.delegate = delegate; diff --git a/xds/src/main/java/io/grpc/xds/OrcaMetricReportingServerInterceptor.java b/xds/src/main/java/io/grpc/xds/orca/OrcaMetricReportingServerInterceptor.java similarity index 76% rename from xds/src/main/java/io/grpc/xds/OrcaMetricReportingServerInterceptor.java rename to xds/src/main/java/io/grpc/xds/orca/OrcaMetricReportingServerInterceptor.java index 9c79ed11bc3..729277072c2 100644 --- a/xds/src/main/java/io/grpc/xds/OrcaMetricReportingServerInterceptor.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaMetricReportingServerInterceptor.java @@ -14,12 +14,13 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.orca; import com.github.xds.data.orca.v3.OrcaLoadReport; import com.google.common.annotations.VisibleForTesting; import io.grpc.Context; import io.grpc.Contexts; +import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -30,7 +31,7 @@ import io.grpc.protobuf.ProtoUtils; import io.grpc.services.CallMetricRecorder; import io.grpc.services.InternalCallMetricRecorder; -import java.util.Map; +import io.grpc.services.MetricReport; /** * A {@link ServerInterceptor} that intercepts a {@link ServerCall} by running server-side RPC @@ -40,7 +41,8 @@ * * @since 1.23.0 */ -final class OrcaMetricReportingServerInterceptor implements ServerInterceptor { +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9127") +public final class OrcaMetricReportingServerInterceptor implements ServerInterceptor { private static final OrcaMetricReportingServerInterceptor INSTANCE = new OrcaMetricReportingServerInterceptor(); @@ -48,7 +50,7 @@ final class OrcaMetricReportingServerInterceptor implements ServerInterceptor { @VisibleForTesting static final Metadata.Key ORCA_ENDPOINT_LOAD_METRICS_KEY = Metadata.Key.of( - "x-endpoint-load-metrics-bin", + "endpoint-load-metrics-bin", ProtoUtils.metadataMarshaller(OrcaLoadReport.getDefaultInstance())); @VisibleForTesting @@ -73,12 +75,9 @@ public Listener interceptCall( new SimpleForwardingServerCall(call) { @Override public void close(Status status, Metadata trailers) { - Map metricValues = - InternalCallMetricRecorder.finalizeAndDump(finalCallMetricRecorder); - // Only attach a metric report if there are some metric values to be reported. - if (!metricValues.isEmpty()) { - OrcaLoadReport report = - OrcaLoadReport.newBuilder().putAllRequestCost(metricValues).build(); + OrcaLoadReport report = fromInternalReport( + InternalCallMetricRecorder.finalizeAndDump2(finalCallMetricRecorder)); + if (!report.equals(OrcaLoadReport.getDefaultInstance())) { trailers.put(ORCA_ENDPOINT_LOAD_METRICS_KEY, report); } super.close(status, trailers); @@ -90,4 +89,13 @@ public void close(Status status, Metadata trailers) { headers, next); } + + private static OrcaLoadReport fromInternalReport(MetricReport internalReport) { + return OrcaLoadReport.newBuilder() + .setCpuUtilization(internalReport.getCpuUtilization()) + .setMemUtilization(internalReport.getMemoryUtilization()) + .putAllUtilization(internalReport.getUtilizationMetrics()) + .putAllRequestCost(internalReport.getRequestCostMetrics()) + .build(); + } } diff --git a/xds/src/main/java/io/grpc/xds/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java similarity index 72% rename from xds/src/main/java/io/grpc/xds/OrcaOobUtil.java rename to xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index 8970a68bf65..016c4ba0eb5 100644 --- a/xds/src/main/java/io/grpc/xds/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -14,13 +14,12 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.orca; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; -import static io.grpc.ConnectivityState.SHUTDOWN; import com.github.xds.data.orca.v3.OrcaLoadReport; import com.github.xds.service.orca.v3.OpenRcaServiceGrpc; @@ -31,12 +30,14 @@ import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; import com.google.protobuf.util.Durations; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; import io.grpc.ConnectivityStateInfo; +import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -50,14 +51,11 @@ import io.grpc.internal.BackoffPolicy; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; +import io.grpc.services.MetricReport; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; -import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.logging.Level; @@ -66,36 +64,17 @@ /** * Utility class that provides method for {@link LoadBalancer} to install listeners to receive - * out-of-band backend cost metrics in the format of Open Request Cost Aggregation (ORCA). + * out-of-band backend metrics in the format of Open Request Cost Aggregation (ORCA). */ -abstract class OrcaOobUtil { - +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9129") +public final class OrcaOobUtil { private static final Logger logger = Logger.getLogger(OrcaPerRequestUtil.class.getName()); - private static final OrcaOobUtil DEFAULT_INSTANCE = - new OrcaOobUtil() { - - @Override - public OrcaReportingHelperWrapper newOrcaReportingHelperWrapper( - LoadBalancer.Helper delegate, - OrcaOobReportListener listener) { - return newOrcaReportingHelperWrapper( - delegate, - listener, - new ExponentialBackoffPolicy.Provider(), - GrpcUtil.STOPWATCH_SUPPLIER); - } - }; - /** - * Gets an {@code OrcaOobUtil} instance that provides actual implementation of - * {@link #newOrcaReportingHelperWrapper}. - */ - public static OrcaOobUtil getInstance() { - return DEFAULT_INSTANCE; - } + private OrcaOobUtil() {} /** - * Creates a new {@link LoadBalancer.Helper} with provided {@link OrcaOobReportListener} installed + * Creates a new {@link io.grpc.LoadBalancer.Helper} with provided + * {@link OrcaOobReportListener} installed * to receive callback when an out-of-band ORCA report is received. * *

    Example usages: @@ -109,12 +88,14 @@ public static OrcaOobUtil getInstance() { * * public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { * // listener implements the logic for WRR's usage of backend metrics. - * OrcaReportingHelperWrapper orcaWrapper = - * OrcaOobUtil.getInstance().newOrcaReportingHelperWrapper(originHelper, listener); - * orcaWrapper.setReportingConfig( - * OrcaRerportingConfig.newBuilder().setReportInterval(30, SECOND).build()); + * OrcaReportingHelper orcaHelper = + * OrcaOobUtil.newOrcaReportingHelper(originHelper); * Subchannel subchannel = - * orcaWrapper.asHelper().createSubchannel(CreateSubchannelArgs.newBuilder()...); + * orcaHelper.createSubchannel(CreateSubchannelArgs.newBuilder()...); + * OrcaOobUtil.setListener( + * subchannel, + * listener, + * OrcaRerportingConfig.newBuilder().setReportInterval(30, SECOND).build()); * ... * } * } @@ -125,8 +106,11 @@ public static OrcaOobUtil getInstance() { *

        *       {@code
        *       class XdsLoadBalancer extends LoadBalancer {
    -   *         private final Helper originHelper;  // the original Helper
    +   *         private final Helper orcaHelper;  // the original Helper
        *
    +   *         public XdsLoadBalancer(LoadBalancer.Helper helper) {
    +   *           this.orcaHelper = OrcaUtil.newOrcaReportingHelper(helper);
    +   *         }
        *         private void createChildPolicy(
        *             Locality locality, LoadBalancerProvider childPolicyProvider) {
        *           // Each Locality has a child policy, and the parent does per-locality aggregation by
    @@ -134,11 +118,18 @@ public static OrcaOobUtil getInstance() {
        *
        *           // Create an OrcaReportingHelperWrapper for each Locality.
        *           // listener implements the logic for locality-level backend metric aggregation.
    -   *           OrcaReportingHelperWrapper orcaWrapper =
    -   *               OrcaOobUtil.getInstance().newOrcaReportingHelperWrapper(originHelper, listener);
    -   *           orcaWrapper.setReportingConfig(
    -   *               OrcaRerportingConfig.newBuilder().setReportInterval(30, SECOND).build());
    -   *           LoadBalancer childLb = childPolicyProvider.newLoadBalancer(orcaWrapper.asHelper());
    +   *           LoadBalancer childLb = childPolicyProvider.newLoadBalancer(
    +   *             new ForwardingLoadBalancerHelper() {
    +   *               public Subchannel createSubchannel(CreateSubchannelArgs args) {
    +   *                 Subchannel subchannel = super.createSubchannel(args);
    +   *                 OrcaOobUtil.setListener(subchannel, listener,
    +   *                 OrcaReportingConfig.newBuilder().setReportInterval(30, SECOND).build());
    +   *                 return subchannel;
    +   *               }
    +   *               public LoadBalancer.Helper delegate() {
    +   *                 return orcaHelper;
    +   *               }
    +   *             });
        *         }
        *       }
        *       }
    @@ -148,33 +139,20 @@ public static OrcaOobUtil getInstance() {
        *
        * @param delegate the delegate helper that provides essentials for establishing subchannels to
        *     backends.
    -   * @param listener contains the callback to be invoked when an out-of-band ORCA report is
    -   *     received.
        */
    -  public abstract OrcaReportingHelperWrapper newOrcaReportingHelperWrapper(
    -      LoadBalancer.Helper delegate,
    -      OrcaOobReportListener listener);
    +  public static LoadBalancer.Helper newOrcaReportingHelper(LoadBalancer.Helper delegate) {
    +    return newOrcaReportingHelper(
    +        delegate,
    +        new ExponentialBackoffPolicy.Provider(),
    +        GrpcUtil.STOPWATCH_SUPPLIER);
    +  }
     
       @VisibleForTesting
    -  static OrcaReportingHelperWrapper newOrcaReportingHelperWrapper(
    +  static LoadBalancer.Helper newOrcaReportingHelper(
           LoadBalancer.Helper delegate,
    -      OrcaOobReportListener listener,
           BackoffPolicy.Provider backoffPolicyProvider,
           Supplier stopwatchSupplier) {
    -    final OrcaReportingHelper orcaHelper =
    -        new OrcaReportingHelper(delegate, listener, backoffPolicyProvider, stopwatchSupplier);
    -
    -    return new OrcaReportingHelperWrapper() {
    -      @Override
    -      public void setReportingConfig(OrcaReportingConfig config) {
    -        orcaHelper.setReportingConfig(config);
    -      }
    -
    -      @Override
    -      public Helper asHelper() {
    -        return orcaHelper;
    -      }
    -    };
    +    return new OrcaReportingHelper(delegate, backoffPolicyProvider, stopwatchSupplier);
       }
     
       /**
    @@ -191,66 +169,62 @@ public interface OrcaOobReportListener {
          * 

    Note this callback will be invoked from the {@link SynchronizationContext} of the * delegated helper, implementations should not block. * - * @param report load report in the format of ORCA protocol. + * @param report load report in the format of grpc {@link MetricReport}. */ - void onLoadReport(OrcaLoadReport report); + void onLoadReport(MetricReport report); } + static final Attributes.Key ORCA_REPORTING_STATE_KEY = + Attributes.Key.create("internal-orca-reporting-state"); + /** - * Blueprint for the wrapper that wraps a {@link LoadBalancer.Helper} with the capability of - * allowing {@link LoadBalancer}s interested in receiving out-of-band ORCA reports to update the - * reporting configuration such as reporting interval. + * Update {@link OrcaOobReportListener} to receive Out-of-Band metrics report for the + * particular subchannel connection, and set the configuration of receiving ORCA reports, + * such as the interval of receiving reports. + * + *

    This method needs to be called from the SynchronizationContext returned by the wrapped + * helper's {@link Helper#getSynchronizationContext()}. + * + *

    Each load balancing policy must call this method to configure the backend load reporting. + * Otherwise, it will not receive ORCA reports. + * + *

    If multiple load balancing policies configure reporting with different intervals, reports + * come with the minimum of those intervals. + * + * @param subchannel the server connected by this subchannel to receive the metrics. + * + * @param listener the callback upon receiving backend metrics from the Out-Of-Band stream. + * + * @param config the configuration to be set. + * */ - public abstract static class OrcaReportingHelperWrapper { - - /** - * Sets the configuration of receiving ORCA reports, such as the interval of receiving reports. - * - *

    This method needs to be called from the SynchronizationContext returned by the wrapped - * helper's {@link Helper#getSynchronizationContext()}. - * - *

    Each load balancing policy must call this method to configure the backend load reporting. - * Otherwise, it will not receive ORCA reports. - * - *

    If multiple load balancing policies configure reporting with different intervals, reports - * come with the minimum of those intervals. - * - * @param config the configuration to be set. - */ - public abstract void setReportingConfig(OrcaReportingConfig config); - - /** - * Returns a wrapped {@link LoadBalancer.Helper}. Subchannels created through it will retrieve - * ORCA load reports if the server supports it. - */ - public abstract LoadBalancer.Helper asHelper(); + public static void setListener(Subchannel subchannel, OrcaOobReportListener listener, + OrcaReportingConfig config) { + SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); + if (orcaSubchannel == null) { + throw new IllegalArgumentException("Subchannel does not have orca Out-Of-Band stream enabled." + + " Try to use a subchannel created by OrcaOobUtil.OrcaHelper."); + } + orcaSubchannel.orcaState.setListener(orcaSubchannel, listener, config); } /** * An {@link OrcaReportingHelper} wraps a delegated {@link LoadBalancer.Helper} with additional * functionality to manage RPCs for out-of-band ORCA reporting for each backend it establishes - * connection to. + * connection to. Subchannels created through it will retrieve ORCA load reports if the server + * supports it. */ - private static final class OrcaReportingHelper extends ForwardingLoadBalancerHelper - implements OrcaOobReportListener { - - private static final CreateSubchannelArgs.Key ORCA_REPORTING_STATE_KEY = - CreateSubchannelArgs.Key.create("internal-orca-reporting-state"); + static final class OrcaReportingHelper extends ForwardingLoadBalancerHelper { private final LoadBalancer.Helper delegate; - private final OrcaOobReportListener listener; private final SynchronizationContext syncContext; private final BackoffPolicy.Provider backoffPolicyProvider; private final Supplier stopwatchSupplier; - private final Set orcaStates = new HashSet<>(); - @Nullable private OrcaReportingConfig orcaConfig; OrcaReportingHelper( LoadBalancer.Helper delegate, - OrcaOobReportListener listener, BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier) { this.delegate = checkNotNull(delegate, "delegate"); - this.listener = checkNotNull(listener, "listener"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); syncContext = checkNotNull(delegate.getSynchronizationContext(), "syncContext"); @@ -264,42 +238,17 @@ protected Helper delegate() { @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); - OrcaReportingState orcaState = args.getOption(ORCA_REPORTING_STATE_KEY); - boolean augmented = false; - if (orcaState == null) { + Subchannel subchannel = super.createSubchannel(args); + SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); + OrcaReportingState orcaState; + if (orcaSubchannel == null) { // Only the first load balancing policy requesting ORCA reports instantiates an // OrcaReportingState. - orcaState = new OrcaReportingState(this, syncContext, - delegate().getScheduledExecutorService()); - args = args.toBuilder().addOption(ORCA_REPORTING_STATE_KEY, orcaState).build(); - augmented = true; - } - orcaStates.add(orcaState); - orcaState.listeners.add(this); - Subchannel subchannel = super.createSubchannel(args); - if (augmented) { - subchannel = new SubchannelImpl(subchannel, orcaState); - } - if (orcaConfig != null) { - orcaState.setReportingConfig(this, orcaConfig); - } - return subchannel; - } - - void setReportingConfig(final OrcaReportingConfig config) { - syncContext.throwIfNotInThisSynchronizationContext(); - orcaConfig = config; - for (OrcaReportingState state : orcaStates) { - state.setReportingConfig(OrcaReportingHelper.this, config); - } - } - - @Override - public void onLoadReport(OrcaLoadReport report) { - syncContext.throwIfNotInThisSynchronizationContext(); - if (orcaConfig != null) { - listener.onLoadReport(report); + orcaState = new OrcaReportingState(syncContext, delegate().getScheduledExecutorService()); + } else { + orcaState = orcaSubchannel.orcaState; } + return new SubchannelImpl(subchannel, orcaState); } /** @@ -309,11 +258,9 @@ public void onLoadReport(OrcaLoadReport report) { */ private final class OrcaReportingState implements SubchannelStateListener { - private final OrcaReportingHelper orcaHelper; private final SynchronizationContext syncContext; private final ScheduledExecutorService timeService; - private final List listeners = new ArrayList<>(); - private final Map configs = new HashMap<>(); + private final Map configs = new HashMap<>(); @Nullable private Subchannel subchannel; @Nullable private ChannelLogger subchannelLogger; @Nullable @@ -332,12 +279,11 @@ public void run() { private ConnectivityStateInfo state = ConnectivityStateInfo.forNonError(IDLE); // True if server returned UNIMPLEMENTED. private boolean disabled; + private boolean started; OrcaReportingState( - OrcaReportingHelper orcaHelper, SynchronizationContext syncContext, ScheduledExecutorService timeService) { - this.orcaHelper = checkNotNull(orcaHelper, "orcaHelper"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.timeService = checkNotNull(timeService, "timeService"); } @@ -347,11 +293,27 @@ void init(Subchannel subchannel, SubchannelStateListener stateListener) { this.subchannel = checkNotNull(subchannel, "subchannel"); this.subchannelLogger = checkNotNull(subchannel.getChannelLogger(), "subchannelLogger"); this.stateListener = checkNotNull(stateListener, "stateListener"); + started = true; } - void setReportingConfig(OrcaReportingHelper helper, OrcaReportingConfig config) { + void setListener(SubchannelImpl orcaSubchannel, OrcaOobReportListener listener, + OrcaReportingConfig config) { + syncContext.execute(new Runnable() { + @Override + public void run() { + OrcaOobReportListener oldListener = orcaSubchannel.reportListener; + if (oldListener != null) { + configs.remove(oldListener); + } + orcaSubchannel.reportListener = listener; + setReportingConfig(listener, config); + } + }); + } + + private void setReportingConfig(OrcaOobReportListener listener, OrcaReportingConfig config) { boolean reconfigured = false; - configs.put(helper, config); + configs.put(listener, config); // Real reporting interval is the minimum of intervals requested by all participating // helpers. if (overallConfig == null) { @@ -383,9 +345,6 @@ public void onSubchannelState(ConnectivityStateInfo newState) { // may be available on the new connection. disabled = false; } - if (Objects.equal(newState.getState(), SHUTDOWN)) { - orcaHelper.orcaStates.remove(this); - } state = newState; adjustOrcaReporting(); // Propagate subchannel state update to downstream listeners. @@ -492,8 +451,9 @@ void handleResponse(OrcaLoadReport response) { callHasResponded = true; backoffPolicy = null; subchannelLogger.log(ChannelLogLevel.DEBUG, "Received an ORCA report: {0}", response); - for (OrcaOobReportListener listener : listeners) { - listener.onLoadReport(response); + MetricReport metricReport = OrcaPerRequestUtil.fromOrcaLoadReport(response); + for (OrcaOobReportListener listener : configs.keySet()) { + listener.onLoadReport(metricReport); } call.request(1); } @@ -547,9 +507,9 @@ public String toString() { @VisibleForTesting static final class SubchannelImpl extends ForwardingSubchannel { - private final Subchannel delegate; private final OrcaReportingHelper.OrcaReportingState orcaState; + @Nullable private OrcaOobReportListener reportListener; SubchannelImpl(Subchannel delegate, OrcaReportingHelper.OrcaReportingState orcaState) { this.delegate = checkNotNull(delegate, "delegate"); @@ -563,8 +523,17 @@ protected Subchannel delegate() { @Override public void start(SubchannelStateListener listener) { - orcaState.init(this, listener); - super.start(orcaState); + if (!orcaState.started) { + orcaState.init(this, listener); + super.start(orcaState); + } else { + super.start(listener); + } + } + + @Override + public Attributes getAttributes() { + return super.getAttributes().toBuilder().set(ORCA_REPORTING_STATE_KEY, this).build(); } } diff --git a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaPerRequestUtil.java similarity index 90% rename from xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java rename to xds/src/main/java/io/grpc/xds/orca/OrcaPerRequestUtil.java index 34589d77d07..0c2c7395b47 100644 --- a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaPerRequestUtil.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.orca; import static com.google.common.base.Preconditions.checkNotNull; @@ -23,10 +23,13 @@ import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.Metadata; import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.protobuf.ProtoUtils; +import io.grpc.services.InternalCallMetricRecorder; +import io.grpc.services.MetricReport; import java.util.ArrayList; import java.util.List; @@ -34,7 +37,8 @@ * Utility class that provides method for {@link LoadBalancer} to install listeners to receive * per-request backend cost metrics in the format of Open Request Cost Aggregation (ORCA). */ -abstract class OrcaPerRequestUtil { +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9128") +public abstract class OrcaPerRequestUtil { private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER = new ClientStreamTracer() {}; private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY = new ClientStreamTracer.Factory() { @@ -67,7 +71,7 @@ public static OrcaPerRequestUtil getInstance() { } /** - * Creates a new {@link ClientStreamTracer.Factory} with provided {@link + * Creates a new {@link io.grpc.ClientStreamTracer.Factory} with provided {@link * OrcaPerRequestReportListener} installed to receive callback when a per-request ORCA report is * received. * @@ -93,7 +97,7 @@ public abstract ClientStreamTracer.Factory newOrcaClientStreamTracerFactory( OrcaPerRequestReportListener listener); /** - * Creates a new {@link ClientStreamTracer.Factory} with provided {@link + * Creates a new {@link io.grpc.ClientStreamTracer.Factory} with provided {@link * OrcaPerRequestReportListener} installed to receive callback when a per-request ORCA report is * received. * @@ -173,14 +177,14 @@ public abstract ClientStreamTracer.Factory newOrcaClientStreamTracerFactory( public interface OrcaPerRequestReportListener { /** - * Invoked when an per-request ORCA report is received. + * Invoked when a per-request ORCA report is received. * *

    Note this callback will be invoked from the network thread as the RPC finishes, * implementations should not block. * - * @param report load report in the format of ORCA format. + * @param report load report in the format of grpc {@link MetricReport}. */ - void onLoadReport(OrcaLoadReport report); + void onLoadReport(MetricReport report); } /** @@ -195,7 +199,7 @@ static final class OrcaReportingTracerFactory extends @VisibleForTesting static final Metadata.Key ORCA_ENDPOINT_LOAD_METRICS_KEY = Metadata.Key.of( - "x-endpoint-load-metrics-bin", + "endpoint-load-metrics-bin", ProtoUtils.metadataMarshaller(OrcaLoadReport.getDefaultInstance())); private static final CallOptions.Key ORCA_REPORT_BROKER_KEY = @@ -248,6 +252,12 @@ public void inboundTrailers(Metadata trailers) { } } + static MetricReport fromOrcaLoadReport(OrcaLoadReport loadReport) { + return InternalCallMetricRecorder.createMetricReport(loadReport.getCpuUtilization(), + loadReport.getMemUtilization(), loadReport.getRequestCostMap(), + loadReport.getUtilizationMap()); + } + /** * A container class to hold registered {@link OrcaPerRequestReportListener}s and invoke all of * them when an {@link OrcaLoadReport} is received. @@ -261,8 +271,9 @@ void addListener(OrcaPerRequestReportListener listener) { } void onReport(OrcaLoadReport report) { + MetricReport metricReport = fromOrcaLoadReport(report); for (OrcaPerRequestReportListener listener : listeners) { - listener.onLoadReport(report); + listener.onLoadReport(metricReport); } } } diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaServiceImpl.java b/xds/src/main/java/io/grpc/xds/orca/OrcaServiceImpl.java new file mode 100644 index 00000000000..1ea64f70bf2 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaServiceImpl.java @@ -0,0 +1,156 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.orca; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.github.xds.data.orca.v3.OrcaLoadReport; +import com.github.xds.service.orca.v3.OpenRcaServiceGrpc; +import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.util.Durations; +import io.grpc.BindableService; +import io.grpc.ServerServiceDefinition; +import io.grpc.SynchronizationContext; +import io.grpc.services.InternalMetricRecorder; +import io.grpc.services.MetricRecorder; +import io.grpc.services.MetricReport; +import io.grpc.stub.ServerCallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Implements a {@link BindableService} that generates Out-Of-Band server metrics. + * Register the returned service to the server, then a client can request for periodic load reports. + */ +public final class OrcaServiceImpl implements BindableService { + private static final Logger logger = Logger.getLogger(OrcaServiceImpl.class.getName()); + + /** + * Empty or invalid (non-positive) minInterval config in will be treated to this default value. + */ + public static final long DEFAULT_MIN_REPORT_INTERVAL_NANOS = TimeUnit.SECONDS.toNanos(30); + + private final long minReportIntervalNanos; + private final ScheduledExecutorService timeService; + @VisibleForTesting + final AtomicInteger clientCount = new AtomicInteger(0); + private MetricRecorder metricRecorder; + private final RealOrcaServiceImpl delegate = new RealOrcaServiceImpl(); + + /** + * Constructs a service to report server metrics. Config the report interval lower bound, the + * executor to run the timer, and a {@link MetricRecorder} that contains metrics data. + * + * @param minInterval configures the minimum metrics reporting interval for the + * service. Bad configuration (non-positive) will be overridden to service default (30s). + * Minimum metrics reporting interval means, if the setting in the client's + * request is invalid (non-positive) or below this value, they will be treated + * as this value. + */ + public static BindableService createService(ScheduledExecutorService timeService, + MetricRecorder metricsRecorder, + long minInterval, TimeUnit timeUnit) { + return new OrcaServiceImpl(minInterval, timeUnit, timeService, metricsRecorder); + } + + public static BindableService createService(ScheduledExecutorService timeService, + MetricRecorder metricRecorder) { + return new OrcaServiceImpl(DEFAULT_MIN_REPORT_INTERVAL_NANOS, TimeUnit.NANOSECONDS, + timeService, metricRecorder); + } + + private OrcaServiceImpl(long minInterval, TimeUnit timeUnit, ScheduledExecutorService timeService, + MetricRecorder orcaMetrics) { + this.minReportIntervalNanos = minInterval > 0 ? timeUnit.toNanos(minInterval) + : DEFAULT_MIN_REPORT_INTERVAL_NANOS; + this.timeService = checkNotNull(timeService, "timeService"); + this.metricRecorder = checkNotNull(orcaMetrics, "orcaMetrics"); + } + + @Override + public ServerServiceDefinition bindService() { + return delegate.bindService(); + } + + private final class RealOrcaServiceImpl extends OpenRcaServiceGrpc.OpenRcaServiceImplBase { + @Override + public void streamCoreMetrics( + OrcaLoadReportRequest request, StreamObserver responseObserver) { + OrcaClient client = new OrcaClient(request, responseObserver); + client.run(); + clientCount.getAndIncrement(); + } + } + + private final class OrcaClient implements Runnable { + final ServerCallStreamObserver responseObserver; + SynchronizationContext.ScheduledHandle periodicReportTimer; + final long reportIntervalNanos; + final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log(Level.SEVERE, "Exception!" + e); + } + }); + + OrcaClient(OrcaLoadReportRequest request, StreamObserver responseObserver) { + this.reportIntervalNanos = Math.max(Durations.toNanos( + checkNotNull(request).getReportInterval()), minReportIntervalNanos); + this.responseObserver = (ServerCallStreamObserver) responseObserver; + this.responseObserver.setOnCancelHandler(new Runnable() { + @Override + public void run() { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (periodicReportTimer != null) { + periodicReportTimer.cancel(); + } + clientCount.getAndDecrement(); + } + }); + } + }); + } + + @Override + public void run() { + if (periodicReportTimer != null && periodicReportTimer.isPending()) { + return; + } + OrcaLoadReport report = generateMetricsReport(); + responseObserver.onNext(report); + periodicReportTimer = syncContext.schedule(OrcaClient.this, reportIntervalNanos, + TimeUnit.NANOSECONDS, timeService); + } + } + + private OrcaLoadReport generateMetricsReport() { + MetricReport internalReport = + InternalMetricRecorder.getMetricReport(metricRecorder); + return OrcaLoadReport.newBuilder().setCpuUtilization(internalReport.getCpuUtilization()) + .setMemUtilization(internalReport.getMemoryUtilization()) + .putAllUtilization(internalReport.getUtilizationMetrics()) + .build(); + } +} diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index 8e5c2dd1c6a..6b6e3a392a9 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -6,3 +6,4 @@ io.grpc.xds.ClusterResolverLoadBalancerProvider io.grpc.xds.ClusterImplLoadBalancerProvider io.grpc.xds.LeastRequestLoadBalancerProvider io.grpc.xds.RingHashLoadBalancerProvider +io.grpc.xds.WrrLocalityLoadBalancerProvider diff --git a/xds/src/main/resources/META-INF/services/io.grpc.NameResolverProvider b/xds/src/main/resources/META-INF/services/io.grpc.NameResolverProvider index c1f2c40e7ee..269cdd38801 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.NameResolverProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.NameResolverProvider @@ -1,2 +1 @@ io.grpc.xds.XdsNameResolverProvider -io.grpc.xds.GoogleCloudToProdNameResolverProvider diff --git a/xds/src/main/resources/META-INF/services/io.grpc.xds.XdsCredentialsProvider b/xds/src/main/resources/META-INF/services/io.grpc.xds.XdsCredentialsProvider new file mode 100644 index 00000000000..a51cd114737 --- /dev/null +++ b/xds/src/main/resources/META-INF/services/io.grpc.xds.XdsCredentialsProvider @@ -0,0 +1,3 @@ +io.grpc.xds.internal.GoogleDefaultXdsCredentialsProvider +io.grpc.xds.internal.InsecureXdsCredentialsProvider +io.grpc.xds.internal.TlsXdsCredentialsProvider \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/BootstrapperImplTest.java b/xds/src/test/java/io/grpc/xds/BootstrapperImplTest.java index 53b52a7bc02..7b263b27f20 100644 --- a/xds/src/test/java/io/grpc/xds/BootstrapperImplTest.java +++ b/xds/src/test/java/io/grpc/xds/BootstrapperImplTest.java @@ -578,7 +578,7 @@ public void useV2ProtocolByDefault() throws XdsInitializationException { ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); assertThat(serverInfo.target()).isEqualTo(SERVER_URI); assertThat(serverInfo.channelCredentials()).isInstanceOf(InsecureChannelCredentials.class); - assertThat(serverInfo.useProtocolV3()).isFalse(); + assertThat(serverInfo.ignoreResourceDeletion()).isFalse(); } @Test @@ -600,7 +600,53 @@ public void useV3ProtocolIfV3FeaturePresent() throws XdsInitializationException ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); assertThat(serverInfo.target()).isEqualTo(SERVER_URI); assertThat(serverInfo.channelCredentials()).isInstanceOf(InsecureChannelCredentials.class); - assertThat(serverInfo.useProtocolV3()).isTrue(); + assertThat(serverInfo.ignoreResourceDeletion()).isFalse(); + } + + @Test + public void serverFeatureIgnoreResourceDeletion() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"ignore_resource_deletion\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.channelCredentials()).isInstanceOf(InsecureChannelCredentials.class); + // Only ignore_resource_deletion feature enabled: confirm it's on, and xds_v3 is off. + assertThat(serverInfo.ignoreResourceDeletion()).isTrue(); + } + + @Test + public void serverFeatureIgnoreResourceDeletion_xdsV3() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"xds_v3\", \"ignore_resource_deletion\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.channelCredentials()).isInstanceOf(InsecureChannelCredentials.class); + // ignore_resource_deletion features enabled: confirm both are on. + assertThat(serverInfo.ignoreResourceDeletion()).isTrue(); } @Test @@ -827,6 +873,7 @@ private static Node.Builder getNodeBuilder() { .setBuildVersion(buildVersion.toString()) .setUserAgentName(buildVersion.getUserAgent()) .setUserAgentVersion(buildVersion.getImplementationVersion()) - .addClientFeatures(BootstrapperImpl.CLIENT_FEATURE_DISABLE_OVERPROVISIONING); + .addClientFeatures(BootstrapperImpl.CLIENT_FEATURE_DISABLE_OVERPROVISIONING) + .addClientFeatures(BootstrapperImpl.CLIENT_FEATURE_RESOURCE_IN_SOTW); } } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 78e6d6473ca..1818f39dd80 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -25,6 +26,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.Attributes; import io.grpc.ConnectivityState; @@ -39,6 +41,7 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; @@ -47,11 +50,13 @@ import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; +import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; -import io.grpc.xds.XdsClient.CdsUpdate; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -79,15 +84,19 @@ public class CdsLoadBalancer2Test { private static final String EDS_SERVICE_NAME = "backend-service-1.googleapis.com"; private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443"; private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create(), true); + ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); private final UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + private final OutlierDetection outlierDetection = OutlierDetection.create( + null, null, null, null, SuccessRateEjection.create(null, null, null, null), null); - private final SynchronizationContext syncContext = new SynchronizationContext( + + private static final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @Override public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); + throw new RuntimeException(e); + //throw new AssertionError(e); } }); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); @@ -121,10 +130,12 @@ public void setUp() { when(helper.getSynchronizationContext()).thenReturn(syncContext); lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_RESOLVER_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider("round_robin")); - lbRegistry.register(new FakeLoadBalancerProvider("ring_hash_experimental")); - lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental")); + lbRegistry.register( + new FakeLoadBalancerProvider("ring_hash_experimental", new RingHashLoadBalancerProvider())); + lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental", + new LeastRequestLoadBalancerProvider())); loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setAttributes( @@ -148,7 +159,8 @@ public void tearDown() { @Test public void discoverTopLevelEdsCluster() { CdsUpdate update = - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext) + CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, + outlierDetection) .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); @@ -158,7 +170,7 @@ public void discoverTopLevelEdsCluster() { assertThat(childLbConfig.discoveryMechanisms).hasSize(1); DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext); + null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); } @@ -175,7 +187,7 @@ public void discoverTopLevelLogicalDnsCluster() { assertThat(childLbConfig.discoveryMechanisms).hasSize(1); DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext); + DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, null); assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()) .isEqualTo("least_request_experimental"); assertThat(((LeastRequestConfig) childLbConfig.lbPolicy.getConfig()).choiceCount).isEqualTo(3); @@ -195,7 +207,7 @@ public void nonAggregateCluster_resourceNotExist_returnErrorPicker() { @Test public void nonAggregateCluster_resourceUpdate() { CdsUpdate update = - CdsUpdate.forEds(CLUSTER, null, null, 100L, upstreamTlsContext) + CdsUpdate.forEds(CLUSTER, null, null, 100L, upstreamTlsContext, outlierDetection) .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); @@ -203,15 +215,15 @@ public void nonAggregateCluster_resourceUpdate() { ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, null, null, null, - 100L, upstreamTlsContext); + 100L, upstreamTlsContext, outlierDetection); - update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, null) - .roundRobinLbPolicy().build(); + update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, null, + outlierDetection).roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); childLbConfig = (ClusterResolverConfig) childBalancer.config; instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 200L, null); + null, LRS_SERVER_INFO, 200L, null, outlierDetection); } @Test @@ -225,7 +237,7 @@ public void nonAggregateCluster_resourceRevoked() { ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, null, 100L, upstreamTlsContext); + DNS_HOST_NAME, null, 100L, upstreamTlsContext, null); xdsClient.deliverResourceNotExist(CLUSTER); assertThat(childBalancer.shutdown).isTrue(); @@ -259,9 +271,8 @@ public void discoverAggregateCluster() { assertThat(xdsClient.watchers.keySet()).containsExactly( CLUSTER, cluster1, cluster2, cluster3, cluster4); assertThat(childBalancers).isEmpty(); - CdsUpdate update3 = - CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, upstreamTlsContext) - .roundRobinLbPolicy().build(); + CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster3, update3); assertThat(childBalancers).isEmpty(); CdsUpdate update2 = @@ -270,7 +281,7 @@ public void discoverAggregateCluster() { xdsClient.deliverCdsUpdate(cluster2, update2); assertThat(childBalancers).isEmpty(); CdsUpdate update4 = - CdsUpdate.forEds(cluster4, null, LRS_SERVER_INFO, 300L, null) + CdsUpdate.forEds(cluster4, null, LRS_SERVER_INFO, 300L, null, outlierDetection) .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster4, update4); assertThat(childBalancers).hasSize(1); // all non-aggregate clusters discovered @@ -280,12 +291,12 @@ public void discoverAggregateCluster() { assertThat(childLbConfig.discoveryMechanisms).hasSize(3); // Clusters on higher level has higher priority: [cluster2, cluster3, cluster4] assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, null, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, null, 100L, null, null); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext); + upstreamTlsContext, outlierDetection); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(2), cluster4, - DiscoveryMechanism.Type.EDS, null, null, LRS_SERVER_INFO, 300L, null); + DiscoveryMechanism.Type.EDS, null, null, LRS_SERVER_INFO, 300L, null, outlierDetection); assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()) .isEqualTo("ring_hash_experimental"); // dominated by top-level cluster's config assertThat(((RingHashConfig) childLbConfig.lbPolicy.getConfig()).minRingSize).isEqualTo(100L); @@ -320,9 +331,8 @@ public void aggregateCluster_descendantClustersRevoked() { .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = - CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, upstreamTlsContext) - .roundRobinLbPolicy().build(); + CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster1, update1); CdsUpdate update2 = CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null) @@ -333,9 +343,10 @@ public void aggregateCluster_descendantClustersRevoked() { assertThat(childLbConfig.discoveryMechanisms).hasSize(2); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext); + upstreamTlsContext, outlierDetection); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, + null); // Revoke cluster1, should still be able to proceed with cluster2. xdsClient.deliverResourceNotExist(cluster1); @@ -343,7 +354,8 @@ public void aggregateCluster_descendantClustersRevoked() { childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanisms).hasSize(1); assertDiscoveryMechanism(Iterables.getOnlyElement(childLbConfig.discoveryMechanisms), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, + null); verify(helper, never()).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); @@ -368,9 +380,8 @@ public void aggregateCluster_rootClusterRevoked() { .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = - CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, upstreamTlsContext) - .roundRobinLbPolicy().build(); + CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster1, update1); CdsUpdate update2 = CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null) @@ -381,9 +392,10 @@ public void aggregateCluster_rootClusterRevoked() { assertThat(childLbConfig.discoveryMechanisms).hasSize(2); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext); + upstreamTlsContext, outlierDetection); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, + null); xdsClient.deliverResourceNotExist(CLUSTER); assertThat(xdsClient.watchers.keySet()) @@ -422,16 +434,15 @@ public void aggregateCluster_intermediateClusterChanges() { .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster2, update2); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2, cluster3); - CdsUpdate update3 = - CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext) - .roundRobinLbPolicy().build(); + CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster3, update3); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanisms).hasSize(1); DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext); + null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); // cluster2 revoked xdsClient.deliverResourceNotExist(cluster2); @@ -459,7 +470,10 @@ public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicke xdsClient.deliverError(error); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), error, null); + Status expectedError = Status.UNAVAILABLE.withDescription( + "Unable to load CDS cluster-foo.googleapis.com. xDS server returned: " + + "RESOURCE_EXHAUSTED: OOM"); + assertPicker(pickerCaptor.getValue(), expectedError, null); assertThat(childBalancers).isEmpty(); } @@ -481,7 +495,8 @@ public void aggregateCluster_discoveryErrorAfterChildLbCreated_propagateToChildL Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); xdsClient.deliverError(error); - assertThat(childLb.upstreamError).isEqualTo(error); + assertThat(childLb.upstreamError.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(childLb.upstreamError.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); assertThat(childLb.shutdown).isFalse(); // child LB may choose to keep working } @@ -496,9 +511,8 @@ public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErr @Test public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { - CdsUpdate update = - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext) - .roundRobinLbPolicy().build(); + CdsUpdate update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.shutdown).isFalse(); @@ -509,6 +523,35 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug any(ConnectivityState.class), any(SubchannelPicker.class)); } + @Test + public void unknownLbProvider() { + try { + xdsClient.deliverCdsUpdate(CLUSTER, + CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, + outlierDetection) + .lbPolicyConfig(ImmutableMap.of("unknown", ImmutableMap.of("foo", "bar"))).build()); + } catch (Exception e) { + assertThat(e).hasCauseThat().hasMessageThat().contains("No provider available"); + return; + } + fail("Expected the unknown LB to cause an exception"); + } + + @Test + public void invalidLbConfig() { + try { + xdsClient.deliverCdsUpdate(CLUSTER, + CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, + outlierDetection).lbPolicyConfig( + ImmutableMap.of("ring_hash_experimental", ImmutableMap.of("minRingSize", "-1"))) + .build()); + } catch (Exception e) { + assertThat(e).hasCauseThat().hasMessageThat().contains("Unable to parse"); + return; + } + fail("Expected the invalid config to casue an exception"); + } + private static void assertPicker(SubchannelPicker picker, Status expectedStatus, @Nullable Subchannel expectedSubchannel) { PickResult result = picker.pickSubchannel(mock(PickSubchannelArgs.class)); @@ -523,7 +566,7 @@ private static void assertPicker(SubchannelPicker picker, Status expectedStatus, private static void assertDiscoveryMechanism(DiscoveryMechanism instance, String name, DiscoveryMechanism.Type type, @Nullable String edsServiceName, @Nullable String dnsHostName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { + @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) { assertThat(instance.cluster).isEqualTo(name); assertThat(instance.type).isEqualTo(type); assertThat(instance.edsServiceName).isEqualTo(edsServiceName); @@ -531,18 +574,25 @@ private static void assertDiscoveryMechanism(DiscoveryMechanism instance, String assertThat(instance.lrsServerInfo).isEqualTo(lrsServerInfo); assertThat(instance.maxConcurrentRequests).isEqualTo(maxConcurrentRequests); assertThat(instance.tlsContext).isEqualTo(tlsContext); + assertThat(instance.outlierDetection).isEqualTo(outlierDetection); } private final class FakeLoadBalancerProvider extends LoadBalancerProvider { private final String policyName; + private final LoadBalancerProvider configParsingDelegate; FakeLoadBalancerProvider(String policyName) { + this(policyName, null); + } + + FakeLoadBalancerProvider(String policyName, LoadBalancerProvider configParsingDelegate) { this.policyName = policyName; + this.configParsingDelegate = configParsingDelegate; } @Override public LoadBalancer newLoadBalancer(Helper helper) { - FakeLoadBalancer balancer = new FakeLoadBalancer(policyName, helper); + FakeLoadBalancer balancer = new FakeLoadBalancer(policyName); childBalancers.add(balancer); return balancer; } @@ -561,18 +611,25 @@ public int getPriority() { public String getPolicyName() { return policyName; } + + @Override + public NameResolver.ConfigOrError parseLoadBalancingPolicyConfig( + Map rawLoadBalancingPolicyConfig) { + if (configParsingDelegate != null) { + return configParsingDelegate.parseLoadBalancingPolicyConfig(rawLoadBalancingPolicyConfig); + } + return super.parseLoadBalancingPolicyConfig(rawLoadBalancingPolicyConfig); + } } private final class FakeLoadBalancer extends LoadBalancer { private final String name; - private final Helper helper; private Object config; private Status upstreamError; private boolean shutdown; - FakeLoadBalancer(String name, Helper helper) { + FakeLoadBalancer(String name) { this.name = name; - this.helper = helper; } @Override @@ -590,29 +647,26 @@ public void shutdown() { shutdown = true; childBalancers.remove(this); } - - void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) { - SubchannelPicker picker = new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel); - } - }; - helper.updateBalancingState(state, picker); - } } - private static final class FakeXdsClient extends XdsClient { - private final Map watchers = new HashMap<>(); + private final class FakeXdsClient extends XdsClient { + private final Map> watchers = new HashMap<>(); @Override - void watchCdsResource(String resourceName, CdsResourceWatcher watcher) { + @SuppressWarnings("unchecked") + void watchXdsResource(XdsResourceType type, String resourceName, + ResourceWatcher watcher) { + assertThat(type.typeName()).isEqualTo("CDS"); assertThat(watchers).doesNotContainKey(resourceName); - watchers.put(resourceName, watcher); + watchers.put(resourceName, (ResourceWatcher)watcher); } @Override - void cancelCdsResourceWatch(String resourceName, CdsResourceWatcher watcher) { + @SuppressWarnings("unchecked") + void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { + assertThat(type.typeName()).isEqualTo("CDS"); assertThat(watchers).containsKey(resourceName); watchers.remove(resourceName); } @@ -630,7 +684,7 @@ private void deliverResourceNotExist(String clusterName) { } private void deliverError(Status error) { - for (CdsResourceWatcher watcher : watchers.values()) { + for (ResourceWatcher watcher : watchers.values()) { watcher.onError(error); } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java deleted file mode 100644 index 29c7fdc4c01..00000000000 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ /dev/null @@ -1,746 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; - -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.protobuf.Any; -import com.google.protobuf.Message; -import com.google.protobuf.UInt32Value; -import com.google.protobuf.UInt64Value; -import com.google.protobuf.util.Durations; -import com.google.rpc.Code; -import io.envoyproxy.envoy.api.v2.Cluster; -import io.envoyproxy.envoy.api.v2.Cluster.CustomClusterType; -import io.envoyproxy.envoy.api.v2.Cluster.DiscoveryType; -import io.envoyproxy.envoy.api.v2.Cluster.EdsClusterConfig; -import io.envoyproxy.envoy.api.v2.Cluster.LbPolicy; -import io.envoyproxy.envoy.api.v2.Cluster.LeastRequestLbConfig; -import io.envoyproxy.envoy.api.v2.Cluster.RingHashLbConfig; -import io.envoyproxy.envoy.api.v2.Cluster.RingHashLbConfig.HashFunction; -import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; -import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment.Policy; -import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment.Policy.DropOverload; -import io.envoyproxy.envoy.api.v2.DiscoveryRequest; -import io.envoyproxy.envoy.api.v2.DiscoveryResponse; -import io.envoyproxy.envoy.api.v2.Listener; -import io.envoyproxy.envoy.api.v2.RouteConfiguration; -import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; -import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; -import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; -import io.envoyproxy.envoy.api.v2.cluster.CircuitBreakers; -import io.envoyproxy.envoy.api.v2.cluster.CircuitBreakers.Thresholds; -import io.envoyproxy.envoy.api.v2.core.Address; -import io.envoyproxy.envoy.api.v2.core.AggregatedConfigSource; -import io.envoyproxy.envoy.api.v2.core.ApiConfigSource; -import io.envoyproxy.envoy.api.v2.core.ConfigSource; -import io.envoyproxy.envoy.api.v2.core.GrpcService; -import io.envoyproxy.envoy.api.v2.core.GrpcService.GoogleGrpc; -import io.envoyproxy.envoy.api.v2.core.HealthStatus; -import io.envoyproxy.envoy.api.v2.core.Locality; -import io.envoyproxy.envoy.api.v2.core.Node; -import io.envoyproxy.envoy.api.v2.core.RoutingPriority; -import io.envoyproxy.envoy.api.v2.core.SelfConfigSource; -import io.envoyproxy.envoy.api.v2.core.SocketAddress; -import io.envoyproxy.envoy.api.v2.core.TransportSocket; -import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats; -import io.envoyproxy.envoy.api.v2.endpoint.Endpoint; -import io.envoyproxy.envoy.api.v2.endpoint.LbEndpoint; -import io.envoyproxy.envoy.api.v2.endpoint.LocalityLbEndpoints; -import io.envoyproxy.envoy.api.v2.listener.FilterChain; -import io.envoyproxy.envoy.api.v2.route.Route; -import io.envoyproxy.envoy.api.v2.route.RouteAction; -import io.envoyproxy.envoy.api.v2.route.RouteMatch; -import io.envoyproxy.envoy.api.v2.route.VirtualHost; -import io.envoyproxy.envoy.config.cluster.aggregate.v2alpha.ClusterConfig; -import io.envoyproxy.envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager; -import io.envoyproxy.envoy.config.filter.network.http_connection_manager.v2.HttpFilter; -import io.envoyproxy.envoy.config.filter.network.http_connection_manager.v2.Rds; -import io.envoyproxy.envoy.config.listener.v2.ApiListener; -import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; -import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceImplBase; -import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest; -import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse; -import io.envoyproxy.envoy.type.FractionalPercent; -import io.envoyproxy.envoy.type.FractionalPercent.DenominatorType; -import io.envoyproxy.envoy.type.matcher.RegexMatcher; -import io.grpc.BindableService; -import io.grpc.Context; -import io.grpc.Context.CancellationListener; -import io.grpc.Status; -import io.grpc.stub.StreamObserver; -import io.grpc.xds.AbstractXdsClient.ResourceType; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import javax.annotation.Nullable; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentMatcher; -import org.mockito.InOrder; - -/** - * Tests for {@link ClientXdsClient} with protocol version v2. - */ -@RunWith(JUnit4.class) -public class ClientXdsClientV2Test extends ClientXdsClientTestBase { - - @Override - protected BindableService createAdsService() { - return new AggregatedDiscoveryServiceImplBase() { - @Override - public StreamObserver streamAggregatedResources( - final StreamObserver responseObserver) { - assertThat(adsEnded.get()).isTrue(); // ensure previous call was ended - adsEnded.set(false); - @SuppressWarnings("unchecked") - StreamObserver requestObserver = mock(StreamObserver.class); - DiscoveryRpcCall call = new DiscoveryRpcCallV2(requestObserver, responseObserver); - resourceDiscoveryCalls.offer(call); - Context.current().addListener( - new CancellationListener() { - @Override - public void cancelled(Context context) { - adsEnded.set(true); - } - }, MoreExecutors.directExecutor()); - return requestObserver; - } - }; - } - - @Override - protected BindableService createLrsService() { - return new LoadReportingServiceImplBase() { - @Override - public StreamObserver streamLoadStats( - StreamObserver responseObserver) { - assertThat(lrsEnded.get()).isTrue(); - lrsEnded.set(false); - @SuppressWarnings("unchecked") - StreamObserver requestObserver = mock(StreamObserver.class); - LrsRpcCall call = new LrsRpcCallV2(requestObserver, responseObserver); - Context.current().addListener( - new CancellationListener() { - @Override - public void cancelled(Context context) { - lrsEnded.set(true); - } - }, MoreExecutors.directExecutor()); - loadReportCalls.offer(call); - return requestObserver; - } - }; - } - - @Override - protected MessageFactory createMessageFactory() { - return new MessageFactoryV2(); - } - - @Override - protected boolean useProtocolV3() { - return false; - } - - private static class DiscoveryRpcCallV2 extends DiscoveryRpcCall { - StreamObserver requestObserver; - StreamObserver responseObserver; - - private DiscoveryRpcCallV2(StreamObserver requestObserver, - StreamObserver responseObserver) { - this.requestObserver = requestObserver; - this.responseObserver = responseObserver; - } - - @Override - protected void verifyRequest( - ResourceType type, List resources, String versionInfo, String nonce, - EnvoyProtoData.Node node) { - verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( - node.toEnvoyProtoNodeV2(), versionInfo, resources, type.typeUrlV2(), nonce, null, null))); - } - - @Override - protected void verifyRequestNack( - ResourceType type, List resources, String versionInfo, String nonce, - EnvoyProtoData.Node node, List errorMessages) { - verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( - node.toEnvoyProtoNodeV2(), versionInfo, resources, type.typeUrlV2(), nonce, - Code.INVALID_ARGUMENT_VALUE, errorMessages))); - } - - @Override - protected void verifyNoMoreRequest() { - verifyNoMoreInteractions(requestObserver); - } - - @Override - protected void sendResponse( - ResourceType type, List resources, String versionInfo, String nonce) { - DiscoveryResponse response = - DiscoveryResponse.newBuilder() - .setVersionInfo(versionInfo) - .addAllResources(resources) - .setTypeUrl(type.typeUrl()) - .setNonce(nonce) - .build(); - responseObserver.onNext(response); - } - - @Override - protected void sendError(Throwable t) { - responseObserver.onError(t); - } - - @Override - protected void sendCompleted() { - responseObserver.onCompleted(); - } - } - - private static class LrsRpcCallV2 extends LrsRpcCall { - private final StreamObserver requestObserver; - private final StreamObserver responseObserver; - private final InOrder inOrder; - - private LrsRpcCallV2(StreamObserver requestObserver, - StreamObserver responseObserver) { - this.requestObserver = requestObserver; - this.responseObserver = responseObserver; - inOrder = inOrder(requestObserver); - } - - @Override - protected void verifyNextReportClusters(List clusters) { - inOrder.verify(requestObserver).onNext(argThat(new LrsRequestMatcher(clusters))); - } - - @Override - protected void sendResponse(List clusters, long loadReportIntervalNano) { - LoadStatsResponse response = - LoadStatsResponse.newBuilder() - .addAllClusters(clusters) - .setLoadReportingInterval(Durations.fromNanos(loadReportIntervalNano)) - .build(); - responseObserver.onNext(response); - } - } - - private static class MessageFactoryV2 extends MessageFactory { - - @SuppressWarnings("unchecked") - @Override - protected Message buildListenerWithApiListener( - String name, Message routeConfiguration, List httpFilters) { - return Listener.newBuilder() - .setName(name) - .setAddress(Address.getDefaultInstance()) - .addFilterChains(FilterChain.getDefaultInstance()) - .setApiListener( - ApiListener.newBuilder().setApiListener(Any.pack( - HttpConnectionManager.newBuilder() - .setRouteConfig((RouteConfiguration) routeConfiguration) - .addAllHttpFilters((List) httpFilters) - .build()))) - .build(); - } - - @Override - protected Message buildListenerWithApiListenerForRds(String name, String rdsResourceName) { - return Listener.newBuilder() - .setName(name) - .setAddress(Address.getDefaultInstance()) - .addFilterChains(FilterChain.getDefaultInstance()) - .setApiListener( - ApiListener.newBuilder().setApiListener(Any.pack( - HttpConnectionManager.newBuilder() - .setRds( - Rds.newBuilder() - .setRouteConfigName(rdsResourceName) - .setConfigSource( - ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.getDefaultInstance()))) - .build()))) - .build(); - } - - @Override - protected Message buildListenerWithApiListenerInvalid(String name) { - return Listener.newBuilder() - .setName(name) - .setAddress(Address.getDefaultInstance()) - .setApiListener(ApiListener.newBuilder().setApiListener(FAILING_ANY)) - .build(); - } - - @Override - protected Message buildHttpFilter(String name, @Nullable Any typedConfig, boolean isOptional) { - throw new UnsupportedOperationException(); - } - - @Override - protected Any buildHttpFaultTypedConfig( - @Nullable Long delayNanos, @Nullable Integer delayRate, String upstreamCluster, - List downstreamNodes, @Nullable Integer maxActiveFaults, @Nullable Status status, - @Nullable Integer httpCode, @Nullable Integer abortRate) { - throw new UnsupportedOperationException(); - } - - @Override - protected Message buildRouteConfiguration(String name, List virtualHostList) { - RouteConfiguration.Builder builder = RouteConfiguration.newBuilder(); - builder.setName(name); - for (Message virtualHost : virtualHostList) { - builder.addVirtualHosts((VirtualHost) virtualHost); - } - return builder.build(); - } - - @Override - protected Message buildRouteConfigurationInvalid(String name) { - // Invalid Path matcher: Pattern.compile() will throw PatternSyntaxException - // when attempting to process SAFE_REGEX RouteMatch malformed safe regex pattern. - // I wish there was a simpler way. - return RouteConfiguration.newBuilder() - .setName(name) - .addVirtualHosts( - VirtualHost.newBuilder() - .setName("do not care") - .addDomains("do not care") - .addRoutes( - Route.newBuilder() - .setRoute(RouteAction.newBuilder().setCluster("do not care")) - .setMatch(RouteMatch.newBuilder() - .setSafeRegex(RegexMatcher.newBuilder().setRegex("[z-a]"))))) - .build(); - } - - @Override - protected List buildOpaqueVirtualHosts(int num) { - List virtualHosts = new ArrayList<>(num); - for (int i = 0; i < num; i++) { - VirtualHost virtualHost = - VirtualHost.newBuilder() - .setName(num + ": do not care") - .addDomains("do not care") - .addRoutes( - Route.newBuilder() - .setRoute(RouteAction.newBuilder().setCluster("do not care")) - .setMatch(RouteMatch.newBuilder() - .setPrefix("do not care"))) - .build(); - virtualHosts.add(virtualHost); - } - return virtualHosts; - } - - @SuppressWarnings("unchecked") - @Override - protected Message buildVirtualHost( - List routes, Map typedConfigMap) { - return VirtualHost.newBuilder() - .setName("do not care") - .addDomains("do not care") - .addAllRoutes((List) routes) - .putAllTypedPerFilterConfig(typedConfigMap) - .build(); - } - - @Override - protected List buildOpaqueRoutes(int num) { - List routes = new ArrayList<>(num); - for (int i = 0; i < num; i++) { - Route route = - Route.newBuilder() - .setRoute(RouteAction.newBuilder().setCluster("do not care")) - .setMatch(RouteMatch.newBuilder().setPrefix("do not care")) - .build(); - routes.add(route); - } - return routes; - } - - @Override - protected Message buildClusterInvalid(String name) { - // Unspecified cluster discovery type - return Cluster.newBuilder().setName(name).build(); - } - - @Override - protected Message buildEdsCluster(String clusterName, @Nullable String edsServiceName, - String lbPolicy, @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, - boolean enableLrs, - @Nullable Message upstreamTlsContext, String transportSocketName, - @Nullable Message circuitBreakers) { - Cluster.Builder builder = initClusterBuilder( - clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, - enableLrs, upstreamTlsContext, circuitBreakers); - builder.setType(DiscoveryType.EDS); - EdsClusterConfig.Builder edsClusterConfigBuilder = EdsClusterConfig.newBuilder(); - edsClusterConfigBuilder.setEdsConfig( - ConfigSource.newBuilder().setAds(AggregatedConfigSource.getDefaultInstance())); // ADS - if (edsServiceName != null) { - edsClusterConfigBuilder.setServiceName(edsServiceName); - } - builder.setEdsClusterConfig(edsClusterConfigBuilder); - return builder.build(); - } - - @Override - protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, - int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, - @Nullable Message leastRequestLbConfig, boolean enableLrs, - @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { - Cluster.Builder builder = initClusterBuilder( - clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, - enableLrs, upstreamTlsContext, circuitBreakers); - builder.setType(DiscoveryType.LOGICAL_DNS); - builder.setLoadAssignment( - ClusterLoadAssignment.newBuilder().addEndpoints( - LocalityLbEndpoints.newBuilder().addLbEndpoints( - LbEndpoint.newBuilder().setEndpoint( - Endpoint.newBuilder().setAddress( - Address.newBuilder().setSocketAddress( - SocketAddress.newBuilder() - .setAddress(dnsHostAddr).setPortValue(dnsHostPort)))))).build()); - return builder.build(); - } - - @Override - protected Message buildAggregateCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, - List clusters) { - ClusterConfig clusterConfig = ClusterConfig.newBuilder().addAllClusters(clusters).build(); - CustomClusterType type = - CustomClusterType.newBuilder() - .setName(ClientXdsClient.AGGREGATE_CLUSTER_TYPE_NAME) - .setTypedConfig(Any.pack(clusterConfig)) - .build(); - Cluster.Builder builder = Cluster.newBuilder().setName(clusterName).setClusterType(type); - if (lbPolicy.equals("round_robin")) { - builder.setLbPolicy(LbPolicy.ROUND_ROBIN); - } else if (lbPolicy.equals("ring_hash_experimental")) { - builder.setLbPolicy(LbPolicy.RING_HASH); - builder.setRingHashLbConfig((RingHashLbConfig) ringHashLbConfig); - } else if (lbPolicy.equals("least_request_experimental")) { - builder.setLbPolicy(LbPolicy.LEAST_REQUEST); - builder.setLeastRequestLbConfig((LeastRequestLbConfig) leastRequestLbConfig); - } else { - throw new AssertionError("Invalid LB policy"); - } - return builder.build(); - } - - private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, - boolean enableLrs, @Nullable Message upstreamTlsContext, - @Nullable Message circuitBreakers) { - Cluster.Builder builder = Cluster.newBuilder(); - builder.setName(clusterName); - if (lbPolicy.equals("round_robin")) { - builder.setLbPolicy(LbPolicy.ROUND_ROBIN); - } else if (lbPolicy.equals("ring_hash_experimental")) { - builder.setLbPolicy(LbPolicy.RING_HASH); - builder.setRingHashLbConfig((RingHashLbConfig) ringHashLbConfig); - } else if (lbPolicy.equals("least_request_experimental")) { - builder.setLbPolicy(LbPolicy.LEAST_REQUEST); - builder.setLeastRequestLbConfig((LeastRequestLbConfig) leastRequestLbConfig); - } else { - throw new AssertionError("Invalid LB policy"); - } - if (enableLrs) { - builder.setLrsServer( - ConfigSource.newBuilder() - .setSelf(SelfConfigSource.getDefaultInstance())); - } - if (upstreamTlsContext != null) { - builder.setTransportSocket( - TransportSocket.newBuilder() - .setName("envoy.transport_sockets.tls") - .setTypedConfig(Any.pack(upstreamTlsContext))); - } - if (circuitBreakers != null) { - builder.setCircuitBreakers((CircuitBreakers) circuitBreakers); - } - return builder; - } - - @Override - protected Message buildRingHashLbConfig(String hashFunction, long minRingSize, - long maxRingSize) { - RingHashLbConfig.Builder builder = RingHashLbConfig.newBuilder(); - if (hashFunction.equals("xx_hash")) { - builder.setHashFunction(HashFunction.XX_HASH); - } else if (hashFunction.equals("murmur_hash_2")) { - builder.setHashFunction(HashFunction.MURMUR_HASH_2); - } else { - throw new AssertionError("Invalid hash function"); - } - builder.setMinimumRingSize(UInt64Value.newBuilder().setValue(minRingSize).build()); - builder.setMaximumRingSize(UInt64Value.newBuilder().setValue(maxRingSize).build()); - return builder.build(); - } - - @Override - protected Message buildLeastRequestLbConfig(int choiceCount) { - LeastRequestLbConfig.Builder builder = LeastRequestLbConfig.newBuilder(); - builder.setChoiceCount(UInt32Value.newBuilder().setValue(choiceCount)); - return builder.build(); - } - - @Override - protected Message buildUpstreamTlsContext(String instanceName, String certName) { - GrpcService grpcService = - GrpcService.newBuilder() - .setGoogleGrpc(GoogleGrpc.newBuilder().setTargetUri(certName)) - .build(); - ConfigSource sdsConfig = - ConfigSource.newBuilder() - .setApiConfigSource(ApiConfigSource.newBuilder().addGrpcServices(grpcService)) - .build(); - SdsSecretConfig validationContextSdsSecretConfig = - SdsSecretConfig.newBuilder() - .setName(instanceName) - .setSdsConfig(sdsConfig) - .build(); - return UpstreamTlsContext.newBuilder() - .setCommonTlsContext( - CommonTlsContext.newBuilder() - .setValidationContextSdsSecretConfig(validationContextSdsSecretConfig)) - .build(); - } - - @Override - protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { - return buildUpstreamTlsContext(instanceName, certName); - } - - - @Override - protected Message buildCircuitBreakers(int highPriorityMaxRequests, - int defaultPriorityMaxRequests) { - return CircuitBreakers.newBuilder() - .addThresholds( - Thresholds.newBuilder() - .setPriority(RoutingPriority.HIGH) - .setMaxRequests(UInt32Value.newBuilder().setValue(highPriorityMaxRequests))) - .addThresholds( - Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(defaultPriorityMaxRequests))) - .build(); - } - - @Override - protected Message buildClusterLoadAssignment(String cluster, - List localityLbEndpointsList, List dropOverloadList) { - ClusterLoadAssignment.Builder builder = ClusterLoadAssignment.newBuilder(); - builder.setClusterName(cluster); - for (Message localityLbEndpoints : localityLbEndpointsList) { - builder.addEndpoints((LocalityLbEndpoints) localityLbEndpoints); - } - Policy.Builder policyBuilder = Policy.newBuilder(); - for (Message dropOverload : dropOverloadList) { - policyBuilder.addDropOverloads((DropOverload) dropOverload); - } - builder.setPolicy(policyBuilder); - return builder.build(); - } - - @Override - protected Message buildClusterLoadAssignmentInvalid(String cluster) { - // Negative priority LocalityLbEndpoint. - return ClusterLoadAssignment.newBuilder() - .setClusterName(cluster) - .addEndpoints(LocalityLbEndpoints.newBuilder() - .setPriority(-1) - .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(1))) - .build(); - } - - @Override - protected Message buildLocalityLbEndpoints(String region, String zone, String subZone, - List lbEndpointList, int loadBalancingWeight, int priority) { - LocalityLbEndpoints.Builder builder = LocalityLbEndpoints.newBuilder(); - builder.setLocality( - Locality.newBuilder().setRegion(region).setZone(zone).setSubZone(subZone)); - for (Message lbEndpoint : lbEndpointList) { - builder.addLbEndpoints((LbEndpoint) lbEndpoint); - } - builder.setLoadBalancingWeight(UInt32Value.of(loadBalancingWeight)); - builder.setPriority(priority); - return builder.build(); - } - - @Override - protected Message buildLbEndpoint(String address, int port, String healthStatus, - int lbWeight) { - HealthStatus status; - switch (healthStatus) { - case "unknown": - status = HealthStatus.UNKNOWN; - break; - case "healthy": - status = HealthStatus.HEALTHY; - break; - case "unhealthy": - status = HealthStatus.UNHEALTHY; - break; - case "draining": - status = HealthStatus.DRAINING; - break; - case "timeout": - status = HealthStatus.TIMEOUT; - break; - case "degraded": - status = HealthStatus.DEGRADED; - break; - default: - status = HealthStatus.UNRECOGNIZED; - } - return LbEndpoint.newBuilder() - .setEndpoint( - Endpoint.newBuilder().setAddress( - Address.newBuilder().setSocketAddress( - SocketAddress.newBuilder().setAddress(address).setPortValue(port)))) - .setHealthStatus(status) - .setLoadBalancingWeight(UInt32Value.of(lbWeight)) - .build(); - } - - @Override - protected Message buildDropOverload(String category, int dropPerMillion) { - return DropOverload.newBuilder() - .setCategory(category) - .setDropPercentage( - FractionalPercent.newBuilder() - .setNumerator(dropPerMillion) - .setDenominator(DenominatorType.MILLION)) - .build(); - } - - @Override - protected Message buildFilterChain(List alpn, Message tlsContext, - String transportSocketName, Message... filters) { - throw new UnsupportedOperationException(); - } - - @Override - protected Message buildListenerWithFilterChain( - String name, int portValue, String address, Message... filterChains) { - throw new UnsupportedOperationException(); - } - - @Override - protected Message buildHttpConnectionManagerFilter( - @Nullable String rdsName, @Nullable Message routeConfig, List httpFilters) { - throw new UnsupportedOperationException(); - } - - @Override - protected Message buildTerminalFilter() { - throw new UnsupportedOperationException(); - } - } - - /** - * Matches a {@link DiscoveryRequest} with the same node metadata, versionInfo, typeUrl, - * response nonce and collection of resource names regardless of order. - */ - private static class DiscoveryRequestMatcher implements ArgumentMatcher { - private final Node node; - private final String versionInfo; - private final String typeUrl; - private final Set resources; - private final String responseNonce; - @Nullable private final Integer errorCode; - private final List errorMessages; - - private DiscoveryRequestMatcher(Node node, String versionInfo, List resources, - String typeUrl, String responseNonce, @Nullable Integer errorCode, - @Nullable List errorMessages) { - this.node = node; - this.versionInfo = versionInfo; - this.resources = new HashSet<>(resources); - this.typeUrl = typeUrl; - this.responseNonce = responseNonce; - this.errorCode = errorCode; - this.errorMessages = errorMessages != null ? errorMessages : ImmutableList.of(); - } - - @Override - public boolean matches(DiscoveryRequest argument) { - if (!typeUrl.equals(argument.getTypeUrl())) { - return false; - } - if (!versionInfo.equals(argument.getVersionInfo())) { - return false; - } - if (!responseNonce.equals(argument.getResponseNonce())) { - return false; - } - if (!resources.equals(new HashSet<>(argument.getResourceNamesList()))) { - return false; - } - if (errorCode == null && argument.hasErrorDetail()) { - return false; - } - if (errorCode != null - && !matchErrorDetail(argument.getErrorDetail(), errorCode, errorMessages)) { - return false; - } - return node.equals(argument.getNode()); - } - } - - /** - * Matches a {@link LoadStatsRequest} containing a collection of {@link ClusterStats} with - * the same list of clusterName:clusterServiceName pair. - */ - private static class LrsRequestMatcher implements ArgumentMatcher { - private final List expected; - - private LrsRequestMatcher(List clusterNames) { - expected = new ArrayList<>(); - for (String[] pair : clusterNames) { - expected.add(pair[0] + ":" + (pair[1] == null ? "" : pair[1])); - } - Collections.sort(expected); - } - - @Override - public boolean matches(LoadStatsRequest argument) { - List actual = new ArrayList<>(); - for (ClusterStats clusterStats : argument.getClusterStatsList()) { - actual.add(clusterStats.getClusterName() + ":" + clusterStats.getClusterServiceName()); - } - Collections.sort(actual); - return actual.equals(expected); - } - } -} diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 582747aecb7..f553b558c9f 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -57,9 +57,9 @@ import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProvider; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SslContextProvider; +import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; @@ -88,7 +88,7 @@ public class ClusterImplLoadBalancerTest { private static final String CLUSTER = "cluster-foo.googleapis.com"; private static final String EDS_SERVICE_NAME = "service.googleapis.com"; private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("api.google.com", InsecureChannelCredentials.create(), true); + ServerInfo.create("api.google.com", InsecureChannelCredentials.create()); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @Override @@ -282,7 +282,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.singletonList(DropOverload.create("lb", 1_000_000)), new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.singletonList(endpoint)) .setAttributes( @@ -470,6 +470,7 @@ public void endpointAddressesAttachedWithClusterName() { assertThat(downstreamBalancers).hasSize(1); // one leaf balancer FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); assertThat(leafBalancer.name).isEqualTo("round_robin"); + // Simulates leaf load balancer creating subchannels. CreateSubchannelArgs args = CreateSubchannelArgs.newBuilder() @@ -480,6 +481,13 @@ public void endpointAddressesAttachedWithClusterName() { assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) .isEqualTo(CLUSTER); } + + // An address update should also retain the cluster attribute. + subchannel.updateAddresses(leafBalancer.addresses); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) + .isEqualTo(CLUSTER); + } } @Test @@ -571,7 +579,7 @@ private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecuri private void deliverAddressesAndConfig(List addresses, ClusterImplConfig config) { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(addresses) .setAttributes( @@ -677,10 +685,11 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { addresses = resolvedAddresses.getAddresses(); config = resolvedAddresses.getLoadBalancingPolicyConfig(); attributes = resolvedAddresses.getAttributes(); + return true; } @Override @@ -760,6 +769,10 @@ public List getAllAddresses() { public Attributes getAttributes() { return attrs; } + + @Override + public void updateAddresses(List addrs) { + } } private final class FakeXdsClient extends XdsClient { diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java index 5890f6f9abb..f2b80cfff0b 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java @@ -17,8 +17,10 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; @@ -52,6 +54,7 @@ import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.testing.TestMethodDescriptors; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; +import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -89,7 +92,7 @@ public void uncaughtException(Thread t, Throwable e) { private final Map lbConfigInventory = new HashMap<>(); private final List childBalancers = new ArrayList<>(); - private LoadBalancer clusterManagerLoadBalancer; + private ClusterManagerLoadBalancer clusterManagerLoadBalancer; @Before public void setUp() { @@ -249,21 +252,35 @@ public void handleNameResolutionError_notPropagateToDeactivatedChildLbs() { assertThat(childBalancer2.upstreamError.getDescription()).isEqualTo("unknown error"); } + @Test + public void noDuplicateOverallBalancingStateUpdate() { + deliverResolvedAddresses(ImmutableMap.of("childA", "policy_a", "childB", "policy_b"), true); + + // The test child LBs would have triggered state updates, let's make sure the overall balancing + // state was only updated once but that the new state reflects the state the child LB reported. + verify(helper, times(1)).updateBalancingState( + eq(TRANSIENT_FAILURE), isA(SubchannelPicker.class)); + } + private void deliverResolvedAddresses(final Map childPolicies) { - clusterManagerLoadBalancer.handleResolvedAddresses( + deliverResolvedAddresses(childPolicies, false); + } + + private void deliverResolvedAddresses(final Map childPolicies, boolean failing) { + clusterManagerLoadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) - .setLoadBalancingPolicyConfig(buildConfig(childPolicies)) + .setLoadBalancingPolicyConfig(buildConfig(childPolicies, failing)) .build()); } - private ClusterManagerConfig buildConfig(Map childPolicies) { + private ClusterManagerConfig buildConfig(Map childPolicies, boolean failing) { Map childPolicySelections = new LinkedHashMap<>(); for (String name : childPolicies.keySet()) { String childPolicyName = childPolicies.get(name); Object childConfig = lbConfigInventory.get(name); PolicySelection policy = - new PolicySelection(new FakeLoadBalancerProvider(childPolicyName), childConfig); + new PolicySelection(new FakeLoadBalancerProvider(childPolicyName, failing), childConfig); childPolicySelections.put(name, policy); } return new ClusterManagerConfig(childPolicySelections); @@ -286,14 +303,16 @@ private static PickResult pickSubchannel(SubchannelPicker picker, String cluster private final class FakeLoadBalancerProvider extends LoadBalancerProvider { private final String policyName; + private final boolean failing; - FakeLoadBalancerProvider(String policyName) { + FakeLoadBalancerProvider(String policyName, boolean failing) { this.policyName = policyName; + this.failing = failing; } @Override public LoadBalancer newLoadBalancer(Helper helper) { - FakeLoadBalancer balancer = new FakeLoadBalancer(policyName, helper); + FakeLoadBalancer balancer = new FakeLoadBalancer(policyName, helper, failing); childBalancers.add(balancer); return balancer; } @@ -317,18 +336,25 @@ public String getPolicyName() { private final class FakeLoadBalancer extends LoadBalancer { private final String name; private final Helper helper; + private final boolean failing; private Object config; private Status upstreamError; private boolean shutdown; - FakeLoadBalancer(String name, Helper helper) { + FakeLoadBalancer(String name, Helper helper, boolean failing) { this.name = name; this.helper = helper; + this.failing = failing; } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { config = resolvedAddresses.getLoadBalancingPolicyConfig(); + + if (failing) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.INTERNAL)); + } + return true; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 51a7ce5066b..51396511dcc 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -20,6 +20,7 @@ import static io.grpc.xds.XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; +import static io.grpc.xds.XdsLbPolicies.WRR_LOCALITY_POLICY_NAME; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -57,6 +58,8 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; +import io.grpc.util.OutlierDetectionLoadBalancerProvider; import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; @@ -64,14 +67,17 @@ import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; +import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; -import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; -import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -107,7 +113,7 @@ public class ClusterResolverLoadBalancerTest { private static final String EDS_SERVICE_NAME2 = "backend-service-bar.googleapis.com"; private static final String DNS_HOST_NAME = "dns-service.googleapis.com"; private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create(), true); + ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); private final Locality locality1 = Locality.create("test-region-1", "test-zone-1", "test-subzone-1"); private final Locality locality2 = @@ -116,10 +122,18 @@ public class ClusterResolverLoadBalancerTest { Locality.create("test-region-3", "test-zone-3", "test-subzone-3"); private final UpstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + private final OutlierDetection outlierDetection = OutlierDetection.create( + 100L, 100L, 100L, 100, SuccessRateEjection.create(100, 100, 100, 100), + FailurePercentageEjection.create(100, 100, 100, 100)); private final DiscoveryMechanism edsDiscoveryMechanism1 = - DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext); + DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, + null); private final DiscoveryMechanism edsDiscoveryMechanism2 = - DiscoveryMechanism.forEds(CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, tlsContext); + DiscoveryMechanism.forEds(CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, tlsContext, + null); + private final DiscoveryMechanism edsDiscoveryMechanismWithOutlierDetection = + DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, + outlierDetection); private final DiscoveryMechanism logicalDnsDiscoveryMechanism = DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, DNS_HOST_NAME, LRS_SERVER_INFO, 300L, null); @@ -133,12 +147,15 @@ public void uncaughtException(Thread t, Throwable e) { private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final NameResolverRegistry nsRegistry = new NameResolverRegistry(); - private final PolicySelection roundRobin = - new PolicySelection(new FakeLoadBalancerProvider("round_robin"), null); + private final PolicySelection roundRobin = new PolicySelection( + new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( + new PolicySelection(new FakeLoadBalancerProvider("round_robin"), null))); private final PolicySelection ringHash = new PolicySelection( new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L)); private final PolicySelection leastRequest = new PolicySelection( - new FakeLoadBalancerProvider("least_request_experimental"), new LeastRequestConfig(3)); + new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( + new PolicySelection(new FakeLoadBalancerProvider("least_request_experimental"), + new LeastRequestConfig(3)))); private final List childBalancers = new ArrayList<>(); private final List resolvers = new ArrayList<>(); private final FakeXdsClient xdsClient = new FakeXdsClient(); @@ -169,7 +186,6 @@ public XdsClient returnObject(Object object) { private int xdsClientRefs; private ClusterResolverLoadBalancer loadBalancer; - @Before public void setUp() throws URISyntaxException { MockitoAnnotations.initMocks(this); @@ -179,6 +195,7 @@ public void setUp() throws URISyntaxException { lbRegistry.register(new FakeLoadBalancerProvider(WEIGHTED_TARGET_POLICY_NAME)); lbRegistry.register( new FakeLoadBalancerProvider("pick_first")); // needed by logical_dns + lbRegistry.register(new OutlierDetectionLoadBalancerProvider()); NameResolver.Args args = NameResolver.Args.newBuilder() .setDefaultPort(8080) .setProxyDetector(GrpcUtil.NOOP_PROXY_DETECTOR) @@ -254,7 +271,7 @@ public void edsClustersWithRingHashEndpointLbPolicy() { .isEqualTo(50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[priority1]"); + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); PriorityChildConfig priorityChildConfig = Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); assertThat(priorityChildConfig.ignoreReresolution).isTrue(); @@ -295,7 +312,7 @@ public void edsClustersWithLeastRequestEndpointLbPolicy() { assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[priority1]"); + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); PriorityChildConfig priorityChildConfig = Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) @@ -303,12 +320,101 @@ public void edsClustersWithLeastRequestEndpointLbPolicy() { ClusterImplConfig clusterImplConfig = (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - WeightedTargetConfig weightedTargetConfig = - (WeightedTargetConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(weightedTargetConfig.targets.keySet()).containsExactly(locality1.toString()); + tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + WrrLocalityConfig wrrLocalityConfig = + (WrrLocalityConfig) clusterImplConfig.childPolicy.getConfig(); + assertThat(wrrLocalityConfig.childPolicy.getProvider().getPolicyName()).isEqualTo( + "least_request_experimental"); + + assertThat( + childBalancer.addresses.get(0).getAttributes() + .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); } + @Test + public void edsClustersWithOutlierDetection() { + ClusterResolverConfig config = new ClusterResolverConfig( + Collections.singletonList(edsDiscoveryMechanismWithOutlierDetection), leastRequest); + deliverLbConfig(config); + assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); + assertThat(childBalancers).isEmpty(); + + // Simple case with one priority and one locality + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); + LocalityLbEndpoints localityLbEndpoints = + LocalityLbEndpoints.create( + Arrays.asList( + LbEndpoint.create(endpoint, 0 /* loadBalancingWeight */, true)), + 100 /* localityWeight */, 1 /* priority */); + xdsClient.deliverClusterLoadAssignment( + EDS_SERVICE_NAME1, + ImmutableMap.of(locality1, localityLbEndpoints)); + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + assertThat(childBalancer.addresses).hasSize(1); + EquivalentAddressGroup addr = childBalancer.addresses.get(0); + assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); + PriorityChildConfig priorityChildConfig = + Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); + + // The child config for priority should be outlier detection. + assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + .isEqualTo("outlier_detection_experimental"); + OutlierDetectionLoadBalancerConfig outlierDetectionConfig = + (OutlierDetectionLoadBalancerConfig) priorityChildConfig.policySelection.getConfig(); + + // The outlier detection config should faithfully represent what came down from xDS. + assertThat(outlierDetectionConfig.intervalNanos).isEqualTo(outlierDetection.intervalNanos()); + assertThat(outlierDetectionConfig.baseEjectionTimeNanos).isEqualTo( + outlierDetection.baseEjectionTimeNanos()); + assertThat(outlierDetectionConfig.baseEjectionTimeNanos).isEqualTo( + outlierDetection.baseEjectionTimeNanos()); + assertThat(outlierDetectionConfig.maxEjectionTimeNanos).isEqualTo( + outlierDetection.maxEjectionTimeNanos()); + assertThat(outlierDetectionConfig.maxEjectionPercent).isEqualTo( + outlierDetection.maxEjectionPercent()); + + OutlierDetectionLoadBalancerConfig.SuccessRateEjection successRateEjection + = outlierDetectionConfig.successRateEjection; + assertThat(successRateEjection.stdevFactor).isEqualTo( + outlierDetection.successRateEjection().stdevFactor()); + assertThat(successRateEjection.enforcementPercentage).isEqualTo( + outlierDetection.successRateEjection().enforcementPercentage()); + assertThat(successRateEjection.minimumHosts).isEqualTo( + outlierDetection.successRateEjection().minimumHosts()); + assertThat(successRateEjection.requestVolume).isEqualTo( + outlierDetection.successRateEjection().requestVolume()); + + OutlierDetectionLoadBalancerConfig.FailurePercentageEjection failurePercentageEjection + = outlierDetectionConfig.failurePercentageEjection; + assertThat(failurePercentageEjection.threshold).isEqualTo( + outlierDetection.failurePercentageEjection().threshold()); + assertThat(failurePercentageEjection.enforcementPercentage).isEqualTo( + outlierDetection.failurePercentageEjection().enforcementPercentage()); + assertThat(failurePercentageEjection.minimumHosts).isEqualTo( + outlierDetection.failurePercentageEjection().minimumHosts()); + assertThat(failurePercentageEjection.requestVolume).isEqualTo( + outlierDetection.failurePercentageEjection().requestVolume()); + + // The wrapped configuration should not have been tampered with. + ClusterImplConfig clusterImplConfig = + (ClusterImplConfig) outlierDetectionConfig.childPolicy.getConfig(); + assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, + tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + WrrLocalityConfig wrrLocalityConfig = + (WrrLocalityConfig) clusterImplConfig.childPolicy.getConfig(); + assertThat(wrrLocalityConfig.childPolicy.getProvider().getPolicyName()).isEqualTo( + "least_request_experimental"); + + assertThat( + childBalancer.addresses.get(0).getAttributes() + .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); + } + + @Test public void onlyEdsClusters_receivedEndpoints() { ClusterResolverConfig config = new ClusterResolverConfig( @@ -337,9 +443,9 @@ public void onlyEdsClusters_receivedEndpoints() { LocalityLbEndpoints.create( Collections.singletonList(LbEndpoint.create(endpoint4, 100, true)), 20 /* localityWeight */, 2 /* priority */); - String priority1 = CLUSTER2 + "[priority1]"; - String priority2 = CLUSTER2 + "[priority2]"; - String priority3 = CLUSTER1 + "[priority1]"; + String priority1 = CLUSTER2 + "[child1]"; + String priority2 = CLUSTER2 + "[child2]"; + String priority3 = CLUSTER1 + "[child1]"; // CLUSTER2: locality1 with priority 1 and locality3 with priority 2. xdsClient.deliverClusterLoadAssignment( @@ -366,13 +472,12 @@ public void onlyEdsClusters_receivedEndpoints() { ClusterImplConfig clusterImplConfig1 = (ClusterImplConfig) priorityChildConfig1.policySelection.getConfig(); assertClusterImplConfig(clusterImplConfig1, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, - tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - WeightedTargetConfig weightedTargetConfig1 = - (WeightedTargetConfig) clusterImplConfig1.childPolicy.getConfig(); - assertThat(weightedTargetConfig1.targets.keySet()).containsExactly(locality1.toString()); - WeightedPolicySelection target1 = weightedTargetConfig1.targets.get(locality1.toString()); - assertThat(target1.weight).isEqualTo(70); - assertThat(target1.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); + tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + assertThat(clusterImplConfig1.childPolicy.getConfig()).isInstanceOf(WrrLocalityConfig.class); + WrrLocalityConfig wrrLocalityConfig1 = + (WrrLocalityConfig) clusterImplConfig1.childPolicy.getConfig(); + assertThat(wrrLocalityConfig1.childPolicy.getProvider().getPolicyName()).isEqualTo( + "round_robin"); PriorityChildConfig priorityChildConfig2 = priorityLbConfig.childConfigs.get(priority2); assertThat(priorityChildConfig2.ignoreReresolution).isTrue(); @@ -381,21 +486,12 @@ public void onlyEdsClusters_receivedEndpoints() { ClusterImplConfig clusterImplConfig2 = (ClusterImplConfig) priorityChildConfig2.policySelection.getConfig(); assertClusterImplConfig(clusterImplConfig2, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, - tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - WeightedTargetConfig weightedTargetConfig2 = - (WeightedTargetConfig) clusterImplConfig2.childPolicy.getConfig(); - assertThat(weightedTargetConfig2.targets.keySet()).containsExactly(locality3.toString()); - WeightedPolicySelection target2 = weightedTargetConfig2.targets.get(locality3.toString()); - assertThat(target2.weight).isEqualTo(20); - assertThat(target2.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); - List priorityAddrs1 = - AddressFilter.filter(childBalancer.addresses, priority1); - assertThat(priorityAddrs1).hasSize(2); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), priorityAddrs1); - List priorityAddrs2 = - AddressFilter.filter(childBalancer.addresses, priority2); - assertThat(priorityAddrs2).hasSize(1); - assertAddressesEqual(Collections.singletonList(endpoint4), priorityAddrs2); + tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + assertThat(clusterImplConfig2.childPolicy.getConfig()).isInstanceOf(WrrLocalityConfig.class); + WrrLocalityConfig wrrLocalityConfig2 = + (WrrLocalityConfig) clusterImplConfig1.childPolicy.getConfig(); + assertThat(wrrLocalityConfig2.childPolicy.getProvider().getPolicyName()).isEqualTo( + "round_robin"); PriorityChildConfig priorityChildConfig3 = priorityLbConfig.childConfigs.get(priority3); assertThat(priorityChildConfig3.ignoreReresolution).isTrue(); @@ -404,17 +500,102 @@ public void onlyEdsClusters_receivedEndpoints() { ClusterImplConfig clusterImplConfig3 = (ClusterImplConfig) priorityChildConfig3.policySelection.getConfig(); assertClusterImplConfig(clusterImplConfig3, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - WeightedTargetConfig weightedTargetConfig3 = - (WeightedTargetConfig) clusterImplConfig3.childPolicy.getConfig(); - assertThat(weightedTargetConfig3.targets.keySet()).containsExactly(locality2.toString()); - WeightedPolicySelection target3 = weightedTargetConfig3.targets.get(locality2.toString()); - assertThat(target3.weight).isEqualTo(10); - assertThat(target3.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); - List priorityAddrs3 = - AddressFilter.filter(childBalancer.addresses, priority3); - assertThat(priorityAddrs3).hasSize(1); - assertAddressesEqual(Collections.singletonList(endpoint3), priorityAddrs3); + tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + assertThat(clusterImplConfig3.childPolicy.getConfig()).isInstanceOf(WrrLocalityConfig.class); + WrrLocalityConfig wrrLocalityConfig3 = + (WrrLocalityConfig) clusterImplConfig1.childPolicy.getConfig(); + assertThat(wrrLocalityConfig3.childPolicy.getProvider().getPolicyName()).isEqualTo( + "round_robin"); + + for (EquivalentAddressGroup eag : childBalancer.addresses) { + if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality1) { + assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) + .isEqualTo(70); + } + if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality2) { + assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) + .isEqualTo(10); + } + if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality3) { + assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) + .isEqualTo(20); + } + } + } + + @SuppressWarnings("unchecked") + private void verifyEdsPriorityNames(List want, + Map... updates) { + ClusterResolverConfig config = new ClusterResolverConfig( + Arrays.asList(edsDiscoveryMechanism2), roundRobin); + deliverLbConfig(config); + assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME2); + assertThat(childBalancers).isEmpty(); + + for (Map update: updates) { + xdsClient.deliverClusterLoadAssignment( + EDS_SERVICE_NAME2, + update); + } + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + assertThat(priorityLbConfig.priorities).isEqualTo(want); + } + + @Test + @SuppressWarnings("unchecked") + public void edsUpdatePriorityName_twoPriorities() { + verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child1]", CLUSTER2 + "[child2]"), + ImmutableMap.of(locality1, createEndpoints(1), + locality2, createEndpoints(2) + )); + } + + @Test + @SuppressWarnings("unchecked") + public void edsUpdatePriorityName_addOnePriority() { + verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child2]"), + ImmutableMap.of(locality1, createEndpoints(1)), + ImmutableMap.of(locality2, createEndpoints(1) + )); + } + + @Test + @SuppressWarnings("unchecked") + public void edsUpdatePriorityName_swapTwoPriorities() { + verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child2]", CLUSTER2 + "[child1]", + CLUSTER2 + "[child3]"), + ImmutableMap.of(locality1, createEndpoints(1), + locality2, createEndpoints(2), + locality3, createEndpoints(3) + ), + ImmutableMap.of(locality1, createEndpoints(2), + locality2, createEndpoints(1), + locality3, createEndpoints(3)) + ); + } + + @Test + @SuppressWarnings("unchecked") + public void edsUpdatePriorityName_mergeTwoPriorities() { + verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child3]", CLUSTER2 + "[child1]"), + ImmutableMap.of(locality1, createEndpoints(1), + locality3, createEndpoints(3), + locality2, createEndpoints(2)), + ImmutableMap.of(locality1, createEndpoints(2), + locality3, createEndpoints(1), + locality2, createEndpoints(1) + )); + } + + private LocalityLbEndpoints createEndpoints(int priority) { + return LocalityLbEndpoints.create( + Arrays.asList( + LbEndpoint.create(makeAddress("endpoint-addr-1"), 100, true), + LbEndpoint.create(makeAddress("endpoint-addr-2"), 100, true)), + 70 /* localityWeight */, priority /* priority */); } @Test @@ -510,19 +691,14 @@ public void handleEdsResource_ignoreLocalitiesWithNoHealthyEndpoints() { LocalityLbEndpoints.create( Collections.singletonList(LbEndpoint.create(endpoint2, 100, true /* isHealthy */)), 10 /* localityWeight */, 1 /* priority */); - String priority = CLUSTER1 + "[priority1]"; xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME1, ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - PriorityChildConfig priorityChildConfig = priorityLbConfig.childConfigs.get(priority); - ClusterImplConfig clusterImplConfig = - (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - WeightedTargetConfig weightedTargetConfig = - (WeightedTargetConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(weightedTargetConfig.targets.keySet()).containsExactly(locality2.toString()); + for (EquivalentAddressGroup eag : childBalancer.addresses) { + assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY)).isEqualTo(locality2); + } } @Test @@ -540,7 +716,7 @@ public void handleEdsResource_ignorePrioritiesWithNoHealthyEndpoints() { LocalityLbEndpoints.create( Collections.singletonList(LbEndpoint.create(endpoint2, 200, true /* isHealthy */)), 10 /* localityWeight */, 2 /* priority */); - String priority2 = CLUSTER1 + "[priority2]"; + String priority2 = CLUSTER1 + "[child2]"; xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME1, ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); @@ -610,7 +786,6 @@ public void onlyLogicalDnsCluster_handleRefreshNameResolution() { EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); assertThat(resolver.refreshCount).isEqualTo(0); - verify(helper).ignoreRefreshNameResolutionCheck(); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); childBalancer.helper.refreshNameResolution(); assertThat(resolver.refreshCount).isEqualTo(1); @@ -676,7 +851,6 @@ public void onlyLogicalDnsCluster_refreshNameResolutionRaceWithResolutionError() FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); assertThat(resolver.refreshCount).isEqualTo(0); - verify(helper).ignoreRefreshNameResolutionCheck(); childBalancer.helper.refreshNameResolution(); assertThat(resolver.refreshCount).isEqualTo(1); @@ -725,14 +899,14 @@ public void edsClustersAndLogicalDnsCluster_receivedEndpoints() { assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[priority1]", CLUSTER_DNS + "[priority0]").inOrder(); + .containsExactly(CLUSTER1 + "[child1]", CLUSTER_DNS + "[child0]").inOrder(); assertAddressesEqual(Arrays.asList(endpoint3, endpoint1, endpoint2), childBalancer.addresses); // ordered by cluster then addresses assertAddressesEqual(AddressFilter.filter(AddressFilter.filter( - childBalancer.addresses, CLUSTER1 + "[priority1]"), locality1.toString()), + childBalancer.addresses, CLUSTER1 + "[child1]"), locality1.toString()), Collections.singletonList(endpoint3)); assertAddressesEqual(AddressFilter.filter(AddressFilter.filter( - childBalancer.addresses, CLUSTER_DNS + "[priority0]"), + childBalancer.addresses, CLUSTER_DNS + "[child0]"), Locality.create("", "", "").toString()), Arrays.asList(endpoint1, endpoint2)); } @@ -757,7 +931,7 @@ public void noEdsResourceExists_useDnsResolutionResults() { FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); String priority = Iterables.getOnlyElement( ((PriorityLbConfig) childBalancer.config).priorities); - assertThat(priority).isEqualTo(CLUSTER_DNS + "[priority0]"); + assertThat(priority).isEqualTo(CLUSTER_DNS + "[child0]"); assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); } @@ -781,7 +955,7 @@ public void edsResourceRevoked_dnsResolutionError_shutDownChildLbPolicyAndReturn assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[priority1]"); + .containsExactly(CLUSTER1 + "[child1]"); assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); assertThat(childBalancer.shutdown).isFalse(); xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); @@ -850,6 +1024,24 @@ public void resolutionErrorBeforeChildLbCreated_returnErrorPickerIfAllClustersEn null); } + @Test + public void resolutionErrorBeforeChildLbCreated_edsOnly_returnErrorPicker() { + ClusterResolverConfig config = new ClusterResolverConfig( + Arrays.asList(edsDiscoveryMechanism1), roundRobin); + deliverLbConfig(config); + assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); + assertThat(childBalancers).isEmpty(); + reset(helper); + xdsClient.deliverError(Status.RESOURCE_EXHAUSTED.withDescription("OOM")); + assertThat(childBalancers).isEmpty(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); + } + @Test public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { ClusterResolverConfig config = new ClusterResolverConfig( @@ -887,7 +1079,7 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[priority1]", CLUSTER_DNS + "[priority0]"); + .containsExactly(CLUSTER1 + "[child1]", CLUSTER_DNS + "[child0]"); assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); loadBalancer.handleNameResolutionError(Status.UNAVAILABLE.withDescription("unreachable")); @@ -898,7 +1090,7 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug } private void deliverLbConfig(ClusterResolverConfig config) { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setAttributes( @@ -985,16 +1177,24 @@ public String toString() { } private static final class FakeXdsClient extends XdsClient { - private final Map watchers = new HashMap<>(); + private final Map> watchers = new HashMap<>(); + @Override - void watchEdsResource(String resourceName, EdsResourceWatcher watcher) { + @SuppressWarnings("unchecked") + void watchXdsResource(XdsResourceType type, String resourceName, + ResourceWatcher watcher) { + assertThat(type.typeName()).isEqualTo("EDS"); assertThat(watchers).doesNotContainKey(resourceName); - watchers.put(resourceName, watcher); + watchers.put(resourceName, (ResourceWatcher) watcher); } @Override - void cancelEdsResourceWatch(String resourceName, EdsResourceWatcher watcher) { + @SuppressWarnings("unchecked") + void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { + assertThat(type.typeName()).isEqualTo("EDS"); assertThat(watchers).containsKey(resourceName); watchers.remove(resourceName); } @@ -1009,7 +1209,7 @@ void deliverClusterLoadAssignment(String resource, List dropOverlo Map localityLbEndpointsMap) { if (watchers.containsKey(resource)) { watchers.get(resource).onChanged( - new EdsUpdate(resource, localityLbEndpointsMap, dropOverloads)); + new XdsEndpointResource.EdsUpdate(resource, localityLbEndpointsMap, dropOverloads)); } } @@ -1020,7 +1220,7 @@ void deliverResourceNotFound(String resource) { } void deliverError(Status error) { - for (EdsResourceWatcher watcher : watchers.values()) { + for (ResourceWatcher watcher : watchers.values()) { watcher.onError(error); } } @@ -1133,9 +1333,10 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { addresses = resolvedAddresses.getAddresses(); config = resolvedAddresses.getLoadBalancingPolicyConfig(); + return true; } @Override diff --git a/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java b/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java index 73632c8addb..4319700b7f5 100644 --- a/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java @@ -20,7 +20,7 @@ import com.google.common.collect.ImmutableMap; import io.grpc.internal.JsonParser; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.io.IOException; import java.util.HashMap; import java.util.Map; diff --git a/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java new file mode 100644 index 00000000000..ea764e67e40 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java @@ -0,0 +1,301 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Message; +import com.google.protobuf.UInt32Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TrafficDirection; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.config.listener.v3.ApiListener; +import io.envoyproxy.envoy.config.listener.v3.Filter; +import io.envoyproxy.envoy.config.listener.v3.FilterChain; +import io.envoyproxy.envoy.config.listener.v3.FilterChainMatch; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.config.route.v3.NonForwardingAction; +import io.envoyproxy.envoy.config.route.v3.Route; +import io.envoyproxy.envoy.config.route.v3.RouteAction; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.config.route.v3.RouteMatch; +import io.envoyproxy.envoy.config.route.v3.VirtualHost; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.grpc.NameResolverRegistry; +import io.grpc.Server; +import io.grpc.netty.NettyServerBuilder; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; + +/** + * Starts a control plane server and sets up the test to use it. Initialized with a default + * configuration, but also provides methods for updating the configuration. + */ +public class ControlPlaneRule extends TestWatcher { + private static final Logger logger = Logger.getLogger(ControlPlaneRule.class.getName()); + + private static final String SCHEME = "test-xds"; + private static final String RDS_NAME = "route-config.googleapis.com"; + private static final String CLUSTER_NAME = "cluster0"; + private static final String EDS_NAME = "eds-service-0"; + private static final String SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT = + "grpc/server?udpa.resource.listening_address="; + private static final String SERVER_HOST_NAME = "test-server"; + private static final String HTTP_CONNECTION_MANAGER_TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3" + + ".HttpConnectionManager"; + + private Server server; + private XdsTestControlPlaneService controlPlaneService; + private XdsNameResolverProvider nameResolverProvider; + + /** + * Returns the test control plane service interface. + */ + public XdsTestControlPlaneService getService() { + return controlPlaneService; + } + + /** + * Returns the server instance. + */ + public Server getServer() { + return server; + } + + @Override protected void starting(Description description) { + // Start the control plane server. + try { + controlPlaneService = new XdsTestControlPlaneService(); + NettyServerBuilder controlPlaneServerBuilder = NettyServerBuilder.forPort(0) + .addService(controlPlaneService); + server = controlPlaneServerBuilder.build().start(); + } catch (Exception e) { + throw new AssertionError("unable to start the control plane server", e); + } + + // Configure and register an xDS name resolver so that gRPC knows how to connect to the server. + nameResolverProvider = XdsNameResolverProvider.createForTest(SCHEME, + defaultBootstrapOverride()); + NameResolverRegistry.getDefaultRegistry().register(nameResolverProvider); + } + + @Override protected void finished(Description description) { + if (server != null) { + server.shutdownNow(); + try { + if (!server.awaitTermination(5, TimeUnit.SECONDS)) { + logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); + } + } catch (InterruptedException e) { + throw new AssertionError("unable to shut down control plane server", e); + } + } + NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); + } + + /** + * For test purpose, use boostrapOverride to programmatically provide bootstrap info. + */ + public Map defaultBootstrapOverride() { + return ImmutableMap.of( + "node", ImmutableMap.of( + "id", UUID.randomUUID().toString(), + "cluster", "cluster0"), + "xds_servers", Collections.singletonList( + + ImmutableMap.of( + "server_uri", "localhost:" + server.getPort(), + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ), + "server_features", Collections.singletonList("xds_v3") + ) + ), + "server_listener_resource_name_template", SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT + ); + } + + void setLdsConfig(Listener serverListener, Listener clientListener) { + getService().setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT, serverListener, + SERVER_HOST_NAME, clientListener)); + } + + void setRdsConfig(RouteConfiguration routeConfiguration) { + getService().setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(RDS_NAME, routeConfiguration)); + } + + void setCdsConfig(Cluster cluster) { + getService().setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cluster)); + } + + void setEdsConfig(ClusterLoadAssignment clusterLoadAssignment) { + getService().setXdsConfig(ADS_TYPE_URL_EDS, + ImmutableMap.of(EDS_NAME, clusterLoadAssignment)); + } + + /** + * Builds a new default RDS configuration. + */ + static RouteConfiguration buildRouteConfiguration(String authority) { + io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHost = VirtualHost.newBuilder() + .addDomains(authority) + .addRoutes( + Route.newBuilder() + .setMatch( + RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute( + RouteAction.newBuilder().setCluster(CLUSTER_NAME).build()).build()).build(); + return RouteConfiguration.newBuilder().setName(RDS_NAME).addVirtualHosts(virtualHost).build(); + } + + /** + * Builds a new default CDS configuration. + */ + static Cluster buildCluster() { + return Cluster.newBuilder() + .setName(CLUSTER_NAME) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig( + Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_NAME) + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder().build()) + .build()) + .build()) + .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + .build(); + } + + /** + * Builds a new default EDS configuration. + */ + static ClusterLoadAssignment buildClusterLoadAssignment(String hostName, int port) { + Address address = Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder().setAddress(hostName).setPortValue(port).build()).build(); + LocalityLbEndpoints endpoints = LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(10)) + .setPriority(0) + .addLbEndpoints( + LbEndpoint.newBuilder() + .setEndpoint( + Endpoint.newBuilder().setAddress(address).build()) + .setHealthStatus(HealthStatus.HEALTHY) + .build()).build(); + return ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_NAME) + .addEndpoints(endpoints) + .build(); + } + + /** + * Builds a new client listener. + */ + static Listener buildClientListener(String name) { + HttpFilter httpFilter = HttpFilter.newBuilder() + .setName("terminal-filter") + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build(); + ApiListener apiListener = ApiListener.newBuilder().setApiListener(Any.pack( + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3 + .HttpConnectionManager.newBuilder() + .setRds( + Rds.newBuilder() + .setRouteConfigName(RDS_NAME) + .setConfigSource( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance()))) + .addAllHttpFilters(Collections.singletonList(httpFilter)) + .build(), + HTTP_CONNECTION_MANAGER_TYPE_URL)).build(); + return Listener.newBuilder() + .setName(name) + .setApiListener(apiListener).build(); + } + + /** + * Builds a new server listener. + */ + static Listener buildServerListener() { + HttpFilter routerFilter = HttpFilter.newBuilder() + .setName("terminal-filter") + .setTypedConfig( + Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build(); + VirtualHost virtualHost = io.envoyproxy.envoy.config.route.v3.VirtualHost.newBuilder() + .setName("virtual-host-0") + .addDomains("*") + .addRoutes( + Route.newBuilder() + .setMatch( + RouteMatch.newBuilder().setPrefix("/").build()) + .setNonForwardingAction(NonForwardingAction.newBuilder().build()) + .build()).build(); + RouteConfiguration routeConfig = RouteConfiguration.newBuilder() + .addVirtualHosts(virtualHost) + .build(); + io.envoyproxy.envoy.config.listener.v3.Filter filter = Filter.newBuilder() + .setName("network-filter-0") + .setTypedConfig( + Any.pack( + HttpConnectionManager.newBuilder() + .setRouteConfig(routeConfig) + .addAllHttpFilters(Collections.singletonList(routerFilter)) + .build())).build(); + FilterChainMatch filterChainMatch = FilterChainMatch.newBuilder() + .setSourceType(FilterChainMatch.ConnectionSourceType.ANY) + .build(); + FilterChain filterChain = FilterChain.newBuilder() + .setName("filter-chain-0") + .setFilterChainMatch(filterChainMatch) + .addFilters(filter) + .build(); + return Listener.newBuilder() + .setName(SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT) + .setTrafficDirection(TrafficDirection.INBOUND) + .addFilterChains(filterChain) + .build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index 0d929939109..6892324a9bc 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -17,10 +17,6 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -49,13 +45,13 @@ import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcServerRule; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.XdsClient.ResourceMetadata; import io.grpc.xds.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; -import java.util.EnumMap; +import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -77,10 +73,14 @@ public class CsdsServiceTest { EnvoyProtoData.Node.newBuilder().setId(NODE_ID).build(); private static final BootstrapInfo BOOTSTRAP_INFO = BootstrapInfo.builder() .servers(ImmutableList.of( - ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create(), true))) + ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()))) .node(BOOTSTRAP_NODE) .build(); - private static final XdsClient XDS_CLIENT_NO_RESOURCES = new FakeXdsClient(); + private static final FakeXdsClient XDS_CLIENT_NO_RESOURCES = new FakeXdsClient(); + private static final XdsResourceType LDS = XdsListenerResource.getInstance(); + private static final XdsResourceType CDS = XdsClusterResource.getInstance(); + private static final XdsResourceType RDS = XdsRouteConfigureResource.getInstance(); + private static final XdsResourceType EDS = XdsEndpointResource.getInstance(); @RunWith(JUnit4.class) public static class ServiceTests { @@ -126,7 +126,7 @@ public void fetchClientConfig_invalidArgument() { public void fetchClientConfig_unexpectedException() { XdsClient throwingXdsClient = new FakeXdsClient() { @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { return Futures.immediateFailedFuture( new IllegalArgumentException("IllegalArgumentException")); @@ -150,12 +150,12 @@ public void fetchClientConfig_unexpectedException() { public void fetchClientConfig_interruptedException() { XdsClient throwingXdsClient = new FakeXdsClient() { @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { return Futures.submit( - new Callable>>() { + new Callable, Map>>() { @Override - public Map> call() { + public Map, Map> call() { Thread.currentThread().interrupt(); return null; } @@ -264,7 +264,7 @@ private void verifyResponse(ClientStatusResponse response) { assertThat(response.getConfigCount()).isEqualTo(1); ClientConfig clientConfig = response.getConfig(0); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(clientConfig); + verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); } private void verifyRequestInvalidResponseStatus(Status status) { @@ -321,18 +321,29 @@ public void metadataStatusToClientStatus() { @Test public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() throws InterruptedException { - ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(new FakeXdsClient() { + FakeXdsClient fakeXdsClient = new FakeXdsClient() { @Override - protected Map> + protected Map, Map> getSubscribedResourcesMetadata() { - return new ImmutableMap.Builder>() + return new ImmutableMap.Builder, Map>() .put(LDS, ImmutableMap.of("subscribedResourceName.LDS", METADATA_ACKED_LDS)) .put(RDS, ImmutableMap.of("subscribedResourceName.RDS", METADATA_ACKED_RDS)) .put(CDS, ImmutableMap.of("subscribedResourceName.CDS", METADATA_ACKED_CDS)) .put(EDS, ImmutableMap.of("subscribedResourceName.EDS", METADATA_ACKED_EDS)) - .build(); + .buildOrThrow(); } - }); + + @Override + public Map> getSubscribedResourceTypesWithTypeUrl() { + return ImmutableMap.of( + LDS.typeUrl(), LDS, + RDS.typeUrl(), RDS, + CDS.typeUrl(), CDS, + EDS.typeUrl(), EDS + ); + } + }; + ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient); verifyClientConfigNode(clientConfig); @@ -340,7 +351,8 @@ public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() // is propagated to the correct resource types. int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); assertThat(xdsConfigCount).isEqualTo(4); - EnumMap configDumps = mapConfigDumps(clientConfig); + Map, GenericXdsConfig> configDumps = mapConfigDumps(fakeXdsClient, + clientConfig); assertThat(configDumps.keySet()).containsExactly(LDS, RDS, CDS, EDS); // LDS. @@ -373,7 +385,7 @@ public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() public void getClientConfigForXdsClient_noSubscribedResources() throws InterruptedException { ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(clientConfig); + verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); } } @@ -381,10 +393,11 @@ public void getClientConfigForXdsClient_noSubscribedResources() throws Interrupt * Assuming {@link MetadataToProtoTests} passes, and metadata converted to corresponding * config dumps correctly, perform a minimal verification of the general shape of ClientConfig. */ - private static void verifyClientConfigNoResources(ClientConfig clientConfig) { + private static void verifyClientConfigNoResources(FakeXdsClient xdsClient, + ClientConfig clientConfig) { int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); assertThat(xdsConfigCount).isEqualTo(0); - EnumMap configDumps = mapConfigDumps(clientConfig); + Map, GenericXdsConfig> configDumps = mapConfigDumps(xdsClient, clientConfig); assertThat(configDumps).isEmpty(); } @@ -398,25 +411,28 @@ private static void verifyClientConfigNode(ClientConfig clientConfig) { assertThat(node).isEqualTo(BOOTSTRAP_NODE.toEnvoyProtoNode()); } - private static EnumMap mapConfigDumps(ClientConfig config) { - EnumMap xdsConfigMap = new EnumMap<>(ResourceType.class); + private static Map, GenericXdsConfig> mapConfigDumps(FakeXdsClient client, + ClientConfig config) { + Map, GenericXdsConfig> xdsConfigMap = new HashMap<>(); List xdsConfigList = config.getGenericXdsConfigsList(); for (GenericXdsConfig genericXdsConfig : xdsConfigList) { - ResourceType type = ResourceType.fromTypeUrl(genericXdsConfig.getTypeUrl()); - assertThat(type).isNotEqualTo(ResourceType.UNKNOWN); + XdsResourceType type = client.getSubscribedResourceTypesWithTypeUrl() + .get(genericXdsConfig.getTypeUrl()); + assertThat(type).isNotNull(); assertThat(xdsConfigMap).doesNotContainKey(type); xdsConfigMap.put(type, genericXdsConfig); } return xdsConfigMap; } - private static class FakeXdsClient extends XdsClient { - protected Map> getSubscribedResourcesMetadata() { + private static class FakeXdsClient extends XdsClient implements XdsClient.ResourceStore { + protected Map, Map> + getSubscribedResourcesMetadata() { return ImmutableMap.of(); } @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { return Futures.immediateFuture(getSubscribedResourcesMetadata()); } @@ -425,6 +441,18 @@ protected Map> getSubscribedResource BootstrapInfo getBootstrapInfo() { return BOOTSTRAP_INFO; } + + @Nullable + @Override + public Collection getSubscribedResources(ServerInfo serverInfo, + XdsResourceType type) { + return null; + } + + @Override + public Map> getSubscribedResourceTypesWithTypeUrl() { + return ImmutableMap.of(); + } } private static class FakeXdsClientPoolFactory implements XdsClientPoolFactory { diff --git a/xds/src/test/java/io/grpc/xds/DataPlaneRule.java b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java new file mode 100644 index 00000000000..faa79444071 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.net.InetSocketAddress; +import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; + +/** + * This rule creates a new server instance in the "data plane" that is configured by a "control + * plane" xDS server. + */ +public class DataPlaneRule extends TestWatcher { + private static final Logger logger = Logger.getLogger(DataPlaneRule.class.getName()); + + private static final String SERVER_HOST_NAME = "test-server"; + private static final String SCHEME = "test-xds"; + + private final ControlPlaneRule controlPlane; + private Server server; + private HashSet channels = new HashSet<>(); + + /** + * Creates a new {@link DataPlaneRule} that is connected to the given {@link ControlPlaneRule}. + */ + public DataPlaneRule(ControlPlaneRule controlPlane) { + this.controlPlane = controlPlane; + } + + /** + * Returns the server instance. + */ + public Server getServer() { + return server; + } + + /** + * Returns a newly created {@link ManagedChannel} to the server. + */ + public ManagedChannel getManagedChannel() { + ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + SERVER_HOST_NAME, + InsecureChannelCredentials.create()).build(); + channels.add(channel); + return channel; + } + + @Override + protected void starting(Description description) { + // Let the control plane know about our new server. + controlPlane.setLdsConfig(ControlPlaneRule.buildServerListener(), + ControlPlaneRule.buildClientListener(SERVER_HOST_NAME) + ); + + // Start up the server. + try { + startServer(controlPlane.defaultBootstrapOverride()); + } catch (Exception e) { + throw new AssertionError("unable to start the data plane server", e); + } + + // Provide the rest of the configuration to the control plane. + controlPlane.setRdsConfig(ControlPlaneRule.buildRouteConfiguration(SERVER_HOST_NAME)); + controlPlane.setCdsConfig(ControlPlaneRule.buildCluster()); + InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); + controlPlane.setEdsConfig( + ControlPlaneRule.buildClusterLoadAssignment(edsInetSocketAddress.getHostName(), + edsInetSocketAddress.getPort())); + } + + @Override + protected void finished(Description description) { + if (server != null) { + // Shut down any lingering open channels to the server. + for (ManagedChannel channel : channels) { + if (!channel.isShutdown()) { + channel.shutdownNow(); + } + } + + // Shut down the server itself. + server.shutdownNow(); + try { + if (!server.awaitTermination(5, TimeUnit.SECONDS)) { + logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); + } + } catch (InterruptedException e) { + throw new AssertionError("unable to shut down data plane server", e); + } + } + } + + private void startServer(Map bootstrapOverride) throws Exception { + ServerInterceptor metadataInterceptor = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata requestHeaders, ServerCallHandler next) { + logger.fine("Received following metadata: " + requestHeaders); + + // Make a copy of the headers so that it can be read in a thread-safe manner when copying + // it to the response headers. + Metadata headersToReturn = new Metadata(); + headersToReturn.merge(requestHeaders); + + return next.startCall(new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata responseHeaders) { + responseHeaders.merge(headersToReturn); + super.sendHeaders(responseHeaders); + } + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + } + }, requestHeaders); + } + }; + + SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl = + new SimpleServiceGrpc.SimpleServiceImplBase() { + @Override + public void unaryRpc( + SimpleRequest request, StreamObserver responseObserver) { + SimpleResponse response = + SimpleResponse.newBuilder().setResponseMessage("Hi, xDS!").build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + }; + + XdsServerBuilder serverBuilder = XdsServerBuilder.forPort( + 0, InsecureServerCredentials.create()) + .addService(simpleServiceImpl) + .intercept(metadataInterceptor) + .overrideBootstrapForTest(bootstrapOverride); + server = serverBuilder.build().start(); + logger.log(Level.FINE, "data plane server started"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/EnvoyProtoDataTest.java b/xds/src/test/java/io/grpc/xds/EnvoyProtoDataTest.java index 1e372e5974e..046f4abcec3 100644 --- a/xds/src/test/java/io/grpc/xds/EnvoyProtoDataTest.java +++ b/xds/src/test/java/io/grpc/xds/EnvoyProtoDataTest.java @@ -86,41 +86,6 @@ public void convertNode() { .addClientFeatures("feature-2") .build(); assertThat(node.toEnvoyProtoNode()).isEqualTo(nodeProto); - - @SuppressWarnings("deprecation") // Deprecated v2 API setBuildVersion(). - io.envoyproxy.envoy.api.v2.core.Node nodeProtoV2 = - io.envoyproxy.envoy.api.v2.core.Node.newBuilder() - .setId("node-id") - .setCluster("cluster") - .setMetadata(Struct.newBuilder() - .putFields("TRAFFICDIRECTOR_INTERCEPTION_PORT", - Value.newBuilder().setStringValue("ENVOY_PORT").build()) - .putFields("TRAFFICDIRECTOR_NETWORK_NAME", - Value.newBuilder().setStringValue("VPC_NETWORK_NAME").build())) - .setLocality( - io.envoyproxy.envoy.api.v2.core.Locality.newBuilder() - .setRegion("region") - .setZone("zone") - .setSubZone("subzone")) - .addListeningAddresses( - io.envoyproxy.envoy.api.v2.core.Address.newBuilder() - .setSocketAddress( - io.envoyproxy.envoy.api.v2.core.SocketAddress.newBuilder() - .setAddress("www.foo.com") - .setPortValue(8080))) - .addListeningAddresses( - io.envoyproxy.envoy.api.v2.core.Address.newBuilder() - .setSocketAddress( - io.envoyproxy.envoy.api.v2.core.SocketAddress.newBuilder() - .setAddress("www.bar.com") - .setPortValue(8088))) - .setBuildVersion("v1") - .setUserAgentName("agent") - .setUserAgentVersion("1.1") - .addClientFeatures("feature-1") - .addClientFeatures("feature-2") - .build(); - assertThat(node.toEnvoyProtoNodeV2()).isEqualTo(nodeProtoV2); } @Test diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java new file mode 100644 index 00000000000..0c3cf61b28e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; + +import com.github.xds.type.v3.TypedStruct; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.LoadBalancerRegistry; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.RuleChain; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Xds integration tests using a local control plane, implemented in {@link + * XdsTestControlPlaneService}. Test cases can inject xds configs to the control plane for testing. + * + *

    Test components: + * 1) A Control Plane {@link XdsTestControlPlaneService} accepts xds requests from multiple clients + * from the Data Plane, see {@link ControlPlaneRule}. + * 2) A test xDS server {@link XdsServerWrapper}, see {@link DataPlaneRule}. + * 3) A test xDS client that uses a testing scheme {@link XdsNameResolverProvider#createForTest}, + * see {@link DataPlaneRule}. + * + *

    The configuration dependency and ephemeral port allocation requires the components to + * be initialized in a certain order: + * 1) Start the Control Plane server {@link XdsTestControlPlaneService}. After start the bootstrap + * information (w/ Control Plane's address) can be constructed for the Data Plane to initialize. + * 2) Set LDS and RDS config at the Control Plane. Get the bootstrap file from the Control + * Plane from 1). And then start the test xDS server (requires LDS/RDS and bootstrap file to start). + * 3) Construct EDS config w/ test server address from 2). Set CDS and EDS Config at the Control + * Plane. Then start the test xDS client (requires EDS to do xDS name resolution). + */ +@RunWith(JUnit4.class) +public class FakeControlPlaneXdsIntegrationTest { + + public ControlPlaneRule controlPlane; + public DataPlaneRule dataPlane; + + /** + * The {@link ControlPlaneRule} should run before the {@link DataPlaneRule}. + */ + @Rule + public RuleChain ruleChain() { + controlPlane = new ControlPlaneRule(); + dataPlane = new DataPlaneRule(controlPlane); + return RuleChain.outerRule(controlPlane).around(dataPlane); + } + + @Test + public void pingPong() throws Exception { + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.newBuilder() + .build(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + } + + @Test + public void pingPong_metadataLoadBalancer() throws Exception { + MetadataLoadBalancerProvider metadataLbProvider = new MetadataLoadBalancerProvider(); + try { + LoadBalancerRegistry.getDefaultRegistry().register(metadataLbProvider); + + // Use the LoadBalancingPolicy to configure a custom LB that adds a header to server calls. + Policy metadataLbPolicy = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + TypedStruct.newBuilder().setTypeUrl("type.googleapis.com/test.MetadataLoadBalancer") + .setValue(Struct.newBuilder() + .putFields("metadataKey", Value.newBuilder().setStringValue("foo").build()) + .putFields("metadataValue", Value.newBuilder().setStringValue("bar").build())) + .build()))).build(); + Policy wrrLocalityPolicy = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder().setTypedConfig( + Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies(metadataLbPolicy)).build()))) + .build(); + controlPlane.setCdsConfig( + ControlPlaneRule.buildCluster().toBuilder().setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder() + .addPolicies(wrrLocalityPolicy)).build()); + + ResponseHeaderClientInterceptor responseHeaderInterceptor + = new ResponseHeaderClientInterceptor(); + + // We add an interceptor to catch the response headers from the server. + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + dataPlane.getManagedChannel()).withInterceptors(responseHeaderInterceptor); + SimpleRequest request = SimpleRequest.newBuilder() + .build(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + + // Make sure we got back the header we configured the LB with. + assertThat(responseHeaderInterceptor.reponseHeaders.get( + Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo("bar"); + } finally { + LoadBalancerRegistry.getDefaultRegistry().deregister(metadataLbProvider); + } + } + + // Captures response headers from the server. + private static class ResponseHeaderClientInterceptor implements ClientInterceptor { + Metadata reponseHeaders; + + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, Channel next) { + + return new SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + super.start(new ForwardingClientCallListener() { + @Override + protected ClientCall.Listener delegate() { + return responseListener; + } + + @Override + public void onHeaders(Metadata headers) { + reponseHeaders = headers; + } + }, headers); + } + }; + } + } + + /** + * Basic test to make sure RING_HASH configuration works. + */ + @Test + public void pingPong_ringHash() { + controlPlane.setCdsConfig( + ControlPlaneRule.buildCluster().toBuilder() + .setLbPolicy(LbPolicy.RING_HASH).build()); + + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.newBuilder() + .build(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + } +} diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index a4efb226ce9..685102477cc 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -18,7 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; -import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; +import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -40,8 +40,8 @@ import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -64,7 +64,6 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; - import org.junit.After; import org.junit.Rule; import org.junit.Test; diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java index 2e8519b150d..51ec8adcc24 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java @@ -96,9 +96,9 @@ public void parseLoadBalancingConfig_invalid_negativeSize() throws IOException { ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'choiceCount'"); + .isEqualTo("Invalid 'choiceCount' in least_request_experimental config"); } @Test @@ -107,9 +107,9 @@ public void parseLoadBalancingConfig_invalid_tooSmallSize() throws IOException { ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'choiceCount'"); + .isEqualTo("Invalid 'choiceCount' in least_request_experimental config"); } @Test diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java index 2d09dbfe1fc..e7a3a28e6aa 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -28,6 +28,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; @@ -154,8 +155,9 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -205,9 +207,10 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -227,8 +230,9 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); verify(newSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); @@ -246,25 +250,16 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); - // test going from non-empty to empty - loadBalancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes(affinity) - .build()); - - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); - verifyNoMoreInteractions(mockHelper); } @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); Ref subchannelStateInfo = subchannel.getAttributes().get( STATE_INFO); @@ -304,9 +299,10 @@ public void pickAfterConfigChange() { final LeastRequestConfig oldConfig = new LeastRequestConfig(4); final LeastRequestConfig newConfig = new LeastRequestConfig(6); final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) .setLoadBalancingPolicyConfig(oldConfig).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(2)) @@ -316,9 +312,10 @@ public void pickAfterConfigChange() { pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(oldConfig.choiceCount)).nextInt(1); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) .setLoadBalancingPolicyConfig(newConfig).build()); + assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)) .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); @@ -331,9 +328,10 @@ public void pickAfterConfigChange() { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); @@ -350,9 +348,10 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -388,9 +387,10 @@ public void stayTransientFailureUntilReady() { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -418,10 +418,11 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { public void pickerLeastRequest() throws Exception { int choiceCount = 2; // This should add inFlight counters to all subchannels. - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) .build()); + assertThat(addressesAccepted).isTrue(); assertEquals(3, loadBalancer.getSubchannels().size()); @@ -504,10 +505,11 @@ public void nameResolutionErrorWithNoChannels() throws Exception { public void nameResolutionErrorWithActiveChannels() throws Exception { int choiceCount = 8; final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) .setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); @@ -537,9 +539,10 @@ public void subchannelStateIsolation() throws Exception { Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -575,10 +578,14 @@ public void subchannelStateIsolation() throws Exception { assertThat(pickers.hasNext()).isFalse(); } - @Test(expected = IllegalArgumentException.class) + @Test public void readyPicker_emptyList() { - // ready picker list must be non-empty - new ReadyPicker(Collections.emptyList(), 2, mockRandom); + try { + // ready picker list must be non-empty + new ReadyPicker(Collections.emptyList(), 2, mockRandom); + fail(); + } catch (IllegalArgumentException expected) { + } } @Test @@ -608,6 +615,15 @@ public void internalPickerComparisons() { assertFalse(ready5.isEquivalentTo(ready6)); } + @Test + public void emptyAddresses() { + assertThat(loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(affinity) + .build())).isFalse(); + } + private static List getList(SubchannelPicker picker) { return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : Collections.emptyList(); diff --git a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java new file mode 100644 index 00000000000..c7217cb45e3 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java @@ -0,0 +1,364 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.github.xds.type.v3.TypedStruct; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFunction; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest; +import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash; +import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import java.util.List; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit test for {@link LoadBalancerConfigFactory}. + */ +@RunWith(JUnit4.class) +public class LoadBalancerConfigFactoryTest { + + private static final Policy ROUND_ROBIN_POLICY = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig( + Any.pack(RoundRobin.getDefaultInstance()))).build(); + + private static final long RING_HASH_MIN_RING_SIZE = 1; + private static final long RING_HASH_MAX_RING_SIZE = 2; + private static final Policy RING_HASH_POLICY = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + RingHash.newBuilder().setMinimumRingSize(UInt64Value.of(RING_HASH_MIN_RING_SIZE)) + .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) + .setHashFunction(RingHash.HashFunction.XX_HASH).build()))).build(); + + private static final int LEAST_REQUEST_CHOICE_COUNT = 10; + private static final Policy LEAST_REQUEST_POLICY = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + LeastRequest.newBuilder().setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT)) + .build()))).build(); + + private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy"; + private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount"; + private static final double CUSTOM_POLICY_FIELD_VALUE = 2; + private static final Policy CUSTOM_POLICY = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + TypedStruct.newBuilder().setTypeUrl( + "type.googleapis.com/" + CUSTOM_POLICY_NAME).setValue( + Struct.newBuilder().putFields(CUSTOM_POLICY_FIELD_KEY, + Value.newBuilder().setNumberValue(CUSTOM_POLICY_FIELD_VALUE).build())) + .build()))).build(); + private static final Policy CUSTOM_POLICY_UDPA = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + com.github.udpa.udpa.type.v1.TypedStruct.newBuilder().setTypeUrl( + "type.googleapis.com/" + CUSTOM_POLICY_NAME).setValue( + Struct.newBuilder().putFields(CUSTOM_POLICY_FIELD_KEY, + Value.newBuilder().setNumberValue(CUSTOM_POLICY_FIELD_VALUE).build())) + .build()))).build(); + private static final FakeCustomLoadBalancerProvider CUSTOM_POLICY_PROVIDER + = new FakeCustomLoadBalancerProvider(); + + private static final LbConfig VALID_ROUND_ROBIN_CONFIG = new LbConfig("wrr_locality_experimental", + ImmutableMap.of("childPolicy", + ImmutableList.of(ImmutableMap.of("round_robin", ImmutableMap.of())))); + private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental", + ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize", + (double) RING_HASH_MAX_RING_SIZE)); + private static final LbConfig VALID_CUSTOM_CONFIG = new LbConfig(CUSTOM_POLICY_NAME, + ImmutableMap.of(CUSTOM_POLICY_FIELD_KEY, CUSTOM_POLICY_FIELD_VALUE)); + private static final LbConfig VALID_CUSTOM_CONFIG_IN_WRR = new LbConfig( + "wrr_locality_experimental", ImmutableMap.of("childPolicy", ImmutableList.of( + ImmutableMap.of(VALID_CUSTOM_CONFIG.getPolicyName(), + VALID_CUSTOM_CONFIG.getRawConfigValue())))); + private static final LbConfig VALID_LEAST_REQUEST_CONFIG = new LbConfig( + "least_request_experimental", + ImmutableMap.of("choiceCount", (double) LEAST_REQUEST_CHOICE_COUNT)); + + @After + public void deregisterCustomProvider() { + LoadBalancerRegistry.getDefaultRegistry().deregister(CUSTOM_POLICY_PROVIDER); + } + + @Test + public void roundRobin() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(ROUND_ROBIN_POLICY)); + + assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + } + + @Test + public void roundRobin_legacy() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.ROUND_ROBIN).build(); + + assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + } + + @Test + public void ringHash() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(RING_HASH_POLICY)) + .build(); + + assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); + } + + @Test + public void ringHash_legacy() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.RING_HASH).setRingHashLbConfig( + RingHashLbConfig.newBuilder().setMinimumRingSize(UInt64Value.of(RING_HASH_MIN_RING_SIZE)) + .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) + .setHashFunction(HashFunction.XX_HASH)).build(); + + assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); + } + + @Test + public void ringHash_invalidHash() { + Cluster cluster = newCluster( + Policy.newBuilder().setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack( + RingHash.newBuilder().setMinimumRingSize(UInt64Value.of(RING_HASH_MIN_RING_SIZE)) + .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) + .setHashFunction(RingHash.HashFunction.MURMUR_HASH_2).build()))).build()); + + assertResourceInvalidExceptionThrown(cluster, true, true, "Invalid ring hash function"); + } + + @Test + public void ringHash_invalidHash_legacy() { + Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.RING_HASH).setRingHashLbConfig( + RingHashLbConfig.newBuilder().setHashFunction(HashFunction.MURMUR_HASH_2)).build(); + + assertResourceInvalidExceptionThrown(cluster, true, true, "invalid ring hash function"); + } + + @Test + public void leastRequest() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(LEAST_REQUEST_POLICY)) + .build(); + + assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG); + } + + @Test + public void leastRequest_legacy() throws ResourceInvalidException { + System.setProperty("io.grpc.xds.experimentalEnableLeastRequest", "true"); + + Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.LEAST_REQUEST) + .setLeastRequestLbConfig( + LeastRequestLbConfig.newBuilder() + .setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT))).build(); + + LbConfig lbConfig = newLbConfig(cluster, true, true); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("least_request_experimental"); + assertThat( + JsonUtil.getNumberAsLong(childConfigs.get(0).getRawConfigValue(), "choiceCount")).isEqualTo( + LEAST_REQUEST_CHOICE_COUNT); + } + + @Test + public void leastRequest_notEnabled() { + System.setProperty("io.grpc.xds.experimentalEnableLeastRequest", "false"); + + Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.LEAST_REQUEST).build(); + + assertResourceInvalidExceptionThrown(cluster, false, true, "unsupported lb policy"); + } + + @Test + public void customRootLb_providerRegistered() throws ResourceInvalidException { + LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER); + + assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false, true)).isEqualTo(VALID_CUSTOM_CONFIG); + } + + @Test + public void customRootLb_providerNotRegistered() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(CUSTOM_POLICY)) + .build(); + + assertResourceInvalidExceptionThrown(cluster, false, true, "Invalid LoadBalancingPolicy"); + } + + // When a provider for the endpoint picking custom policy is available, the configuration should + // use it. + @Test + public void customLbInWrr_providerRegistered() throws ResourceInvalidException { + LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER); + + Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); + + assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); + } + + // When a provider for the endpoint picking custom policy is available, the configuration should + // use it. This one uses the legacy UDPA TypedStruct that is also supported. + @Test + public void customLbInWrr_providerRegistered_udpa() throws ResourceInvalidException { + LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER); + + Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(buildWrrPolicy(CUSTOM_POLICY_UDPA, ROUND_ROBIN_POLICY))).build(); + + assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); + } + + // When a provider for the custom wrr_locality child policy is NOT available, we should fall back + // to the round_robin that is also provided. + @Test + public void customLbInWrr_providerNotRegistered() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); + + assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + } + + // When a provider for the custom wrr_locality child policy is NOT available and no alternative + // child policy is provided, the configuration is invalid. + @Test + public void customLbInWrr_providerNotRegistered_noFallback() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies(buildWrrPolicy(CUSTOM_POLICY))).build(); + + assertResourceInvalidExceptionThrown(cluster, false, true, "Invalid LoadBalancingPolicy"); + } + + @Test + public void customConfig_notEnabled() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies(RING_HASH_POLICY)) + .build(); + + // Custom LB flag not set, so we use old logic that will default to round_robin. + assertThat(newLbConfig(cluster, true, false)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + } + + @Test + public void maxRecursion() { + Cluster cluster = Cluster.newBuilder() + .setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies( + buildWrrPolicy( // Wheee... + buildWrrPolicy( // ...eee... + buildWrrPolicy( // ...eee! + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + buildWrrPolicy( + ROUND_ROBIN_POLICY))))))))))))))))))).build(); + + assertResourceInvalidExceptionThrown(cluster, false, true, + "Maximum LB config recursion depth reached"); + } + + private Cluster newCluster(Policy... policies) { + return Cluster.newBuilder().setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder().addAllPolicies(Lists.newArrayList(policies))).build(); + } + + private static Policy buildWrrPolicy(Policy... childPolicy) { + return Policy.newBuilder().setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy( + LoadBalancingPolicy.newBuilder().addAllPolicies(Lists.newArrayList(childPolicy))) + .build()))).build(); + } + + private LbConfig newLbConfig(Cluster cluster, boolean enableLeastRequest, + boolean enableCustomConfig) + throws ResourceInvalidException { + return ServiceConfigUtil.unwrapLoadBalancingConfig( + LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest, enableCustomConfig)); + } + + private void assertResourceInvalidExceptionThrown(Cluster cluster, boolean enableLeastRequest, + boolean enableCustomConfig, String expectedMessage) { + try { + newLbConfig(cluster, enableLeastRequest, enableCustomConfig); + } catch (ResourceInvalidException e) { + assertThat(e).hasMessageThat().contains(expectedMessage); + return; + } + fail("ResourceInvalidException not thrown"); + } + + private static class FakeCustomLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return null; + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return CUSTOM_POLICY_NAME; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index 53952f89478..13944915b6b 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -33,12 +33,12 @@ import com.google.protobuf.Struct; import com.google.protobuf.Value; import com.google.protobuf.util.Durations; -import io.envoyproxy.envoy.api.v2.core.Node; -import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats; -import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats; -import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc; -import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest; -import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse; +import io.envoyproxy.envoy.config.core.v3.Node; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterStats; +import io.envoyproxy.envoy.config.endpoint.v3.UpstreamLocalityStats; +import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc; +import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; +import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; import io.grpc.Context; import io.grpc.Context.CancellationListener; import io.grpc.ManagedChannel; @@ -71,7 +71,6 @@ /** * Unit tests for {@link LoadReportClient}. */ -// TODO(chengyuanzhang): missing LRS V3 test. @RunWith(JUnit4.class) public class LoadReportClientTest { // bootstrap node identifier @@ -172,7 +171,7 @@ public void cancelled(Context context) { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, channel, Context.ROOT, false, NODE, + lrsClient = new LoadReportClient(loadStatsManager, channel, Context.ROOT, NODE, syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @@ -434,8 +433,15 @@ public void lrsStreamClosedAndRetried() { // Then breaks the RPC responseObserver.onError(Status.UNAVAILABLE.asException()); - // Will reset the retry sequence and retry immediately, because balancer has responded. + // Will reset the retry sequence retry after a delay. We want to always delay, to restrict any + // accidental closed loop of retries to 1 QPS. inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + // Fast-forward to a moment before the retry of backoff sequence 2 (2s) + fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(2) - 1); + verifyNoMoreInteractions(mockLoadReportingService); + // Then time for retry + fakeClock.forwardNanos(1); inOrder.verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture()); responseObserver = lrsResponseObserverCaptor.getValue(); assertThat(lrsRequestObservers).hasSize(1); @@ -446,12 +452,12 @@ public void lrsStreamClosedAndRetried() { fakeClock.forwardNanos(4); responseObserver.onError(Status.UNAVAILABLE.asException()); - // Will be on the first retry (2s) of backoff sequence 2. + // Will be on the second retry (20s) of backoff sequence 2. inOrder.verify(backoffPolicy2).nextBackoffNanos(); assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER)); // Fast-forward to a moment before the retry, the time spent in the last try is deducted. - fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(2) - 4 - 1); + fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(20) - 4 - 1); verifyNoMoreInteractions(mockLoadReportingService); // Then time for retry fakeClock.forwardNanos(1); @@ -471,7 +477,8 @@ public void lrsStreamClosedAndRetried() { ClusterStats clusterStats = Iterables.getOnlyElement(request.getClusterStatsList()); assertThat(clusterStats.getClusterName()).isEqualTo(CLUSTER1); assertThat(clusterStats.getClusterServiceName()).isEqualTo(EDS_SERVICE_NAME1); - assertThat(Durations.toSeconds(clusterStats.getLoadReportInterval())).isEqualTo(1L + 10L + 2L); + assertThat(Durations.toSeconds(clusterStats.getLoadReportInterval())) + .isEqualTo(1L + 10L + 2L + 20L); assertThat(Iterables.getOnlyElement(clusterStats.getDroppedRequestsList()).getCategory()) .isEqualTo("lb"); assertThat(Iterables.getOnlyElement(clusterStats.getDroppedRequestsList()).getDroppedCount()) @@ -490,7 +497,7 @@ public void lrsStreamClosedAndRetried() { // Wrapping up verify(backoffPolicyProvider, times(2)).get(); verify(backoffPolicy1, times(2)).nextBackoffNanos(); - verify(backoffPolicy2, times(1)).nextBackoffNanos(); + verify(backoffPolicy2, times(2)).nextBackoffNanos(); } @Test diff --git a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java new file mode 100644 index 00000000000..ecc0112a2e0 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Metadata; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.util.ForwardingLoadBalancer; +import io.grpc.util.ForwardingLoadBalancerHelper; +import java.util.Map; +import javax.annotation.Nonnull; + +/** + * A custom LB for testing purposes that simply delegates to round_robin and adds a metadata entry + * to each request. + */ +public class MetadataLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public NameResolver.ConfigOrError parseLoadBalancingPolicyConfig( + Map rawLoadBalancingPolicyConfig) { + String metadataKey = JsonUtil.getString(rawLoadBalancingPolicyConfig, "metadataKey"); + if (metadataKey == null) { + return NameResolver.ConfigOrError.fromError( + Status.UNAVAILABLE.withDescription("no 'metadataKey' defined")); + } + + String metadataValue = JsonUtil.getString(rawLoadBalancingPolicyConfig, "metadataValue"); + if (metadataValue == null) { + return NameResolver.ConfigOrError.fromError( + Status.UNAVAILABLE.withDescription("no 'metadataValue' defined")); + } + + return NameResolver.ConfigOrError.fromConfig( + new MetadataLoadBalancerConfig(metadataKey, metadataValue)); + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + MetadataHelper metadataHelper = new MetadataHelper(helper); + return new MetadataLoadBalancer(metadataHelper, + LoadBalancerRegistry.getDefaultRegistry().getProvider("round_robin") + .newLoadBalancer(metadataHelper)); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "test.MetadataLoadBalancer"; + } + + static class MetadataLoadBalancerConfig { + + final String metadataKey; + final String metadataValue; + + MetadataLoadBalancerConfig(String metadataKey, String metadataValue) { + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + } + + static class MetadataLoadBalancer extends ForwardingLoadBalancer { + + private final MetadataHelper helper; + private final LoadBalancer delegateLb; + + MetadataLoadBalancer(MetadataHelper helper, LoadBalancer delegateLb) { + this.helper = helper; + this.delegateLb = delegateLb; + } + + @Override + protected LoadBalancer delegate() { + return delegateLb; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + MetadataLoadBalancerConfig config + = (MetadataLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + helper.setMetadata(config.metadataKey, config.metadataValue); + delegateLb.handleResolvedAddresses(resolvedAddresses); + } + } + + /** + * Wraps the picker that is provided when the balancing change updates with the {@link + * MetadataPicker} that injects the metadata entry. + */ + static class MetadataHelper extends ForwardingLoadBalancerHelper { + + private final Helper delegateHelper; + private String metadataKey; + private String metadataValue; + + MetadataHelper(Helper delegateHelper) { + this.delegateHelper = delegateHelper; + } + + void setMetadata(String metadataKey, String metadataValue) { + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + + @Override + protected Helper delegate() { + return delegateHelper; + } + + @Override + public void updateBalancingState(@Nonnull ConnectivityState newState, + @Nonnull SubchannelPicker newPicker) { + delegateHelper.updateBalancingState(newState, + new MetadataPicker(newPicker, metadataKey, metadataValue)); + } + } + + /** + * Includes the rpc-behavior metadata entry on each subchannel pick. + */ + static class MetadataPicker extends SubchannelPicker { + + private final SubchannelPicker delegatePicker; + private final String metadataKey; + private final String metadataValue; + + MetadataPicker(SubchannelPicker delegatePicker, String metadataKey, String metadataValue) { + this.delegatePicker = delegatePicker; + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getHeaders() + .put(Metadata.Key.of(metadataKey, Metadata.ASCII_STRING_MARSHALLER), metadataValue); + return delegatePicker.pickSubchannel(args); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java index 420e92cf9cd..a005f40fad7 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java @@ -22,7 +22,9 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doReturn; @@ -399,6 +401,137 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { verify(balancer3).shutdown(); } + @Test + public void idleToConnectingDoesNotTriggerFailOver() { + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + Helper helper0 = Iterables.getOnlyElement(fooHelpers); + + // p0 gets IDLE. + helper0.updateBalancingState( + IDLE, + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); + + // p0 goes to CONNECTING + helper0.updateBalancingState( + IDLE, + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); + + // no failover happened + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + } + + @Test + public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + Helper helper0 = Iterables.getOnlyElement(fooHelpers); + + // p0 gets IDLE. + helper0.updateBalancingState( + IDLE, + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); + + // p0 goes to CONNECTING, reset failover timer + fakeClock.forwardTime(5, TimeUnit.SECONDS); + helper0.updateBalancingState( + CONNECTING, + BUFFER_PICKER); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); + + // failover happens + fakeClock.forwardTime(10, TimeUnit.SECONDS); + assertThat(fooBalancers).hasSize(2); + assertThat(fooHelpers).hasSize(2); + } + + @Test + public void readyToConnectDoesNotFailOverButUpdatesPicker() { + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + Helper helper0 = Iterables.getOnlyElement(fooHelpers); + + // p0 gets READY. + final Subchannel subchannel0 = mock(Subchannel.class); + helper0.updateBalancingState( + READY, + new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withSubchannel(subchannel0); + } + }); + assertCurrentPickerPicksSubchannel(subchannel0); + + // p0 goes to CONNECTING + helper0.updateBalancingState( + IDLE, + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); + + // no failover happened + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + + // resolution update without priority change does not trigger failover + Attributes.Key fooKey = Attributes.Key.create("fooKey"); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .setAttributes(Attributes.newBuilder().set(fooKey, "barVal").build()) + .build()); + + assertCurrentPickerIsBufferPicker(); + + // no failover happened + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + } + @Test public void typicalPriorityFailOverFlowWithIdleUpdate() { PriorityChildConfig priorityChildConfig0 = @@ -425,16 +558,10 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { Helper helper0 = Iterables.getOnlyElement(fooHelpers); // p0 gets IDLE. - final Subchannel subchannel0 = mock(Subchannel.class); helper0.updateBalancingState( IDLE, - new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel0); - } - }); - assertCurrentPickerPicksIdleSubchannel(subchannel0); + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); // p0 fails over to p1 immediately. helper0.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.ABORTED)); @@ -452,32 +579,20 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Helper helper2 = Iterables.getLast(fooHelpers); // p2 gets IDLE - final Subchannel subchannel1 = mock(Subchannel.class); helper2.updateBalancingState( IDLE, - new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel1); - } - }); - assertCurrentPickerPicksIdleSubchannel(subchannel1); + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); // p0 gets back to IDLE - final Subchannel subchannel2 = mock(Subchannel.class); helper0.updateBalancingState( IDLE, - new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel2); - } - }); - assertCurrentPickerPicksIdleSubchannel(subchannel2); + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); // p2 fails but does not affect overall picker helper2.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.UNAVAILABLE)); - assertCurrentPickerPicksIdleSubchannel(subchannel2); + assertCurrentPickerIsBufferPicker(); // p0 fails over to p3 immediately since p1 already timeout and p2 already in TRANSIENT_FAILURE. helper0.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.UNAVAILABLE)); @@ -497,32 +612,20 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { assertCurrentPickerReturnsError(Status.Code.DATA_LOSS, "foo"); // p2 gets back to IDLE - final Subchannel subchannel3 = mock(Subchannel.class); helper2.updateBalancingState( IDLE, - new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel3); - } - }); - assertCurrentPickerPicksIdleSubchannel(subchannel3); + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); // p0 gets back to IDLE - final Subchannel subchannel4 = mock(Subchannel.class); helper0.updateBalancingState( IDLE, - new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel4); - } - }); - assertCurrentPickerPicksIdleSubchannel(subchannel4); + BUFFER_PICKER); + assertCurrentPickerIsBufferPicker(); // p0 fails over to p2 and picker is updated to p2's existing picker. helper0.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.UNAVAILABLE)); - assertCurrentPickerPicksIdleSubchannel(subchannel3); + assertCurrentPickerIsBufferPicker(); // Deactivate child balancer get deleted. fakeClock.forwardTime(15, TimeUnit.MINUTES); @@ -574,7 +677,7 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); // LB shutdown and subchannel state change can happen simultaneously. If shutdown runs first, // any further balancing state update should be ignored. @@ -584,6 +687,37 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { verifyNoMoreInteractions(helper); } + @Test + public void noDuplicateOverallBalancingStateUpdate() { + FakeLoadBalancerProvider fakeLbProvider = new FakeLoadBalancerProvider(); + + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), false); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0), + ImmutableList.of("p0")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + + priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + + verify(helper, times(6)).updateBalancingState(any(), any()); + } + private void assertLatestConnectivityState(ConnectivityState expectedState) { verify(helper, atLeastOnce()) .updateBalancingState(connectivityStateCaptor.capture(), pickerCaptor.capture()); @@ -607,9 +741,54 @@ private void assertCurrentPickerPicksSubchannel(Subchannel expectedSubchannelToP assertThat(pickResult.getSubchannel()).isEqualTo(expectedSubchannelToPick); } - private void assertCurrentPickerPicksIdleSubchannel(Subchannel expectedSubchannelToPick) { + private void assertCurrentPickerIsBufferPicker() { assertLatestConnectivityState(IDLE); PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(pickResult.getSubchannel()).isEqualTo(expectedSubchannelToPick); + assertThat(pickResult).isEqualTo(PickResult.withNoResult()); + } + + private static class FakeLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "foo"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new FakeLoadBalancer(helper); + } + } + + static class FakeLoadBalancer extends LoadBalancer { + + private Helper helper; + + FakeLoadBalancer(Helper helper) { + this.helper = helper; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.INTERNAL)); + } + + @Override + public void handleNameResolutionError(Status error) { + } + + @Override + public void shutdown() { + } } } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 052868a2fb1..87615a125c0 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -29,8 +29,10 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.JsonParser; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; +import io.grpc.xds.RingHashOptions; import java.io.IOException; import java.lang.Thread.UncaughtExceptionHandler; +import java.util.Locale; import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; @@ -98,7 +100,7 @@ public void parseLoadBalancingConfig_invalid_negativeSize() throws IOException { ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } @@ -109,21 +111,90 @@ public void parseLoadBalancingConfig_invalid_minGreaterThanMax() throws IOExcept ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } @Test - public void parseLoadBalancingConfig_invalid_ringTooLarge() throws IOException { - long ringSize = RingHashLoadBalancerProvider.MAX_RING_SIZE + 1; - String lbConfig = String.format("{\"minRingSize\" : 10, \"maxRingSize\" : %d}", ringSize); + public void parseLoadBalancingConfig_ringTooLargeUsesCap() throws IOException { + long ringSize = RingHashOptions.MAX_RING_SIZE_CAP + 1; + String lbConfig = + String.format(Locale.US, "{\"minRingSize\" : 10, \"maxRingSize\" : %d}", ringSize); ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); - assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); - assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10); + assertThat(config.maxRingSize).isEqualTo(RingHashOptions.DEFAULT_RING_SIZE_CAP); + } + + @Test + public void parseLoadBalancingConfig_ringCapCanBeRaised() throws IOException { + RingHashOptions.setRingSizeCap(RingHashOptions.MAX_RING_SIZE_CAP); + long ringSize = RingHashOptions.MAX_RING_SIZE_CAP; + String lbConfig = + String.format( + Locale.US, "{\"minRingSize\" : %d, \"maxRingSize\" : %d}", ringSize, ringSize); + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + // Reset to avoid affecting subsequent test cases + RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); + } + + @Test + public void parseLoadBalancingConfig_ringCapIsClampedTo8M() throws IOException { + RingHashOptions.setRingSizeCap(RingHashOptions.MAX_RING_SIZE_CAP + 1); + long ringSize = RingHashOptions.MAX_RING_SIZE_CAP + 1; + String lbConfig = + String.format( + Locale.US, "{\"minRingSize\" : %d, \"maxRingSize\" : %d}", ringSize, ringSize); + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + // Reset to avoid affecting subsequent test cases + RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); + } + + @Test + public void parseLoadBalancingConfig_ringCapCanBeLowered() throws IOException { + RingHashOptions.setRingSizeCap(1); + long ringSize = 2; + String lbConfig = + String.format( + Locale.US, "{\"minRingSize\" : %d, \"maxRingSize\" : %d}", ringSize, ringSize); + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(1); + assertThat(config.maxRingSize).isEqualTo(1); + // Reset to avoid affecting subsequent test cases + RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); + } + + @Test + public void parseLoadBalancingConfig_ringCapLowerLimitIs1() throws IOException { + RingHashOptions.setRingSizeCap(0); + long ringSize = 2; + String lbConfig = + String.format( + Locale.US, "{\"minRingSize\" : %d, \"maxRingSize\" : %d}", ringSize, ringSize); + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(1); + assertThat(config.maxRingSize).isEqualTo(1); + // Reset to avoid affecting subsequent test cases + RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @Test @@ -132,7 +203,7 @@ public void parseLoadBalancingConfig_zeroMinRingSize() throws IOException { ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } @@ -143,7 +214,7 @@ public void parseLoadBalancingConfig_minRingSizeGreaterThanMaxRingSize() throws ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); - assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 5a9bb7ff4a8..ed2fff0d244 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -25,6 +25,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -35,6 +36,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.Iterables; +import com.google.common.primitives.UnsignedInteger; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ConnectivityStateInfo; @@ -56,8 +58,10 @@ import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import java.lang.Thread.UncaughtExceptionHandler; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; +import java.util.Deque; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -95,6 +99,7 @@ public void uncaughtException(Thread t, Throwable e) { private final Map, Subchannel> subchannels = new HashMap<>(); private final Map subchannelStateListeners = new HashMap<>(); + private final Deque connectionRequestedQueue = new ArrayDeque<>(); private final XxHash64 hashFunc = XxHash64.INSTANCE; @Mock private Helper helper; @@ -123,6 +128,13 @@ public Void answer(InvocationOnMock invocation) throws Throwable { return null; } }).when(subchannel).start(any(SubchannelStateListener.class)); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + connectionRequestedQueue.offer(subchannel); + return null; + } + }).when(subchannel).requestConnection(); return subchannel; } }); @@ -138,15 +150,17 @@ public void tearDown() { for (Subchannel subchannel : subchannels.values()) { verify(subchannel).shutdown(); } + connectionRequestedQueue.clear(); } @Test public void subchannelLazyConnectUntilPicked() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1); // one server - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); verify(subchannel, never()).requestConnection(); @@ -175,9 +189,10 @@ public void subchannelLazyConnectUntilPicked() { public void subchannelNotAutoReconnectAfterReenteringIdle() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1); // one server - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); InOrder inOrder = Mockito.inOrder(helper, subchannel); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -205,9 +220,10 @@ public void aggregateSubchannelStates_connectingReadyIdleFailure() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -216,18 +232,21 @@ public void aggregateSubchannelStates_connectingReadyIdleFailure() { subchannels.get(Collections.singletonList(servers.get(0))), ConnectivityStateInfo.forNonError(CONNECTING)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(0); // two in CONNECTING deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forNonError(CONNECTING)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(0); // one in CONNECTING, one in READY deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forNonError(READY)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + verifyConnection(0); // one in TRANSIENT_FAILURE, one in READY deliverSubchannelState( @@ -236,25 +255,37 @@ public void aggregateSubchannelStates_connectingReadyIdleFailure() { Status.UNKNOWN.withDescription("unknown failure"))); inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + verifyConnection(0); // one in TRANSIENT_FAILURE, one in IDLE deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(1); verifyNoMoreInteractions(helper); } + private void verifyConnection(int times) { + for (int i = 0; i < times; i++) { + Subchannel connectOnce = connectionRequestedQueue.poll(); + assertThat(connectOnce).isNotNull(); + clearInvocations(connectOnce); + } + assertThat(connectionRequestedQueue.poll()).isNull(); + } + @Test public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1, 1); InOrder inOrder = Mockito.inOrder(helper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -264,7 +295,8 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("not found"))); inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(1); // two in TRANSIENT_FAILURE, two in IDLE deliverSubchannelState( @@ -274,6 +306,7 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(1); // two in TRANSIENT_FAILURE, one in CONNECTING, one in IDLE // The overall state is dominated by the two in TRANSIENT_FAILURE. @@ -282,6 +315,7 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { ConnectivityStateInfo.forNonError(CONNECTING)); inOrder.verify(helper) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(0); // three in TRANSIENT_FAILURE, one in CONNECTING deliverSubchannelState( @@ -291,12 +325,14 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(0); // three in TRANSIENT_FAILURE, one in READY deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(2))), ConnectivityStateInfo.forNonError(READY)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + verifyConnection(0); verifyNoMoreInteractions(helper); } @@ -305,9 +341,10 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { public void subchannelStayInTransientFailureUntilBecomeReady() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); reset(helper); @@ -320,15 +357,20 @@ public void subchannelStayInTransientFailureUntilBecomeReady() { verify(helper, times(3)).refreshNameResolution(); // Stays in IDLE when until there are two or more subchannels in TRANSIENT_FAILURE. - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verify(helper, times(2)) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(3); verifyNoMoreInteractions(helper); + reset(helper); // Simulate underlying subchannel auto reconnect after backoff. for (Subchannel subchannel : subchannels.values()) { deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); } + verify(helper, times(3)) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(3); verifyNoMoreInteractions(helper); // Simulate one subchannel enters READY. @@ -337,13 +379,61 @@ public void subchannelStayInTransientFailureUntilBecomeReady() { verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); } + @Test + public void updateConnectionIterator() { + RingHashConfig config = new RingHashConfig(10, 100); + List servers = createWeightedServerAddrs(1, 1, 1); + InOrder inOrder = Mockito.inOrder(helper); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + deliverSubchannelState( + subchannels.get(Collections.singletonList(servers.get(0))), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("connection lost"))); + inOrder.verify(helper).refreshNameResolution(); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(1); + + servers = createWeightedServerAddrs(1,1); + addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(1); + + deliverSubchannelState( + subchannels.get(Collections.singletonList(servers.get(1))), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("connection lost"))); + inOrder.verify(helper).refreshNameResolution(); + inOrder.verify(helper) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(1); + + deliverSubchannelState( + subchannels.get(Collections.singletonList(servers.get(0))), + ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(helper) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verifyConnection(1); + } + @Test public void ignoreShutdownSubchannelStateChange() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -361,9 +451,10 @@ public void ignoreShutdownSubchannelStateChange() { public void deterministicPickWithHostsPartiallyRemoved() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); InOrder inOrder = Mockito.inOrder(helper); inOrder.verify(helper, times(5)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -389,9 +480,10 @@ public void deterministicPickWithHostsPartiallyRemoved() { Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build(); updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr)); } - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(subchannels.get(Collections.singletonList(servers.get(0)))) .updateAddresses(Collections.singletonList(updatedServers.get(0))); verify(subchannels.get(Collections.singletonList(servers.get(1)))) @@ -406,9 +498,10 @@ public void deterministicPickWithHostsPartiallyRemoved() { public void deterministicPickWithNewHostsAdded() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1); // server0 and server1 - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); InOrder inOrder = Mockito.inOrder(helper); inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -430,9 +523,10 @@ public void deterministicPickWithNewHostsAdded() { assertThat(subchannel.getAddresses()).isEqualTo(servers.get(1)); servers = createWeightedServerAddrs(1, 1, 1, 1, 1); // server2, server3, server4 added - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel()) @@ -445,9 +539,10 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE reset(helper); @@ -457,16 +552,15 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { // "[FakeSocketAddress-server2]_0" long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0"); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash); // Bring down server0 to force trying server2. deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(0))), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); - verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); @@ -476,6 +570,7 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) .requestConnection(); // no excessive connection + reset(helper); deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(2))), ConnectivityStateInfo.forNonError(CONNECTING)); @@ -495,14 +590,21 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(2)); } + private PickSubchannelArgsImpl getDefaultPickSubchannelArgs(long rpcHash) { + return new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); + } + @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE reset(helper); @@ -526,16 +628,17 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { ConnectivityStateInfo.forTransientFailure( Status.PERMISSION_DENIED.withDescription("permission denied"))); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verify(subchannels.get(Collections.singletonList(servers.get(1)))) - .requestConnection(); // LB attempts to recover by itself + verifyConnection(2); // LB attempts to recover by itself PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC assertThat(result.getStatus().getCode()) .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(1))), times(2)) - .requestConnection(); // kickoff connection to server3 (next first non-failing) + verify(subchannels.get(Collections.singletonList(servers.get(1)))) + .requestConnection(); // kickoff connection to server3 (next first non-failing) + verify(subchannels.get(Collections.singletonList(servers.get(0)))).requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2)))).requestConnection(); // Now connecting to server1. deliverSubchannelState( @@ -565,9 +668,10 @@ public void allSubchannelsInTransientFailure() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -579,6 +683,7 @@ public void allSubchannelsInTransientFailure() { } verify(helper, atLeastOnce()) .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verifyConnection(3); // Picking subchannel triggers connection. RPC hash hits server0. PickSubchannelArgs args = new PickSubchannelArgsImpl( @@ -589,51 +694,326 @@ public void allSubchannelsInTransientFailure() { assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) .isEqualTo("[FakeSocketAddress-server0] unreachable"); + verify(subchannels.get(Collections.singletonList(servers.get(0)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2)))) + .requestConnection(); } @Test - public void hostSelectionProportionalToWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring - List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 - loadBalancer.handleResolvedAddresses( + public void firstSubchannelIdle() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); - // Bring all subchannels to READY. - Map pickCounts = new HashMap<>(); - for (Subchannel subchannel : subchannels.values()) { - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - pickCounts.put(subchannel.getAddresses(), 0); - } - verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture()); - SubchannelPicker picker = pickerCaptor.getValue(); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("unreachable"))); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); - for (int i = 0; i < 10000; i++) { - long hash = hashFunc.hashInt(i); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hash)); - Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel(); - EquivalentAddressGroup addr = pickedSubchannel.getAddresses(); - pickCounts.put(addr, pickCounts.get(addr) + 1); - } + // Picking subchannel triggers connection. RPC hash hits server0. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + verify(subchannels.get(Collections.singletonList(servers.get(0)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) + .requestConnection(); + } - // Actual distribution: server0 = 104, server1 = 808, server2 = 9088 - double ratio01 = (double) pickCounts.get(servers.get(0)) / pickCounts.get(servers.get(1)); - double ratio12 = (double) pickCounts.get(servers.get(1)) / pickCounts.get(servers.get(2)); - assertThat(ratio01).isWithin(0.03).of((double) 1 / 10); - assertThat(ratio12).isWithin(0.03).of((double) 10 / 100); + @Test + public void firstSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))), + ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), + ConnectivityStateInfo.forNonError(CONNECTING)); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + + // Picking subchannel triggers connection. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + verify(subchannels.get(Collections.singletonList(servers.get(0))), never()) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) + .requestConnection(); + } + + @Test + public void firstSubchannelFailure() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // ring: + // "[FakeSocketAddress-server1]_0" + // "[FakeSocketAddress-server0]_0" + // "[FakeSocketAddress-server2]_0" + + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("unreachable"))); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); + + // Picking subchannel triggers connection. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + verify(subchannels.get(Collections.singletonList(servers.get(0)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); + } + + @Test + public void secondSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // ring: + // "[FakeSocketAddress-server1]_0" + // "[FakeSocketAddress-server0]_0" + // "[FakeSocketAddress-server2]_0" + + Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + deliverSubchannelState(firstSubchannel, + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription( + firstSubchannel.getAddresses().getAddresses() + "unreachable"))); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), + ConnectivityStateInfo.forNonError(CONNECTING)); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); + + // Picking subchannel triggers connection. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + verify(subchannels.get(Collections.singletonList(servers.get(0)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); } @Test - public void hostSelectionProportionalToRepeatedAddressCount() { - RingHashConfig config = new RingHashConfig(10000, 100000); - List servers = createRepeatedServerAddrs(1, 10, 100); // 1:10:100 - loadBalancer.handleResolvedAddresses( + public void secondSubchannelFailure() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // ring: + // "[FakeSocketAddress-server1]_0" + // "[FakeSocketAddress-server0]_0" + // "[FakeSocketAddress-server2]_0" + + Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + deliverSubchannelState(firstSubchannel, + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription( + firstSubchannel.getAddresses().getAddresses() + " unreachable"))); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("unreachable"))); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verifyConnection(2); + + // Picking subchannel triggers connection. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isFalse(); + assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(result.getStatus().getDescription()) + .isEqualTo("[FakeSocketAddress-server0] unreachable"); + verify(subchannels.get(Collections.singletonList(servers.get(0)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1)))) + .requestConnection(); + } + + @Test + public void thirdSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // ring: + // "[FakeSocketAddress-server1]_0" + // "[FakeSocketAddress-server0]_0" + // "[FakeSocketAddress-server2]_0" + + Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + deliverSubchannelState(firstSubchannel, + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription( + firstSubchannel.getAddresses().getAddresses() + " unreachable"))); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("unreachable"))); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), + ConnectivityStateInfo.forNonError(CONNECTING)); + verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verifyConnection(2); + + // Picking subchannel triggers connection. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isFalse(); + assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(result.getStatus().getDescription()) + .isEqualTo("[FakeSocketAddress-server0] unreachable"); + verify(subchannels.get(Collections.singletonList(servers.get(0)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2)))) + .requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); + } + + @Test + public void stickyTransientFailure() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + // Bring one subchannel to TRANSIENT_FAILURE. + Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + deliverSubchannelState(firstSubchannel, + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription( + firstSubchannel.getAddresses().getAddresses() + " unreachable"))); + + verify(helper).updateBalancingState(eq(CONNECTING), any()); + verifyConnection(1); + deliverSubchannelState(firstSubchannel, ConnectivityStateInfo.forNonError(IDLE)); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); + + // Picking subchannel triggers connection. RPC hash hits server0. + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + verify(subchannels.get(Collections.singletonList(servers.get(0)))).requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2)))).requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); + } + + @Test + public void largeWeights() { + RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + List servers = + createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100 + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + + // Try value between max signed and max unsigned int + servers = createWeightedServerAddrs(Integer.MAX_VALUE + 100L, 100); // (MAX+100):100 + addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); + + // Try a negative value + servers = createWeightedServerAddrs(10, -20, 100); // 10:-20:100 + addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isFalse(); + + // Try an individual value larger than max unsigned int + long maxUnsigned = UnsignedInteger.MAX_VALUE.longValue(); + servers = createWeightedServerAddrs(maxUnsigned + 10, 10, 100); // uMAX+10:10:100 + addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isFalse(); + + // Try a sum of values larger than max unsigned int + servers = createWeightedServerAddrs(Integer.MAX_VALUE, Integer.MAX_VALUE, 100); // MAX:MAX:100 + addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isFalse(); + } + + @Test + public void hostSelectionProportionalToWeights() { + RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -658,7 +1038,7 @@ public void hostSelectionProportionalToRepeatedAddressCount() { // Actual distribution: server0 = 104, server1 = 808, server2 = 9088 double ratio01 = (double) pickCounts.get(servers.get(0)) / pickCounts.get(servers.get(1)); - double ratio12 = (double) pickCounts.get(servers.get(1)) / pickCounts.get(servers.get(11)); + double ratio12 = (double) pickCounts.get(servers.get(1)) / pickCounts.get(servers.get(2)); assertThat(ratio01).isWithin(0.03).of((double) 1 / 10); assertThat(ratio12).isWithin(0.03).of((double) 10 / 100); } @@ -679,9 +1059,10 @@ public void nameResolutionErrorWithNoActiveSubchannels() { public void nameResolutionErrorWithActiveSubchannels() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -698,6 +1079,30 @@ public void nameResolutionErrorWithActiveSubchannels() { verifyNoMoreInteractions(helper); } + @Test + public void duplicateAddresses() { + RingHashConfig config = new RingHashConfig(10, 100); + List servers = createRepeatedServerAddrs(1, 2, 3); + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isFalse(); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + + PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); + assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC + assertThat(result.getStatus().getCode()) + .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash + String description = result.getStatus().getDescription(); + assertThat(description).startsWith( + "Ring hash lb error: EDS resolution was successful, but there were duplicate addresses: "); + assertThat(description).contains("Address: FakeSocketAddress-server1, count: 2"); + assertThat(description).contains("Address: FakeSocketAddress-server2, count: 3"); + } + private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo state) { subchannelStateListeners.get(subchannel).onSubchannelState(state); } diff --git a/xds/src/test/java/io/grpc/xds/RouteLookupServiceClusterSpecifierPluginTest.java b/xds/src/test/java/io/grpc/xds/RouteLookupServiceClusterSpecifierPluginTest.java index c883a9d257d..25b47232c45 100644 --- a/xds/src/test/java/io/grpc/xds/RouteLookupServiceClusterSpecifierPluginTest.java +++ b/xds/src/test/java/io/grpc/xds/RouteLookupServiceClusterSpecifierPluginTest.java @@ -89,7 +89,7 @@ public void parseConfigWithAllFieldsGiven() { .put("cacheSizeBytes", "5000") .put("validTargets", ImmutableList.of("valid-target")) .put("defaultTarget","default-target") - .build()); + .buildOrThrow()); } @Test @@ -131,6 +131,6 @@ public void parseConfigWithOptionalFieldsUnspecified() { .put("lookupServiceTimeout", "1.234s") .put("cacheSizeBytes", "5000") .put("validTargets", ImmutableList.of("valid-target")) - .build()); + .buildOrThrow()); } } diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 14a8f1ce743..58bbcce737c 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -67,7 +67,7 @@ public void noServer() throws XdsInitializationException { @Test public void sharedXdsClientObjectPool() throws XdsInitializationException { - ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create(), false); + ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); @@ -84,7 +84,7 @@ public void sharedXdsClientObjectPool() throws XdsInitializationException { @Test public void refCountedXdsClientObjectPool_delayedCreation() { - ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create(), false); + ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool(bootstrapInfo); @@ -96,7 +96,7 @@ public void refCountedXdsClientObjectPool_delayedCreation() { @Test public void refCountedXdsClientObjectPool_refCounted() { - ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create(), false); + ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool(bootstrapInfo); @@ -115,7 +115,7 @@ public void refCountedXdsClientObjectPool_refCounted() { @Test public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadyShutdown() { - ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create(), false); + ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool(bootstrapInfo); diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java index 7d9d30385e4..1bbd02b753f 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java @@ -179,7 +179,8 @@ public void tearDown() { @Test public void handleResolvedAddresses() { - ArgumentCaptor resolvedAddressesCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor resolvedAddressesCaptor = + ArgumentCaptor.forClass(ResolvedAddresses.class); Attributes.Key fakeKey = Attributes.Key.create("fake_key"); Object fakeValue = new Object(); @@ -260,8 +261,8 @@ public void handleResolvedAddresses() { @Test public void handleNameResolutionError() { - ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); // Error before any child balancer created. weightedTargetLb.handleNameResolutionError(Status.DATA_LOSS); @@ -326,7 +327,7 @@ public void balancingStateUpdatedFromChildBalancers() { new ErrorPicker(Status.DATA_LOSS), new ErrorPicker(Status.DATA_LOSS) }; - ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); // One child balancer goes to TRANSIENT_FAILURE. childHelpers.get(1).updateBalancingState(TRANSIENT_FAILURE, failurePickers[1]); @@ -402,4 +403,73 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { weightedChildHelper0.updateBalancingState(READY, mock(SubchannelPicker.class)); verifyNoMoreInteractions(helper); } + + // When the ChildHelper is asked to update the overall balancing state, it should not do that if + // the update was triggered by the parent LB that will handle triggering the overall state update. + @Test + public void noDuplicateOverallBalancingStateUpdate() { + FakeLoadBalancerProvider fakeLbProvider = new FakeLoadBalancerProvider(); + + Map targets = ImmutableMap.of( + "target0", new WeightedPolicySelection( + weights[0], new PolicySelection(fakeLbProvider, configs[0])), + "target3", new WeightedPolicySelection( + weights[3], new PolicySelection(fakeLbProvider, configs[3]))); + weightedTargetLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) + .build()); + + // Both of the two child LB policies will call the helper to update the balancing state. + // But since those calls happen during the handling of teh resolved addresses of the parent + // WeightedTargetLLoadBalancer, the overall balancing state should only be updated once. + verify(helper, times(1)).updateBalancingState(any(), any()); + + } + + private static class FakeLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "foo"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new FakeLoadBalancer(helper); + } + } + + static class FakeLoadBalancer extends LoadBalancer { + + private Helper helper; + + FakeLoadBalancer(Helper helper) { + this.helper = helper; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.INTERNAL)); + } + + @Override + public void handleNameResolutionError(Status error) { + } + + @Override + public void shutdown() { + } + } } diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java new file mode 100644 index 00000000000..d251f3677d8 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java @@ -0,0 +1,69 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver; +import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link WrrLocalityLoadBalancerProvider}. + */ +@RunWith(JUnit4.class) +public class WrrLocalityLoadBalancerProviderTest { + + @Test + public void provided() { + LoadBalancerProvider provider = + LoadBalancerRegistry.getDefaultRegistry().getProvider( + XdsLbPolicies.WRR_LOCALITY_POLICY_NAME); + assertThat(provider).isInstanceOf(WrrLocalityLoadBalancerProvider.class); + } + + @Test + public void providesLoadBalancer() { + Helper helper = mock(Helper.class); + when(helper.getAuthority()).thenReturn("api.google.com"); + LoadBalancerProvider provider = new WrrLocalityLoadBalancerProvider(); + LoadBalancer loadBalancer = provider.newLoadBalancer(helper); + assertThat(loadBalancer).isInstanceOf(WrrLocalityLoadBalancer.class); + } + + @Test + public void parseConfig() { + Map rawConfig = ImmutableMap.of("childPolicy", + ImmutableList.of(ImmutableMap.of("round_robin", ImmutableMap.of()))); + + WrrLocalityLoadBalancerProvider provider = new WrrLocalityLoadBalancerProvider(); + NameResolver.ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(rawConfig); + WrrLocalityConfig config = (WrrLocalityConfig) configOrError.getConfig(); + assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java new file mode 100644 index 00000000000..344876aa348 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java @@ -0,0 +1,272 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.testing.EqualsTester; +import io.grpc.Attributes; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; +import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; +import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Tests for {@link WrrLocalityLoadBalancerProvider}. + */ +@RunWith(JUnit4.class) +public class WrrLocalityLoadBalancerTest { + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private LoadBalancerProvider mockWeightedTargetProvider; + @Mock + private LoadBalancer mockWeightedTargetLb; + @Mock + private LoadBalancerProvider mockChildProvider; + @Mock + private LoadBalancer mockChildLb; + @Mock + private Helper mockHelper; + + @Captor + private ArgumentCaptor resolvedAddressesCaptor; + @Captor + private ArgumentCaptor connectivityStateCaptor; + @Captor + private ArgumentCaptor errorPickerCaptor; + + private WrrLocalityLoadBalancer loadBalancer; + private LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + @Before + public void setUp() { + when(mockHelper.getSynchronizationContext()).thenReturn(syncContext); + + when(mockWeightedTargetProvider.newLoadBalancer(isA(Helper.class))).thenReturn( + mockWeightedTargetLb); + when(mockWeightedTargetProvider.getPolicyName()).thenReturn(WEIGHTED_TARGET_POLICY_NAME); + when(mockWeightedTargetProvider.isAvailable()).thenReturn(true); + lbRegistry.register(mockWeightedTargetProvider); + + when(mockChildProvider.newLoadBalancer(isA(Helper.class))).thenReturn(mockChildLb); + when(mockChildProvider.getPolicyName()).thenReturn("round_robin"); + lbRegistry.register(mockWeightedTargetProvider); + + loadBalancer = new WrrLocalityLoadBalancer(mockHelper, lbRegistry); + } + + @Test + public void handleResolvedAddresses() { + // A two locality cluster with a mock child LB policy. + Locality localityOne = Locality.create("region1", "zone1", "subzone1"); + Locality localityTwo = Locality.create("region2", "zone2", "subzone2"); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); + + // The child config is delivered wrapped in the wrr_locality config and the locality weights + // in a ResolvedAddresses attribute. + WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); + deliverAddresses(wlConfig, + ImmutableList.of( + makeAddress("addr1", localityOne, 1), + makeAddress("addr2", localityTwo, 2))); + + // Assert that the child policy and the locality weights were correctly mapped to a + // WeightedTargetConfig. + verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + Object config = resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig(); + assertThat(config).isInstanceOf(WeightedTargetConfig.class); + WeightedTargetConfig wtConfig = (WeightedTargetConfig) config; + assertThat(wtConfig.targets).hasSize(2); + assertThat(wtConfig.targets).containsEntry(localityOne.toString(), + new WeightedPolicySelection(1, childPolicy)); + assertThat(wtConfig.targets).containsEntry(localityTwo.toString(), + new WeightedPolicySelection(2, childPolicy)); + } + + @Test + public void handleResolvedAddresses_noLocalityWeights() { + // A two locality cluster with a mock child LB policy. + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); + + // The child config is delivered wrapped in the wrr_locality config and the locality weights + // in a ResolvedAddresses attribute. + WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); + deliverAddresses(wlConfig, ImmutableList.of( + makeAddress("addr", Locality.create("test-region", "test-zone", "test-subzone"), null))); + + // With no locality weights, we should get a TRANSIENT_FAILURE. + verify(mockHelper).getAuthority(); + verify(mockHelper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), + isA(ErrorPicker.class)); + } + + @Test + public void handleNameResolutionError_noChildLb() { + loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); + + verify(mockHelper).updateBalancingState(connectivityStateCaptor.capture(), + errorPickerCaptor.capture()); + assertThat(connectivityStateCaptor.getValue()).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); + assertThat(errorPickerCaptor.getValue().toString()).isEqualTo( + new ErrorPicker(Status.DEADLINE_EXCEEDED).toString()); + } + + @Test + public void handleNameResolutionError_withChildLb() { + deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), + ImmutableList.of( + makeAddress("addr1", Locality.create("test-region1", "test-zone", "test-subzone"), 1))); + loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); + + verify(mockHelper, never()).updateBalancingState(isA(ConnectivityState.class), + isA(ErrorPicker.class)); + verify(mockWeightedTargetLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); + } + + @Test + public void localityWeightAttributeNotPropagated() { + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); + + WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); + deliverAddresses(wlConfig, ImmutableList.of( + makeAddress("addr1", Locality.create("test-region1", "test-zone", "test-subzone"), 1))); + + // Assert that the child policy and the locality weights were correctly mapped to a + // WeightedTargetConfig. + verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + + //assertThat(resolvedAddressesCaptor.getValue().getAttributes() + // .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); + } + + @Test + public void shutdown() { + deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), + ImmutableList.of( + makeAddress("addr", Locality.create("test-region", "test-zone", "test-subzone"), 1))); + loadBalancer.shutdown(); + + verify(mockWeightedTargetLb).shutdown(); + } + + @Test + public void configEquality() { + WrrLocalityConfig configOne = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, null)); + WrrLocalityConfig configTwo = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, null)); + WrrLocalityConfig differentConfig = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, "config")); + + new EqualsTester().addEqualityGroup(configOne, configTwo).addEqualityGroup(differentConfig) + .testEquals(); + } + + private void deliverAddresses(WrrLocalityConfig config, List addresses) { + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(addresses).setLoadBalancingPolicyConfig(config) + .build()); + } + + /** + * Create a locality-labeled address. + */ + private static EquivalentAddressGroup makeAddress(final String name, Locality locality, + Integer localityWeight) { + class FakeSocketAddress extends SocketAddress { + private final String name; + + private FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public int hashCode() { + return Objects.hash(name); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof FakeSocketAddress)) { + return false; + } + FakeSocketAddress that = (FakeSocketAddress) o; + return Objects.equals(name, that.name); + } + + @Override + public String toString() { + return name; + } + } + + Attributes.Builder attrBuilder = Attributes.newBuilder() + .set(InternalXdsAttributes.ATTR_LOCALITY, locality); + if (localityWeight != null) { + attrBuilder.set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT, localityWeight); + } + + EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), + attrBuilder.build()); + return AddressFilter.setPathFilter(eag, Collections.singletonList(locality.toString())); + } +} diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java similarity index 80% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java index 2b75c02d4dc..051d890aea4 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java @@ -17,6 +17,8 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.envoyproxy.envoy.config.route.v3.RouteAction.ClusterSpecifierCase.CLUSTER_SPECIFIER_PLUGIN; +import static org.junit.Assert.fail; import com.github.udpa.udpa.type.v1.TypedStruct; import com.google.common.collect.ImmutableMap; @@ -24,12 +26,12 @@ import com.google.common.collect.Iterables; import com.google.protobuf.Any; import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.StringValue; import com.google.protobuf.Struct; import com.google.protobuf.UInt32Value; -import com.google.protobuf.UInt64Value; import com.google.protobuf.Value; import com.google.protobuf.util.Durations; import com.google.re2j.Pattern; @@ -37,9 +39,6 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; -import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; -import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; -import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFunction; import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; import io.envoyproxy.envoy.config.core.v3.CidrRange; @@ -47,6 +46,7 @@ import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions; import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.core.v3.PathConfigSource; import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; import io.envoyproxy.envoy.config.core.v3.SelfConfigSource; import io.envoyproxy.envoy.config.core.v3.SocketAddress; @@ -105,16 +105,17 @@ import io.grpc.ClientInterceptor; import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerRegistry; import io.grpc.Status.Code; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; import io.grpc.lookup.v1.GrpcKeyBuilder; import io.grpc.lookup.v1.GrpcKeyBuilder.Name; import io.grpc.lookup.v1.NameMatcher; import io.grpc.lookup.v1.RouteLookupClusterSpecifier; import io.grpc.lookup.v1.RouteLookupConfig; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; -import io.grpc.xds.ClientXdsClient.StructOrError; import io.grpc.xds.ClusterSpecifierPlugin.NamedPluginConfig; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; import io.grpc.xds.Endpoints.LbEndpoint; @@ -127,15 +128,16 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; -import io.grpc.xds.XdsClient.CdsUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsResourceType.StructOrError; +import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.Matchers.FractionMatcher; import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.util.Arrays; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; @@ -148,10 +150,10 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class ClientXdsClientDataTest { +public class XdsClientImplDataTest { private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create(), true); + ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule @@ -164,22 +166,22 @@ public class ClientXdsClientDataTest { @Before public void setUp() { - originalEnableRetry = ClientXdsClient.enableRetry; + originalEnableRetry = XdsResourceType.enableRetry; assertThat(originalEnableRetry).isTrue(); - originalEnableRbac = ClientXdsClient.enableRbac; + originalEnableRbac = XdsResourceType.enableRbac; assertThat(originalEnableRbac).isTrue(); - originalEnableRouteLookup = ClientXdsClient.enableRouteLookup; + originalEnableRouteLookup = XdsResourceType.enableRouteLookup; assertThat(originalEnableRouteLookup).isFalse(); - originalEnableLeastRequest = ClientXdsClient.enableLeastRequest; + originalEnableLeastRequest = XdsResourceType.enableLeastRequest; assertThat(originalEnableLeastRequest).isFalse(); } @After public void tearDown() { - ClientXdsClient.enableRetry = originalEnableRetry; - ClientXdsClient.enableRbac = originalEnableRbac; - ClientXdsClient.enableRouteLookup = originalEnableRouteLookup; - ClientXdsClient.enableLeastRequest = originalEnableLeastRequest; + XdsResourceType.enableRetry = originalEnableRetry; + XdsResourceType.enableRbac = originalEnableRbac; + XdsResourceType.enableRouteLookup = originalEnableRouteLookup; + XdsResourceType.enableLeastRequest = originalEnableLeastRequest; } @Test @@ -194,8 +196,8 @@ public void parseRoute_withRouteAction() { io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() .setCluster("cluster-foo")) .build(); - StructOrError struct = ClientXdsClient.parseRoute( - proto, filterRegistry, false, ImmutableMap.of()); + StructOrError struct = XdsRouteConfigureResource.parseRoute( + proto, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()) .isEqualTo( @@ -217,8 +219,8 @@ public void parseRoute_withNonForwardingAction() { .setPath("/service/method")) .setNonForwardingAction(NonForwardingAction.getDefaultInstance()) .build(); - StructOrError struct = ClientXdsClient.parseRoute( - proto, filterRegistry, false, ImmutableMap.of()); + StructOrError struct = XdsRouteConfigureResource.parseRoute( + proto, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct()) .isEqualTo( Route.forNonForwardingAction( @@ -236,8 +238,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setMatch(io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder().setPath("")) .setRedirect(RedirectAction.getDefaultInstance()) .build(); - res = ClientXdsClient.parseRoute( - redirectRoute, filterRegistry, false, ImmutableMap.of()); + res = XdsRouteConfigureResource.parseRoute( + redirectRoute, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of()); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: REDIRECT"); @@ -248,8 +250,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setMatch(io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder().setPath("")) .setDirectResponse(DirectResponseAction.getDefaultInstance()) .build(); - res = ClientXdsClient.parseRoute( - directResponseRoute, filterRegistry, false, ImmutableMap.of()); + res = XdsRouteConfigureResource.parseRoute( + directResponseRoute, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of()); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: DIRECT_RESPONSE"); @@ -260,8 +262,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setMatch(io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder().setPath("")) .setFilterAction(FilterAction.getDefaultInstance()) .build(); - res = ClientXdsClient.parseRoute( - filterRoute, filterRegistry, false, ImmutableMap.of()); + res = XdsRouteConfigureResource.parseRoute( + filterRoute, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of()); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: FILTER_ACTION"); @@ -282,8 +284,8 @@ public void parseRoute_skipRouteWithUnsupportedMatcher() { io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() .setCluster("cluster-foo")) .build(); - assertThat(ClientXdsClient.parseRoute( - proto, filterRegistry, false, ImmutableMap.of())) + assertThat(XdsRouteConfigureResource.parseRoute( + proto, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of())) .isNull(); } @@ -299,8 +301,8 @@ public void parseRoute_skipRouteWithUnsupportedAction() { io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() .setClusterHeader("cluster header")) // cluster_header action not supported .build(); - assertThat(ClientXdsClient.parseRoute( - proto, filterRegistry, false, ImmutableMap.of())) + assertThat(XdsRouteConfigureResource.parseRoute( + proto, filterRegistry, false, ImmutableMap.of(), ImmutableSet.of())) .isNull(); } @@ -319,7 +321,7 @@ public void parseRouteMatch_withHeaderMatcher() { .setName(":method") .setExactMatch("PUT")) .build(); - StructOrError struct = ClientXdsClient.parseRouteMatch(proto); + StructOrError struct = XdsRouteConfigureResource.parseRouteMatch(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()) .isEqualTo( @@ -343,7 +345,7 @@ public void parseRouteMatch_withRuntimeFractionMatcher() { .setNumerator(30) .setDenominator(FractionalPercent.DenominatorType.HUNDRED))) .build(); - StructOrError struct = ClientXdsClient.parseRouteMatch(proto); + StructOrError struct = XdsRouteConfigureResource.parseRouteMatch(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()) .isEqualTo( @@ -358,7 +360,7 @@ public void parsePathMatcher_withFullPath() { io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder() .setPath("/service/method") .build(); - StructOrError struct = ClientXdsClient.parsePathMatcher(proto); + StructOrError struct = XdsRouteConfigureResource.parsePathMatcher(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( PathMatcher.fromPath("/service/method", false)); @@ -368,7 +370,7 @@ public void parsePathMatcher_withFullPath() { public void parsePathMatcher_withPrefix() { io.envoyproxy.envoy.config.route.v3.RouteMatch proto = io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder().setPrefix("/").build(); - StructOrError struct = ClientXdsClient.parsePathMatcher(proto); + StructOrError struct = XdsRouteConfigureResource.parsePathMatcher(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( PathMatcher.fromPrefix("/", false)); @@ -380,7 +382,7 @@ public void parsePathMatcher_withSafeRegEx() { io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder() .setSafeRegex(RegexMatcher.newBuilder().setRegex(".")) .build(); - StructOrError struct = ClientXdsClient.parsePathMatcher(proto); + StructOrError struct = XdsRouteConfigureResource.parsePathMatcher(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo(PathMatcher.fromRegEx(Pattern.compile("."))); } @@ -393,7 +395,7 @@ public void parseHeaderMatcher_withExactMatch() { .setName(":method") .setExactMatch("PUT") .build(); - StructOrError struct1 = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct1 = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct1.getErrorDetail()).isNull(); assertThat(struct1.getStruct()).isEqualTo( HeaderMatcher.forExactValue(":method", "PUT", false)); @@ -407,7 +409,7 @@ public void parseHeaderMatcher_withSafeRegExMatch() { .setName(":method") .setSafeRegexMatch(RegexMatcher.newBuilder().setRegex("P*")) .build(); - StructOrError struct3 = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct3 = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct3.getErrorDetail()).isNull(); assertThat(struct3.getStruct()).isEqualTo( HeaderMatcher.forSafeRegEx(":method", Pattern.compile("P*"), false)); @@ -420,7 +422,7 @@ public void parseHeaderMatcher_withRangeMatch() { .setName("timeout") .setRangeMatch(Int64Range.newBuilder().setStart(10L).setEnd(20L)) .build(); - StructOrError struct4 = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct4 = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct4.getErrorDetail()).isNull(); assertThat(struct4.getStruct()).isEqualTo( HeaderMatcher.forRange("timeout", HeaderMatcher.Range.create(10L, 20L), false)); @@ -433,7 +435,7 @@ public void parseHeaderMatcher_withPresentMatch() { .setName("user-agent") .setPresentMatch(true) .build(); - StructOrError struct5 = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct5 = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct5.getErrorDetail()).isNull(); assertThat(struct5.getStruct()).isEqualTo( HeaderMatcher.forPresent("user-agent", true, false)); @@ -447,7 +449,7 @@ public void parseHeaderMatcher_withPrefixMatch() { .setName("authority") .setPrefixMatch("service-foo") .build(); - StructOrError struct6 = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct6 = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct6.getErrorDetail()).isNull(); assertThat(struct6.getStruct()).isEqualTo( HeaderMatcher.forPrefix("authority", "service-foo", false)); @@ -461,7 +463,7 @@ public void parseHeaderMatcher_withSuffixMatch() { .setName("authority") .setSuffixMatch("googleapis.com") .build(); - StructOrError struct7 = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct7 = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct7.getErrorDetail()).isNull(); assertThat(struct7.getStruct()).isEqualTo( HeaderMatcher.forSuffix("authority", "googleapis.com", false)); @@ -475,11 +477,33 @@ public void parseHeaderMatcher_malformedRegExPattern() { .setName(":method") .setSafeRegexMatch(RegexMatcher.newBuilder().setRegex("[")) .build(); - StructOrError struct = ClientXdsClient.parseHeaderMatcher(proto); + StructOrError struct = XdsRouteConfigureResource.parseHeaderMatcher(proto); assertThat(struct.getErrorDetail()).isNotNull(); assertThat(struct.getStruct()).isNull(); } + @Test + @SuppressWarnings("deprecation") + public void parseHeaderMatcher_withStringMatcher() { + io.envoyproxy.envoy.type.matcher.v3.StringMatcher stringMatcherProto = + io.envoyproxy.envoy.type.matcher.v3.StringMatcher.newBuilder() + .setPrefix("service-foo") + .setIgnoreCase(false) + .build(); + + io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = + io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() + .setName("authority") + .setStringMatch(stringMatcherProto) + .setInvertMatch(false) + .build(); + StructOrError struct = XdsRouteConfigureResource.parseHeaderMatcher(proto); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct()).isEqualTo( + HeaderMatcher.forString("authority", Matchers.StringMatcher + .forPrefix("service-foo", false), false)); + } + @Test public void parseRouteAction_withCluster() { io.envoyproxy.envoy.config.route.v3.RouteAction proto = @@ -487,8 +511,8 @@ public void parseRouteAction_withCluster() { .setCluster("cluster-foo") .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); assertThat(struct.getStruct().weightedClusters()).isNull(); @@ -511,8 +535,8 @@ public void parseRouteAction_withWeightedCluster() { .setWeight(UInt32Value.newBuilder().setValue(70)))) .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct().cluster()).isNull(); assertThat(struct.getStruct().weightedClusters()).containsExactly( @@ -520,6 +544,28 @@ public void parseRouteAction_withWeightedCluster() { ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); } + @Test + public void parseRouteAction_weightedClusterSum() { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setWeightedClusters( + WeightedCluster.newBuilder() + .addClusters( + WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-foo") + .setWeight(UInt32Value.newBuilder().setValue(0))) + .addClusters(WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-bar") + .setWeight(UInt32Value.newBuilder().setValue(0)))) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); + assertThat(struct.getErrorDetail()).isEqualTo("Sum of cluster weights should be above 0."); + } + @Test public void parseRouteAction_withTimeoutByGrpcTimeoutHeaderMax() { io.envoyproxy.envoy.config.route.v3.RouteAction proto = @@ -531,8 +577,8 @@ public void parseRouteAction_withTimeoutByGrpcTimeoutHeaderMax() { .setMaxStreamDuration(Durations.fromMillis(20L))) .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().timeoutNano()).isEqualTo(TimeUnit.SECONDS.toNanos(5L)); } @@ -546,8 +592,8 @@ public void parseRouteAction_withTimeoutByMaxStreamDuration() { .setMaxStreamDuration(Durations.fromSeconds(5L))) .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().timeoutNano()).isEqualTo(TimeUnit.SECONDS.toNanos(5L)); } @@ -558,14 +604,14 @@ public void parseRouteAction_withTimeoutUnset() { .setCluster("cluster-foo") .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().timeoutNano()).isNull(); } @Test public void parseRouteAction_withRetryPolicy() { - ClientXdsClient.enableRetry = true; + XdsResourceType.enableRetry = true; RetryPolicy.Builder builder = RetryPolicy.newBuilder() .setNumRetries(UInt32Value.of(3)) .setRetryBackOff( @@ -581,8 +627,8 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder.build()) .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); RouteAction.RetryPolicy retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.maxAttempts()).isEqualTo(4); assertThat(retryPolicy.initialBackoff()).isEqualTo(Durations.fromMillis(500)); @@ -605,8 +651,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder.build()) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().retryPolicy()).isNotNull(); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()).isEmpty(); @@ -618,8 +664,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()).isEqualTo("No base_interval specified in retry_backoff"); // max_interval unset @@ -628,8 +674,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.maxBackoff()).isEqualTo(Durations.fromMillis(500 * 10)); @@ -639,8 +685,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()) .isEqualTo("base_interval in retry_backoff must be positive"); @@ -652,8 +698,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()) .isEqualTo("max_interval in retry_backoff cannot be less than base_interval"); @@ -665,8 +711,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getErrorDetail()) .isEqualTo("max_interval in retry_backoff cannot be less than base_interval"); @@ -678,8 +724,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().retryPolicy().initialBackoff()) .isEqualTo(Durations.fromMillis(1)); assertThat(struct.getStruct().retryPolicy().maxBackoff()) @@ -694,8 +740,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.initialBackoff()).isEqualTo(Durations.fromMillis(25)); assertThat(retryPolicy.maxBackoff()).isEqualTo(Durations.fromMillis(250)); @@ -713,8 +759,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); @@ -731,8 +777,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); @@ -749,8 +795,8 @@ public void parseRouteAction_withRetryPolicy() { .setCluster("cluster-foo") .setRetryPolicy(builder) .build(); - struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); } @@ -780,15 +826,15 @@ public void parseRouteAction_withHashPolicies() { io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.newBuilder() .setFilterState( FilterState.newBuilder() - .setKey(ClientXdsClient.HASH_POLICY_FILTER_STATE_KEY))) + .setKey(XdsResourceType.HASH_POLICY_FILTER_STATE_KEY))) .addHashPolicy( io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.newBuilder() .setQueryParameter( QueryParameter.newBuilder().setName("param"))) // unsupported .build(); StructOrError struct = - ClientXdsClient.parseRouteAction(proto, filterRegistry, false, - ImmutableMap.of()); + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); List policies = struct.getStruct().hashPolicies(); assertThat(policies).hasSize(2); assertThat(policies.get(0).type()).isEqualTo(HashPolicy.Type.HEADER); @@ -801,6 +847,30 @@ public void parseRouteAction_withHashPolicies() { assertThat(policies.get(1).isTerminal()).isFalse(); } + @Test + public void parseRouteAction_custerSpecifierNotSet() { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); + assertThat(struct).isNull(); + } + + @Test + public void parseRouteAction_clusterSpecifier_routeLookupDisabled() { + XdsResourceType.enableRouteLookup = false; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, false, + ImmutableMap.of(), ImmutableSet.of()); + assertThat(struct).isNull(); + } + @Test public void parseClusterWeight() { io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight proto = @@ -809,7 +879,7 @@ public void parseClusterWeight() { .setWeight(UInt32Value.newBuilder().setValue(30)) .build(); ClusterWeight clusterWeight = - ClientXdsClient.parseClusterWeight(proto, filterRegistry, false).getStruct(); + XdsRouteConfigureResource.parseClusterWeight(proto, filterRegistry, false).getStruct(); assertThat(clusterWeight.name()).isEqualTo("cluster-foo"); assertThat(clusterWeight.weight()).isEqualTo(30); } @@ -831,7 +901,7 @@ public void parseLocalityLbEndpoints_withHealthyEndpoints() { .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY) .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight .build(); - StructOrError struct = ClientXdsClient.parseLocalityLbEndpoints(proto); + StructOrError struct = XdsEndpointResource.parseLocalityLbEndpoints(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( @@ -855,7 +925,7 @@ public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() { .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.UNKNOWN) .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight .build(); - StructOrError struct = ClientXdsClient.parseLocalityLbEndpoints(proto); + StructOrError struct = XdsEndpointResource.parseLocalityLbEndpoints(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( @@ -879,7 +949,7 @@ public void parseLocalityLbEndpoints_withUnHealthyEndpoints() { .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.UNHEALTHY) .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight .build(); - StructOrError struct = ClientXdsClient.parseLocalityLbEndpoints(proto); + StructOrError struct = XdsEndpointResource.parseLocalityLbEndpoints(proto); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( @@ -903,7 +973,7 @@ public void parseLocalityLbEndpoints_ignorZeroWeightLocality() { .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.UNKNOWN) .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight .build(); - assertThat(ClientXdsClient.parseLocalityLbEndpoints(proto)).isNull(); + assertThat(XdsEndpointResource.parseLocalityLbEndpoints(proto)).isNull(); } @Test @@ -923,7 +993,7 @@ public void parseLocalityLbEndpoints_invalidPriority() { .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.UNKNOWN) .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight .build(); - StructOrError struct = ClientXdsClient.parseLocalityLbEndpoints(proto); + StructOrError struct = XdsEndpointResource.parseLocalityLbEndpoints(proto); assertThat(struct.getErrorDetail()).isEqualTo("negative priority"); } @@ -933,7 +1003,7 @@ public void parseHttpFilter_unsupportedButOptional() { .setIsOptional(true) .setTypedConfig(Any.pack(StringValue.of("unsupported"))) .build(); - assertThat(ClientXdsClient.parseHttpFilter(httpFilter, filterRegistry, true)).isNull(); + assertThat(XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true)).isNull(); } private static class SimpleFilterConfig implements FilterConfig { @@ -994,7 +1064,7 @@ public void parseHttpFilter_typedStructMigration() { .setTypeUrl("test-url") .setValue(rawStruct) .build())).build(); - FilterConfig config = ClientXdsClient.parseHttpFilter(httpFilter, filterRegistry, + FilterConfig config = XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true).getStruct(); assertThat(((SimpleFilterConfig)config).getConfig()).isEqualTo(rawStruct); @@ -1005,7 +1075,7 @@ public void parseHttpFilter_typedStructMigration() { .setTypeUrl("test-url") .setValue(rawStruct) .build())).build(); - config = ClientXdsClient.parseHttpFilter(httpFilterNewTypeStruct, filterRegistry, + config = XdsListenerResource.parseHttpFilter(httpFilterNewTypeStruct, filterRegistry, true).getStruct(); assertThat(((SimpleFilterConfig)config).getConfig()).isEqualTo(rawStruct); } @@ -1031,8 +1101,8 @@ public void parseOverrideHttpFilter_typedStructMigration() { .setValue(rawStruct1) .build()) ); - Map map = ClientXdsClient.parseOverrideFilterConfigs(rawFilterMap, - filterRegistry).getStruct(); + Map map = XdsRouteConfigureResource.parseOverrideFilterConfigs( + rawFilterMap, filterRegistry).getStruct(); assertThat(((SimpleFilterConfig)map.get("struct-0")).getConfig()).isEqualTo(rawStruct0); assertThat(((SimpleFilterConfig)map.get("struct-1")).getConfig()).isEqualTo(rawStruct1); } @@ -1044,7 +1114,7 @@ public void parseHttpFilter_unsupportedAndRequired() { .setName("unsupported.filter") .setTypedConfig(Any.pack(StringValue.of("string value"))) .build(); - assertThat(ClientXdsClient.parseHttpFilter(httpFilter, filterRegistry, true) + assertThat(XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true) .getErrorDetail()).isEqualTo( "HttpFilter [unsupported.filter]" + "(type.googleapis.com/google.protobuf.StringValue) is required but unsupported " @@ -1060,7 +1130,7 @@ public void parseHttpFilter_routerFilterForClient() { .setName("envoy.router") .setTypedConfig(Any.pack(Router.getDefaultInstance())) .build(); - FilterConfig config = ClientXdsClient.parseHttpFilter( + FilterConfig config = XdsListenerResource.parseHttpFilter( httpFilter, filterRegistry, true /* isForClient */).getStruct(); assertThat(config.typeUrl()).isEqualTo(RouterFilter.TYPE_URL); } @@ -1074,7 +1144,7 @@ public void parseHttpFilter_routerFilterForServer() { .setName("envoy.router") .setTypedConfig(Any.pack(Router.getDefaultInstance())) .build(); - FilterConfig config = ClientXdsClient.parseHttpFilter( + FilterConfig config = XdsListenerResource.parseHttpFilter( httpFilter, filterRegistry, false /* isForClient */).getStruct(); assertThat(config.typeUrl()).isEqualTo(RouterFilter.TYPE_URL); } @@ -1101,7 +1171,7 @@ public void parseHttpFilter_faultConfigForClient() { .setDenominator(DenominatorType.HUNDRED))) .build())) .build(); - FilterConfig config = ClientXdsClient.parseHttpFilter( + FilterConfig config = XdsListenerResource.parseHttpFilter( httpFilter, filterRegistry, true /* isForClient */).getStruct(); assertThat(config).isInstanceOf(FaultConfig.class); } @@ -1129,7 +1199,7 @@ public void parseHttpFilter_faultConfigUnsupportedForServer() { .build())) .build(); StructOrError config = - ClientXdsClient.parseHttpFilter(httpFilter, filterRegistry, false /* isForClient */); + XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, false /* isForClient */); assertThat(config.getErrorDetail()).isEqualTo( "HttpFilter [envoy.fault](" + FaultFilter.TYPE_URL + ") is required but " + "unsupported for server"); @@ -1157,7 +1227,7 @@ public void parseHttpFilter_rbacConfigForServer() { .build()) .build())) .build(); - FilterConfig config = ClientXdsClient.parseHttpFilter( + FilterConfig config = XdsListenerResource.parseHttpFilter( httpFilter, filterRegistry, false /* isForClient */).getStruct(); assertThat(config).isInstanceOf(RbacConfig.class); } @@ -1185,7 +1255,7 @@ public void parseHttpFilter_rbacConfigUnsupportedForClient() { .build())) .build(); StructOrError config = - ClientXdsClient.parseHttpFilter(httpFilter, filterRegistry, true /* isForClient */); + XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true /* isForClient */); assertThat(config.getErrorDetail()).isEqualTo( "HttpFilter [envoy.auth](" + RbacFilter.TYPE_URL + ") is required but " + "unsupported for client"); @@ -1210,7 +1280,8 @@ public void parseOverrideRbacFilterConfig() { .build(); Map configOverrides = ImmutableMap.of("envoy.auth", Any.pack(rbacPerRoute)); Map parsedConfigs = - ClientXdsClient.parseOverrideFilterConfigs(configOverrides, filterRegistry).getStruct(); + XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) + .getStruct(); assertThat(parsedConfigs).hasSize(1); assertThat(parsedConfigs).containsKey("envoy.auth"); assertThat(parsedConfigs.get("envoy.auth")).isInstanceOf(RbacConfig.class); @@ -1230,7 +1301,8 @@ public void parseOverrideFilterConfigs_unsupportedButOptional() { .setIsOptional(true).setConfig(Any.pack(StringValue.of("string value"))) .build())); Map parsedConfigs = - ClientXdsClient.parseOverrideFilterConfigs(configOverrides, filterRegistry).getStruct(); + XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) + .getStruct(); assertThat(parsedConfigs).hasSize(1); assertThat(parsedConfigs).containsKey("envoy.fault"); } @@ -1248,7 +1320,7 @@ public void parseOverrideFilterConfigs_unsupportedAndRequired() { Any.pack(io.envoyproxy.envoy.config.route.v3.FilterConfig.newBuilder() .setIsOptional(false).setConfig(Any.pack(StringValue.of("string value"))) .build())); - assertThat(ClientXdsClient.parseOverrideFilterConfigs(configOverrides, filterRegistry) + assertThat(XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) .getErrorDetail()).isEqualTo( "HttpFilter [unsupported.filter]" + "(type.googleapis.com/google.protobuf.StringValue) is required but unsupported"); @@ -1258,7 +1330,7 @@ public void parseOverrideFilterConfigs_unsupportedAndRequired() { Any.pack(httpFault), "unsupported.filter", Any.pack(StringValue.of("string value"))); - assertThat(ClientXdsClient.parseOverrideFilterConfigs(configOverrides, filterRegistry) + assertThat(XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) .getErrorDetail()).isEqualTo( "HttpFilter [unsupported.filter]" + "(type.googleapis.com/google.protobuf.StringValue) is required but unsupported"); @@ -1271,8 +1343,8 @@ public void parseHttpConnectionManager_xffNumTrustedHopsUnsupported() HttpConnectionManager hcm = HttpConnectionManager.newBuilder().setXffNumTrustedHops(2).build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("HttpConnectionManager with xff_num_trusted_hops unsupported"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, false /* does not matter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false /* does not matter */, true /* does not matter */); } @@ -1285,8 +1357,8 @@ public void parseHttpConnectionManager_OriginalIpDetectionExtensionsMustEmpty() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("HttpConnectionManager with original_ip_detection_extensions unsupported"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, false /* does not matter */, false); + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false /* does not matter */, false); } @Test @@ -1300,8 +1372,8 @@ public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("HttpConnectionManager neither has inlined route_config nor RDS"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, false /* does not matter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false /* does not matter */, true /* does not matter */); } @@ -1319,8 +1391,8 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, true /* parseHttpFilter */, true /* does not matter */); } @@ -1337,8 +1409,8 @@ public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, true /* parseHttpFilter */, true /* does not matter */); } @@ -1355,8 +1427,8 @@ public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("A terminal HttpFilter must be the last filter: terminal"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, true /* parseHttpFilter */, true); } @@ -1371,8 +1443,8 @@ public void parseHttpConnectionManager_unknownFilters() throws ResourceInvalidEx .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, true /* parseHttpFilter */, true /* does not matter */); } @@ -1383,14 +1455,14 @@ public void parseHttpConnectionManager_emptyFilters() throws ResourceInvalidExce .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Missing HttpFilter in HttpConnectionManager."); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, true /* parseHttpFilter */, true /* does not matter */); } @Test public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception { - ClientXdsClient.enableRouteLookup = true; + XdsResourceType.enableRouteLookup = true; RouteLookupConfig routeLookupConfig = RouteLookupConfig.newBuilder() .addGrpcKeybuilders( GrpcKeyBuilder.newBuilder() @@ -1429,8 +1501,8 @@ public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception .addRoutes(route))) .build(); - io.grpc.xds.HttpConnectionManager parsedHcm = ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, false /* parseHttpFilter */, + io.grpc.xds.HttpConnectionManager parsedHcm = XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false /* parseHttpFilter */, true /* does not matter */); VirtualHost virtualHost = Iterables.getOnlyElement(parsedHcm.virtualHosts()); @@ -1443,7 +1515,7 @@ public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception @Test public void parseHttpConnectionManager_duplicatePluginName() throws Exception { - ClientXdsClient.enableRouteLookup = true; + XdsResourceType.enableRouteLookup = true; RouteLookupConfig routeLookupConfig1 = RouteLookupConfig.newBuilder() .addGrpcKeybuilders( GrpcKeyBuilder.newBuilder() @@ -1506,14 +1578,14 @@ public void parseHttpConnectionManager_duplicatePluginName() throws Exception { thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Multiple ClusterSpecifierPlugins with the same name: rls-plugin-1"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, false /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false /* parseHttpFilter */, true /* does not matter */); } @Test public void parseHttpConnectionManager_pluginNameNotFound() throws Exception { - ClientXdsClient.enableRouteLookup = true; + XdsResourceType.enableRouteLookup = true; RouteLookupConfig routeLookupConfig = RouteLookupConfig.newBuilder() .addGrpcKeybuilders( GrpcKeyBuilder.newBuilder() @@ -1555,15 +1627,93 @@ public void parseHttpConnectionManager_pluginNameNotFound() throws Exception { thrown.expect(ResourceInvalidException.class); thrown.expectMessage("ClusterSpecifierPlugin for [invalid-plugin-name] not found"); - ClientXdsClient.parseHttpConnectionManager( - hcm, new HashSet(), filterRegistry, false /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false /* parseHttpFilter */, true /* does not matter */); } + + @Test + public void parseHttpConnectionManager_optionalPlugin() throws ResourceInvalidException { + XdsResourceType.enableRouteLookup = true; + + // RLS Plugin, and a route to it. + RouteLookupConfig routeLookupConfig = RouteLookupConfig.newBuilder() + .addGrpcKeybuilders( + GrpcKeyBuilder.newBuilder() + .addNames(Name.newBuilder().setService("service1")) + .addNames(Name.newBuilder().setService("service2")) + .addHeaders( + NameMatcher.newBuilder().setKey("key1").addNames("v1").setRequiredMatch(true))) + .setLookupService("rls-cbt.googleapis.com") + .setLookupServiceTimeout(Durations.fromMillis(1234)) + .setCacheSizeBytes(5000) + .addValidTargets("valid-target") + .build(); + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin rlsPlugin = + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin.newBuilder() + .setExtension( + TypedExtensionConfig.newBuilder() + .setName("rls-plugin-1") + .setTypedConfig(Any.pack( + RouteLookupClusterSpecifier.newBuilder() + .setRouteLookupConfig(routeLookupConfig) + .build()))) + .build(); + io.envoyproxy.envoy.config.route.v3.Route rlsRoute = + io.envoyproxy.envoy.config.route.v3.Route.newBuilder() + .setName("rls-route-1") + .setMatch(io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder().setPrefix("")) + .setRoute(io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin("rls-plugin-1")) + .build(); + + // Unknown optional plugin, and a route to it. + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin optionalPlugin = + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin.newBuilder() + .setIsOptional(true) + .setExtension( + TypedExtensionConfig.newBuilder() + .setName("optional-plugin-1") + .setTypedConfig(Any.pack(StringValue.of("unregistered"))) + .build()) + .build(); + io.envoyproxy.envoy.config.route.v3.Route optionalRoute = + io.envoyproxy.envoy.config.route.v3.Route.newBuilder() + .setName("optional-route-1") + .setMatch(io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder().setPrefix("")) + .setRoute(io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin("optional-plugin-1")) + .build(); + + + // Build and parse the route. + RouteConfiguration routeConfig = RouteConfiguration.newBuilder() + .addClusterSpecifierPlugins(rlsPlugin) + .addClusterSpecifierPlugins(optionalPlugin) + .addVirtualHosts( + io.envoyproxy.envoy.config.route.v3.VirtualHost.newBuilder() + .setName("virtual-host-1") + .addRoutes(rlsRoute) + .addRoutes(optionalRoute)) + .build(); + io.grpc.xds.HttpConnectionManager parsedHcm = XdsListenerResource.parseHttpConnectionManager( + HttpConnectionManager.newBuilder().setRouteConfig(routeConfig).build(), filterRegistry, + false /* parseHttpFilter */, true /* does not matter */); + + // Verify that the only route left is the one with the registered RLS plugin `rls-plugin-1`, + // while the route with unregistered optional `optional-plugin-`1 has been skipped. + VirtualHost virtualHost = Iterables.getOnlyElement(parsedHcm.virtualHosts()); + Route parsedRoute = Iterables.getOnlyElement(virtualHost.routes()); + NamedPluginConfig namedPluginConfig = + parsedRoute.routeAction().namedClusterSpecifierPluginConfig(); + assertThat(namedPluginConfig.name()).isEqualTo("rls-plugin-1"); + assertThat(namedPluginConfig.config()).isInstanceOf(RlsPluginConfig.class); + } + @Test public void parseHttpConnectionManager_validateRdsConfigSource() throws Exception { - ClientXdsClient.enableRouteLookup = true; - Set rdsResources = new HashSet<>(); + XdsResourceType.enableRouteLookup = true; HttpConnectionManager hcm1 = HttpConnectionManager.newBuilder() @@ -1572,8 +1722,8 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .setConfigSource( ConfigSource.newBuilder().setAds(AggregatedConfigSource.getDefaultInstance()))) .build(); - ClientXdsClient.parseHttpConnectionManager( - hcm1, rdsResources, filterRegistry, false /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm1, filterRegistry, false /* parseHttpFilter */, true /* does not matter */); HttpConnectionManager hcm2 = @@ -1583,8 +1733,8 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .setConfigSource( ConfigSource.newBuilder().setSelf(SelfConfigSource.getDefaultInstance()))) .build(); - ClientXdsClient.parseHttpConnectionManager( - hcm2, rdsResources, filterRegistry, false /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm2, filterRegistry, false /* parseHttpFilter */, true /* does not matter */); HttpConnectionManager hcm3 = @@ -1592,13 +1742,14 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .setRds(Rds.newBuilder() .setRouteConfigName("rds-config-foo") .setConfigSource( - ConfigSource.newBuilder().setPath("foo-path"))) + ConfigSource.newBuilder() + .setPathConfigSource(PathConfigSource.newBuilder().setPath("foo-path")))) .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "HttpConnectionManager contains invalid RDS: must specify ADS or self ConfigSource"); - ClientXdsClient.parseHttpConnectionManager( - hcm3, rdsResources, filterRegistry, false /* parseHttpFilter */, + XdsListenerResource.parseHttpConnectionManager( + hcm3, filterRegistry, false /* parseHttpFilter */, true /* does not matter */); } @@ -1636,7 +1787,8 @@ public ConfigOrError parsePlugin(Message rawProtoMessage .setTypedConfig(Any.pack(typedStruct))) .build(); - PluginConfig pluginConfig = ClientXdsClient.parseClusterSpecifierPlugin(pluginProto, registry); + PluginConfig pluginConfig = XdsRouteConfigureResource + .parseClusterSpecifierPlugin(pluginProto, registry); assertThat(pluginConfig).isInstanceOf(TestPluginConfig.class); } @@ -1674,7 +1826,8 @@ public ConfigOrError parsePlugin(Message rawProtoMessage .setTypedConfig(Any.pack(typedStruct))) .build(); - PluginConfig pluginConfig = ClientXdsClient.parseClusterSpecifierPlugin(pluginProto, registry); + PluginConfig pluginConfig = XdsRouteConfigureResource + .parseClusterSpecifierPlugin(pluginProto, registry); assertThat(pluginConfig).isInstanceOf(TestPluginConfig.class); } @@ -1691,79 +1844,83 @@ public void parseClusterSpecifierPlugin_unregisteredPlugin() throws Exception { thrown.expectMessage( "Unsupported ClusterSpecifierPlugin type: type.googleapis.com/google.protobuf.StringValue"); - ClientXdsClient.parseClusterSpecifierPlugin(pluginProto, registry); + XdsRouteConfigureResource.parseClusterSpecifierPlugin(pluginProto, registry); } @Test - public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInvalidException { - Cluster cluster = Cluster.newBuilder() - .setName("cluster-foo.googleapis.com") - .setType(DiscoveryType.EDS) - .setEdsClusterConfig( - EdsClusterConfig.newBuilder() - .setEdsConfig( - ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.getDefaultInstance())) - .setServiceName("service-foo.googleapis.com")) - .setLbPolicy(LbPolicy.RING_HASH) - .build(); + public void parseClusterSpecifierPlugin_unregisteredPlugin_optional() + throws ResourceInvalidException { + ClusterSpecifierPluginRegistry registry = ClusterSpecifierPluginRegistry.newRegistry(); + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin pluginProto = + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin.newBuilder() + .setExtension(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(StringValue.of("unregistered")))) + .setIsOptional(true) + .build(); - CdsUpdate update = ClientXdsClient.processCluster( - cluster, new HashSet(), null, LRS_SERVER_INFO); - assertThat(update.lbPolicy()).isEqualTo(CdsUpdate.LbPolicy.RING_HASH); - assertThat(update.minRingSize()) - .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE); - assertThat(update.maxRingSize()) - .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE); + PluginConfig pluginConfig = XdsRouteConfigureResource + .parseClusterSpecifierPlugin(pluginProto, registry); + assertThat(pluginConfig).isNull(); } @Test - public void parseCluster_leastRequestLbPolicy_defaultLbConfig() throws ResourceInvalidException { - ClientXdsClient.enableLeastRequest = true; - Cluster cluster = Cluster.newBuilder() - .setName("cluster-foo.googleapis.com") - .setType(DiscoveryType.EDS) - .setEdsClusterConfig( - EdsClusterConfig.newBuilder() - .setEdsConfig( - ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.getDefaultInstance())) - .setServiceName("service-foo.googleapis.com")) - .setLbPolicy(LbPolicy.LEAST_REQUEST) + public void parseClusterSpecifierPlugin_brokenPlugin() { + ClusterSpecifierPluginRegistry registry = ClusterSpecifierPluginRegistry.newRegistry(); + + Any failingAny = Any.newBuilder() + .setTypeUrl("type.googleapis.com/xds.type.v3.TypedStruct") + .setValue(ByteString.copyFromUtf8("fail")) .build(); - CdsUpdate update = ClientXdsClient.processCluster( - cluster, new HashSet(), null, LRS_SERVER_INFO); - assertThat(update.lbPolicy()).isEqualTo(CdsUpdate.LbPolicy.LEAST_REQUEST); - assertThat(update.choiceCount()) - .isEqualTo(ClientXdsClient.DEFAULT_LEAST_REQUEST_CHOICE_COUNT); + TypedExtensionConfig brokenPlugin = TypedExtensionConfig.newBuilder() + .setName("bad-plugin-1") + .setTypedConfig(failingAny) + .build(); + + try { + XdsRouteConfigureResource.parseClusterSpecifierPlugin( + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin.newBuilder() + .setExtension(brokenPlugin) + .build(), + registry); + fail("Expected ResourceInvalidException"); + } catch (ResourceInvalidException e) { + assertThat(e).hasMessageThat() + .startsWith("ClusterSpecifierPlugin [bad-plugin-1] contains invalid proto"); + } } @Test - public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { - Cluster cluster = Cluster.newBuilder() - .setName("cluster-foo.googleapis.com") - .setType(DiscoveryType.EDS) - .setEdsClusterConfig( - EdsClusterConfig.newBuilder() - .setEdsConfig( - ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.getDefaultInstance())) - .setServiceName("service-foo.googleapis.com")) - .setLbPolicy(LbPolicy.ROUND_ROBIN) - .addTransportSocketMatches( - Cluster.TransportSocketMatch.newBuilder().setName("match1").build()) + public void parseClusterSpecifierPlugin_brokenPlugin_optional() { + ClusterSpecifierPluginRegistry registry = ClusterSpecifierPluginRegistry.newRegistry(); + + Any failingAny = Any.newBuilder() + .setTypeUrl("type.googleapis.com/xds.type.v3.TypedStruct") + .setValue(ByteString.copyFromUtf8("fail")) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "Cluster cluster-foo.googleapis.com: transport-socket-matches not supported."); - ClientXdsClient.processCluster(cluster, new HashSet(), null, LRS_SERVER_INFO); + TypedExtensionConfig brokenPlugin = TypedExtensionConfig.newBuilder() + .setName("bad-plugin-1") + .setTypedConfig(failingAny) + .build(); + + // Despite being optional, still should fail. + try { + XdsRouteConfigureResource.parseClusterSpecifierPlugin( + io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin.newBuilder() + .setIsOptional(true) + .setExtension(brokenPlugin) + .build(), + registry); + fail("Expected ResourceInvalidException"); + } catch (ResourceInvalidException e) { + assertThat(e).hasMessageThat() + .startsWith("ClusterSpecifierPlugin [bad-plugin-1] contains invalid proto"); + } } @Test - public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMax() - throws ResourceInvalidException { + public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") .setType(DiscoveryType.EDS) @@ -1774,21 +1931,18 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMa .setAds(AggregatedConfigSource.getDefaultInstance())) .setServiceName("service-foo.googleapis.com")) .setLbPolicy(LbPolicy.RING_HASH) - .setRingHashLbConfig( - RingHashLbConfig.newBuilder() - .setHashFunction(HashFunction.XX_HASH) - .setMinimumRingSize(UInt64Value.newBuilder().setValue(1000L)) - .setMaximumRingSize(UInt64Value.newBuilder().setValue(100L))) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); - ClientXdsClient.processCluster(cluster, new HashSet(), null, LRS_SERVER_INFO); + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(update.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("ring_hash_experimental"); } @Test - public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_tooLargeRingSize() - throws ResourceInvalidException { + public void parseCluster_leastRequestLbPolicy_defaultLbConfig() throws ResourceInvalidException { + XdsResourceType.enableLeastRequest = true; Cluster cluster = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") .setType(DiscoveryType.EDS) @@ -1798,25 +1952,21 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_tooLargeRingSize ConfigSource.newBuilder() .setAds(AggregatedConfigSource.getDefaultInstance())) .setServiceName("service-foo.googleapis.com")) - .setLbPolicy(LbPolicy.RING_HASH) - .setRingHashLbConfig( - RingHashLbConfig.newBuilder() - .setHashFunction(HashFunction.XX_HASH) - .setMinimumRingSize(UInt64Value.newBuilder().setValue(1000L)) - .setMaximumRingSize( - UInt64Value.newBuilder() - .setValue(ClientXdsClient.MAX_RING_HASH_LB_POLICY_RING_SIZE + 1))) + .setLbPolicy(LbPolicy.LEAST_REQUEST) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); - ClientXdsClient.processCluster(cluster, new HashSet(), null, LRS_SERVER_INFO); + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(update.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("least_request_experimental"); } @Test - public void parseCluster_leastRequestLbPolicy_invalidChoiceCountConfig_tooSmallChoiceCount() - throws ResourceInvalidException { - ClientXdsClient.enableLeastRequest = true; + public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") .setType(DiscoveryType.EDS) @@ -1826,21 +1976,20 @@ public void parseCluster_leastRequestLbPolicy_invalidChoiceCountConfig_tooSmallC ConfigSource.newBuilder() .setAds(AggregatedConfigSource.getDefaultInstance())) .setServiceName("service-foo.googleapis.com")) - .setLbPolicy(LbPolicy.LEAST_REQUEST) - .setLeastRequestLbConfig( - LeastRequestLbConfig.newBuilder() - .setChoiceCount(UInt32Value.newBuilder().setValue(1)) - ) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .addTransportSocketMatches( + Cluster.TransportSocketMatch.newBuilder().setName("match1").build()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid least_request_lb_config"); - ClientXdsClient.processCluster(cluster, new HashSet(), null, LRS_SERVER_INFO); + thrown.expectMessage( + "Cluster cluster-foo.googleapis.com: transport-socket-matches not supported."); + XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); } @Test public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidException { - Set retainedEdsResources = new HashSet<>(); Cluster cluster1 = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") .setType(DiscoveryType.EDS) @@ -1852,7 +2001,8 @@ public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidExcepti .setServiceName("service-foo.googleapis.com")) .setLbPolicy(LbPolicy.ROUND_ROBIN) .build(); - ClientXdsClient.processCluster(cluster1, retainedEdsResources, null, LRS_SERVER_INFO); + XdsClusterResource.processCluster(cluster1, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); Cluster cluster2 = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") @@ -1865,7 +2015,8 @@ public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidExcepti .setServiceName("service-foo.googleapis.com")) .setLbPolicy(LbPolicy.ROUND_ROBIN) .build(); - ClientXdsClient.processCluster(cluster2, retainedEdsResources, null, LRS_SERVER_INFO); + XdsClusterResource.processCluster(cluster2, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); Cluster cluster3 = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") @@ -1874,7 +2025,7 @@ public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidExcepti EdsClusterConfig.newBuilder() .setEdsConfig( ConfigSource.newBuilder() - .setPath("foo-path")) + .setPathConfigSource(PathConfigSource.newBuilder().setPath("foo-path"))) .setServiceName("service-foo.googleapis.com")) .setLbPolicy(LbPolicy.ROUND_ROBIN) .build(); @@ -1883,7 +2034,8 @@ public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidExcepti thrown.expectMessage( "Cluster cluster-foo.googleapis.com: field eds_cluster_config must be set to indicate to" + " use EDS over ADS or self ConfigSource"); - ClientXdsClient.processCluster(cluster3, retainedEdsResources, null, LRS_SERVER_INFO); + XdsClusterResource.processCluster(cluster3, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); } @Test @@ -1895,8 +2047,18 @@ public void parseServerSideListener_invalidTrafficDirection() throws ResourceInv .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 with invalid traffic direction: OUTBOUND"); - ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, true /* does not matter */); + } + + @Test + public void parseServerSideListener_noTrafficDirection() throws ResourceInvalidException { + Listener listener = + Listener.newBuilder() + .setName("listener1") + .build(); + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, true /* does not matter */); } @Test @@ -1909,8 +2071,8 @@ public void parseServerSideListener_listenerFiltersPresent() throws ResourceInva .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 cannot have listener_filters"); - ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, true /* does not matter */); } @Test @@ -1923,8 +2085,8 @@ public void parseServerSideListener_useOriginalDst() throws ResourceInvalidExcep .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 cannot have use_original_dst set to true"); - ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); + XdsListenerResource.parseServerSideListener( + listener,null, filterRegistry, null, true /* does not matter */); } @Test @@ -1972,8 +2134,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("FilterChainMatch must be unique. Found duplicate:"); - ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, true /* does not matter */); } @Test @@ -2021,8 +2183,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("FilterChainMatch must be unique. Found duplicate:"); - ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); + XdsListenerResource.parseServerSideListener( + listener,null, filterRegistry, null, true /* does not matter */); } @Test @@ -2070,8 +2232,8 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .setTrafficDirection(TrafficDirection.INBOUND) .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); - ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, true /* does not matter */); } @Test @@ -2085,8 +2247,8 @@ public void parseFilterChain_noHcm() throws ResourceInvalidException { thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); - ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, null, + XdsListenerResource.parseFilterChain( + filterChain, null, filterRegistry, null, null, true /* does not matter */); } @@ -2104,8 +2266,8 @@ public void parseFilterChain_duplicateFilter() throws ResourceInvalidException { thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); - ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, null, + XdsListenerResource.parseFilterChain( + filterChain, null, filterRegistry, null, null, true /* does not matter */); } @@ -2123,8 +2285,8 @@ public void parseFilterChain_filterMissingTypedConfig() throws ResourceInvalidEx thrown.expectMessage( "FilterChain filter-chain-foo contains filter envoy.http_connection_manager " + "without typed_config"); - ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, null, + XdsListenerResource.parseFilterChain( + filterChain, null, filterRegistry, null, null, true /* does not matter */); } @@ -2146,8 +2308,8 @@ public void parseFilterChain_unsupportedFilter() throws ResourceInvalidException thrown.expectMessage( "FilterChain filter-chain-foo contains filter unsupported with unsupported " + "typed_config type unsupported-type-url"); - ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, null, + XdsListenerResource.parseFilterChain( + filterChain, null, filterRegistry, null, null, true /* does not matter */); } @@ -2174,11 +2336,11 @@ public void parseFilterChain_noName() throws ResourceInvalidException { .build())) .build(); - EnvoyServerProtoData.FilterChain parsedFilterChain1 = ClientXdsClient.parseFilterChain( - filterChain1, new HashSet(), null, filterRegistry, null, + EnvoyServerProtoData.FilterChain parsedFilterChain1 = XdsListenerResource.parseFilterChain( + filterChain1, null, filterRegistry, null, null, true /* does not matter */); - EnvoyServerProtoData.FilterChain parsedFilterChain2 = ClientXdsClient.parseFilterChain( - filterChain2, new HashSet(), null, filterRegistry, null, + EnvoyServerProtoData.FilterChain parsedFilterChain2 = XdsListenerResource.parseFilterChain( + filterChain2, null, filterRegistry, null, null, true /* does not matter */); assertThat(parsedFilterChain1.name()).isEqualTo(parsedFilterChain2.name()); } @@ -2190,7 +2352,7 @@ public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with tls_params is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2200,7 +2362,7 @@ public void validateCommonTlsContext_customHandshaker() throws ResourceInvalidEx .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with custom_handshaker is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2210,7 +2372,7 @@ public void validateCommonTlsContext_validationContext() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2222,7 +2384,7 @@ public void validateCommonTlsContext_validationContextSdsSecretConfig() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with validation_context_sds_secret_config is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2236,7 +2398,7 @@ public void validateCommonTlsContext_validationContextCertificateProvider() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with validation_context_certificate_provider is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2251,7 +2413,7 @@ public void validateCommonTlsContext_validationContextCertificateProviderInstanc thrown.expectMessage( "common-tls-context with validation_context_certificate_provider_instance is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2262,7 +2424,7 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredFo thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "tls_certificate_provider_instance is required in downstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, true); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, true); } @Test @@ -2273,7 +2435,7 @@ public void validateCommonTlsContext_tlsNewCertificateProviderInstance() .setTlsCertificateProviderInstance( CertificateProviderPluginInstance.newBuilder().setInstanceName("name1").build()) .build(); - ClientXdsClient + XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @@ -2285,7 +2447,7 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance() .setTlsCertificateCertificateProviderInstance( CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) .build(); - ClientXdsClient + XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @@ -2300,7 +2462,7 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBoot thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "CertificateProvider instance name 'bad-name' not defined in the bootstrap file."); - ClientXdsClient + XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @@ -2315,7 +2477,7 @@ public void validateCommonTlsContext_validationContextProviderInstance() CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) .build()) .build(); - ClientXdsClient + XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } @@ -2333,7 +2495,7 @@ public void validateCommonTlsContext_validationContextProviderInstance_absentInB thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); - ClientXdsClient + XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } @@ -2345,7 +2507,7 @@ public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInval .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("tls_certificate_provider_instance is unset"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2357,7 +2519,7 @@ public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "tls_certificate_provider_instance is unset"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2371,7 +2533,7 @@ public void validateCommonTlsContext_tlsCertificateCertificateProvider() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "tls_certificate_provider_instance is unset"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2381,7 +2543,7 @@ public void validateCommonTlsContext_combinedValidationContext_isRequiredForClie .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2394,7 +2556,7 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "ca_certificate_provider_instance is required in upstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2414,7 +2576,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextForS .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("match_subject_alt_names only allowed in upstream_tls_context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true); } @Test @@ -2434,7 +2596,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri thrown.expect(ResourceInvalidException.class); thrown.expectMessage("verify_certificate_spki in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -2454,7 +2616,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri thrown.expect(ResourceInvalidException.class); thrown.expectMessage("verify_certificate_hash in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -2475,7 +2637,7 @@ public void validateCommonTlsContext_combinedValContextDfltValContextRequireSign thrown.expectMessage( "require_signed_certificate_timestamp in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -2494,7 +2656,7 @@ public void validateCommonTlsContext_combinedValidationContextWithDefaultValidat .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("crl in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -2514,7 +2676,7 @@ public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomV thrown.expect(ResourceInvalidException.class); thrown.expectMessage("custom_validator_config in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -2522,7 +2684,7 @@ public void validateDownstreamTlsContext_noCommonTlsContext() throws ResourceInv DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.getDefaultInstance(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context is required in downstream-tls-context"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, null); + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, null); } @Test @@ -2542,7 +2704,7 @@ public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("downstream-tls-context with require-sni is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test @@ -2563,7 +2725,7 @@ public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceIn thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "downstream-tls-context with ocsp_staple_policy value STRICT_STAPLING is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test @@ -2571,27 +2733,29 @@ public void validateUpstreamTlsContext_noCommonTlsContext() throws ResourceInval UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.getDefaultInstance(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context is required in upstream-tls-context"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext, null); + XdsClusterResource.validateUpstreamTlsContext(upstreamTlsContext, null); } @Test public void validateResourceName() { String traditionalResource = "cluster1.google.com"; - assertThat(XdsClient.isResourceNameValid(traditionalResource, ResourceType.CDS.typeUrl())) - .isTrue(); - assertThat(XdsClient.isResourceNameValid(traditionalResource, ResourceType.RDS.typeUrlV2())) + assertThat(XdsClient.isResourceNameValid(traditionalResource, + XdsClusterResource.getInstance().typeUrl())) .isTrue(); String invalidPath = "xdstp:/abc/efg"; - assertThat(XdsClient.isResourceNameValid(invalidPath, ResourceType.CDS.typeUrl())).isFalse(); + assertThat(XdsClient.isResourceNameValid(invalidPath, + XdsClusterResource.getInstance().typeUrl())).isFalse(); String invalidPath2 = "xdstp:///envoy.config.route.v3.RouteConfiguration"; - assertThat(XdsClient.isResourceNameValid(invalidPath2, ResourceType.RDS.typeUrl())).isFalse(); + assertThat(XdsClient.isResourceNameValid(invalidPath2, + XdsRouteConfigureResource.getInstance().typeUrl())).isFalse(); String typeMatch = "xdstp:///envoy.config.route.v3.RouteConfiguration/foo/route1"; - assertThat(XdsClient.isResourceNameValid(typeMatch, ResourceType.LDS.typeUrl())).isFalse(); - assertThat(XdsClient.isResourceNameValid(typeMatch, ResourceType.RDS.typeUrl())).isTrue(); - assertThat(XdsClient.isResourceNameValid(typeMatch, ResourceType.RDS.typeUrlV2())).isFalse(); + assertThat(XdsClient.isResourceNameValid(typeMatch, + XdsListenerResource.getInstance().typeUrl())).isFalse(); + assertThat(XdsClient.isResourceNameValid(typeMatch, + XdsRouteConfigureResource.getInstance().typeUrl())).isTrue(); } @Test diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/XdsClientImplTestBase.java similarity index 59% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplTestBase.java index ba182be76d1..0f18d3d387f 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplTestBase.java @@ -18,11 +18,9 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; +import static io.grpc.xds.XdsClientImpl.XdsChannelFactory.DEFAULT_XDS_CHANNEL_FACTORY; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -35,19 +33,29 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.Any; +import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.api.v2.DiscoveryRequest; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; import io.envoyproxy.envoy.config.route.v3.FilterConfig; +import io.envoyproxy.envoy.config.route.v3.WeightedCluster; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc; +import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc; import io.grpc.BindableService; import io.grpc.ChannelCredentials; import io.grpc.Context; import io.grpc.Context.CancellableContext; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; +import io.grpc.Server; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.inprocess.InProcessChannelBuilder; @@ -56,35 +64,38 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.FakeClock.TaskFilter; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; import io.grpc.internal.TimeProvider; +import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.AuthorityInfo; +import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.ClientXdsClient.XdsChannelFactory; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.EnvoyProtoData.Node; +import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; import io.grpc.xds.FaultConfig.FractionalPercent.DenominatorType; import io.grpc.xds.LoadStatsManager2.ClusterDropStats; -import io.grpc.xds.XdsClient.CdsResourceWatcher; -import io.grpc.xds.XdsClient.CdsUpdate; -import io.grpc.xds.XdsClient.CdsUpdate.ClusterType; -import io.grpc.xds.XdsClient.CdsUpdate.LbPolicy; -import io.grpc.xds.XdsClient.EdsResourceWatcher; -import io.grpc.xds.XdsClient.EdsUpdate; -import io.grpc.xds.XdsClient.LdsResourceWatcher; -import io.grpc.xds.XdsClient.LdsUpdate; -import io.grpc.xds.XdsClient.RdsResourceWatcher; -import io.grpc.xds.XdsClient.RdsUpdate; import io.grpc.xds.XdsClient.ResourceMetadata; import io.grpc.xds.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.XdsClient.ResourceMetadata.UpdateFailureState; +import io.grpc.xds.XdsClient.ResourceUpdate; import io.grpc.xds.XdsClient.ResourceWatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsClientImpl.XdsChannelFactory; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.io.IOException; import java.util.ArrayDeque; import java.util.Arrays; @@ -92,6 +103,8 @@ import java.util.List; import java.util.Map; import java.util.Queue; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; @@ -103,6 +116,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -110,10 +124,12 @@ import org.mockito.MockitoAnnotations; /** - * Tests for {@link ClientXdsClient}. + * Tests for {@link XdsClientImpl}. */ @RunWith(JUnit4.class) -public abstract class ClientXdsClientTestBase { +// The base class was used to test both xds v2 and v3. V2 is dropped now so the base class is not +// necessary. Still keep it for future version usage. Remove if too much trouble to maintain. +public abstract class XdsClientImplTestBase { private static final String SERVER_URI = "trafficdirector.googleapis.com"; private static final String SERVER_URI_CUSTOME_AUTHORITY = "trafficdirector2.googleapis.com"; private static final String SERVER_URI_EMPTY_AUTHORITY = "trafficdirector3.googleapis.com"; @@ -126,11 +142,17 @@ public abstract class ClientXdsClientTestBase { private static final String VERSION_1 = "42"; private static final String VERSION_2 = "43"; private static final String VERSION_3 = "44"; - private static final Node NODE = Node.newBuilder().build(); + private static final String NODE_ID = "cool-node-id"; + private static final Node NODE = Node.newBuilder().setId(NODE_ID).build(); private static final Any FAILING_ANY = MessageFactory.FAILING_ANY; private static final ChannelCredentials CHANNEL_CREDENTIALS = InsecureChannelCredentials.create(); - private final ServerInfo lrsServerInfo = - ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, useProtocolV3()); + private static final XdsResourceType LDS = XdsListenerResource.getInstance(); + private static final XdsResourceType CDS = XdsClusterResource.getInstance(); + private static final XdsResourceType RDS = XdsRouteConfigureResource.getInstance(); + private static final XdsResourceType EDS = XdsEndpointResource.getInstance(); + + // xDS control plane server info. + private ServerInfo xdsServerInfo; private static final FakeClock.TaskFilter RPC_RETRY_TASK_FILTER = new FakeClock.TaskFilter() { @@ -176,8 +198,12 @@ public boolean shouldAccept(Runnable command) { public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private final FakeClock fakeClock = new FakeClock(); - protected final Queue resourceDiscoveryCalls = new ArrayDeque<>(); + protected final BlockingDeque resourceDiscoveryCalls = + new LinkedBlockingDeque<>(1); + protected final BlockingDeque resourceDiscoveryCallsV2 = + new LinkedBlockingDeque<>(1); protected final Queue loadReportCalls = new ArrayDeque<>(); + protected final Queue loadReportCallsV2 = new ArrayDeque<>(); protected final AtomicBoolean adsEnded = new AtomicBoolean(true); protected final AtomicBoolean lrsEnded = new AtomicBoolean(true); private final MessageFactory mf = createMessageFactory(); @@ -206,7 +232,7 @@ public long currentTimeNanos() { // CDS test resources. private final Any testClusterRoundRobin = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, - null, false, null, "envoy.transport_sockets.tls", null + null, false, null, "envoy.transport_sockets.tls", null, null )); // EDS test resources. @@ -235,6 +261,7 @@ public long currentTimeNanos() { private ArgumentCaptor edsUpdateCaptor; @Captor private ArgumentCaptor errorCaptor; + @Mock private BackoffPolicy.Provider backoffPolicyProvider; @Mock @@ -242,23 +269,28 @@ public long currentTimeNanos() { @Mock private BackoffPolicy backoffPolicy2; @Mock - private LdsResourceWatcher ldsResourceWatcher; + private ResourceWatcher ldsResourceWatcher; @Mock - private RdsResourceWatcher rdsResourceWatcher; + private ResourceWatcher rdsResourceWatcher; @Mock - private CdsResourceWatcher cdsResourceWatcher; + private ResourceWatcher cdsResourceWatcher; @Mock - private EdsResourceWatcher edsResourceWatcher; + private ResourceWatcher edsResourceWatcher; @Mock private TlsContextManager tlsContextManager; private ManagedChannel channel; private ManagedChannel channelForCustomAuthority; private ManagedChannel channelForEmptyAuthority; - private ClientXdsClient xdsClient; + private XdsClientImpl xdsClient; private boolean originalEnableFaultInjection; private boolean originalEnableRbac; private boolean originalEnableLeastRequest; + private boolean originalEnableFederation; + private Server xdsServer; + private final String serverName = InProcessServerBuilder.generateName(); + private BindableService adsService = createAdsService(); + private BindableService lrsService = createLrsService(); @Before public void setUp() throws IOException { @@ -269,21 +301,22 @@ public void setUp() throws IOException { when(backoffPolicy2.nextBackoffNanos()).thenReturn(20L, 200L); // Start the server and the client. - originalEnableFaultInjection = ClientXdsClient.enableFaultInjection; - ClientXdsClient.enableFaultInjection = true; - originalEnableRbac = ClientXdsClient.enableRbac; + originalEnableFaultInjection = XdsResourceType.enableFaultInjection; + XdsResourceType.enableFaultInjection = true; + originalEnableRbac = XdsResourceType.enableRbac; assertThat(originalEnableRbac).isTrue(); - originalEnableLeastRequest = ClientXdsClient.enableLeastRequest; - ClientXdsClient.enableLeastRequest = true; - final String serverName = InProcessServerBuilder.generateName(); - cleanupRule.register( - InProcessServerBuilder - .forName(serverName) - .addService(createAdsService()) - .addService(createLrsService()) - .directExecutor() - .build() - .start()); + originalEnableLeastRequest = XdsResourceType.enableLeastRequest; + XdsResourceType.enableLeastRequest = true; + originalEnableFederation = BootstrapperImpl.enableFederation; + xdsServer = cleanupRule.register(InProcessServerBuilder + .forName(serverName) + .addService(adsService) + .addService(createAdsServiceV2()) + .addService(lrsService) + .addService(createLrsServiceV2()) + .directExecutor() + .build() + .start()); channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); XdsChannelFactory xdsChannelFactory = new XdsChannelFactory() { @@ -310,27 +343,28 @@ ManagedChannel create(ServerInfo serverInfo) { } }; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, + ignoreResourceDeletion()); Bootstrapper.BootstrapInfo bootstrapInfo = Bootstrapper.BootstrapInfo.builder() - .servers(Arrays.asList( - Bootstrapper.ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, useProtocolV3()))) - .node(EnvoyProtoData.Node.newBuilder().build()) + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) .authorities(ImmutableMap.of( "authority.xds.com", AuthorityInfo.create( "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", ImmutableList.of(Bootstrapper.ServerInfo.create( - SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS, useProtocolV3()))), + SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS))), "", AuthorityInfo.create( "xdstp:///envoy.config.listener.v3.Listener/%s", ImmutableList.of(Bootstrapper.ServerInfo.create( - SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS, useProtocolV3()))))) + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) .certProviders(ImmutableMap.of("cert-instance-name", CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) .build(); xdsClient = - new ClientXdsClient( + new XdsClientImpl( xdsChannelFactory, bootstrapInfo, Context.ROOT, @@ -346,9 +380,10 @@ SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS, useProtocolV3()))))) @After public void tearDown() { - ClientXdsClient.enableFaultInjection = originalEnableFaultInjection; - ClientXdsClient.enableRbac = originalEnableRbac; - ClientXdsClient.enableLeastRequest = originalEnableLeastRequest; + XdsResourceType.enableFaultInjection = originalEnableFaultInjection; + XdsResourceType.enableRbac = originalEnableRbac; + XdsResourceType.enableLeastRequest = originalEnableLeastRequest; + BootstrapperImpl.enableFederation = originalEnableFederation; xdsClient.shutdown(); channel.shutdown(); // channel not owned by XdsClient assertThat(adsEnded.get()).isTrue(); @@ -358,6 +393,9 @@ public void tearDown() { protected abstract boolean useProtocolV3(); + /** Whether ignore_resource_deletion server feature is enabled for the given test. */ + protected abstract boolean ignoreResourceDeletion(); + protected abstract BindableService createAdsService(); protected abstract BindableService createLrsService(); @@ -383,15 +421,32 @@ protected static boolean matchErrorDetail( private void verifySubscribedResourcesMetadataSizes( int ldsSize, int cdsSize, int rdsSize, int edsSize) { - Map> subscribedResourcesMetadata = + Map, Map> subscribedResourcesMetadata = awaitSubscribedResourcesMetadata(); - assertThat(subscribedResourcesMetadata.get(LDS)).hasSize(ldsSize); - assertThat(subscribedResourcesMetadata.get(CDS)).hasSize(cdsSize); - assertThat(subscribedResourcesMetadata.get(RDS)).hasSize(rdsSize); - assertThat(subscribedResourcesMetadata.get(EDS)).hasSize(edsSize); + Map> subscribedTypeUrls = + xdsClient.getSubscribedResourceTypesWithTypeUrl(); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, LDS, ldsSize); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, CDS, cdsSize); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, RDS, rdsSize); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, EDS, edsSize); + } + + private void verifyResourceCount( + Map> subscribedTypeUrls, + Map, Map> subscribedResourcesMetadata, + XdsResourceType type, + int size) { + if (size == 0) { + assertThat(subscribedTypeUrls.containsKey(type.typeUrl())).isFalse(); + assertThat(subscribedResourcesMetadata.containsKey(type)).isFalse(); + } else { + assertThat(subscribedTypeUrls.containsKey(type.typeUrl())).isTrue(); + assertThat(subscribedResourcesMetadata.get(type)).hasSize(size); + } } - private Map> awaitSubscribedResourcesMetadata() { + private Map, Map> + awaitSubscribedResourcesMetadata() { try { return xdsClient.getSubscribedResourcesMetadataSnapshot().get(20, TimeUnit.SECONDS); } catch (Exception e) { @@ -403,20 +458,20 @@ private Map> awaitSubscribedResource } /** Verify the resource requested, but not updated. */ - private void verifyResourceMetadataRequested(ResourceType type, String resourceName) { + private void verifyResourceMetadataRequested(XdsResourceType type, String resourceName) { verifyResourceMetadata( type, resourceName, null, ResourceMetadataStatus.REQUESTED, "", 0, false); } /** Verify that the requested resource does not exist. */ - private void verifyResourceMetadataDoesNotExist(ResourceType type, String resourceName) { + private void verifyResourceMetadataDoesNotExist(XdsResourceType type, String resourceName) { verifyResourceMetadata( type, resourceName, null, ResourceMetadataStatus.DOES_NOT_EXIST, "", 0, false); } /** Verify the resource to be acked. */ private void verifyResourceMetadataAcked( - ResourceType type, String resourceName, Any rawResource, String versionInfo, + XdsResourceType type, String resourceName, Any rawResource, String versionInfo, long updateTimeNanos) { verifyResourceMetadata(type, resourceName, rawResource, ResourceMetadataStatus.ACKED, versionInfo, updateTimeNanos, false); @@ -427,7 +482,7 @@ private void verifyResourceMetadataAcked( * corresponding i-th element of {@code List failedDetails}. */ private void verifyResourceMetadataNacked( - ResourceType type, String resourceName, Any rawResource, String versionInfo, + XdsResourceType type, String resourceName, Any rawResource, String versionInfo, long updateTime, String failedVersion, long failedUpdateTimeNanos, List failedDetails) { ResourceMetadata resourceMetadata = @@ -449,7 +504,7 @@ private void verifyResourceMetadataNacked( } private ResourceMetadata verifyResourceMetadata( - ResourceType type, String resourceName, Any rawResource, ResourceMetadataStatus status, + XdsResourceType type, String resourceName, Any rawResource, ResourceMetadataStatus status, String versionInfo, long updateTimeNanos, boolean hasErrorState) { ResourceMetadata metadata = awaitSubscribedResourcesMetadata().get(type).get(resourceName); assertThat(metadata).isNotNull(); @@ -467,11 +522,83 @@ private ResourceMetadata verifyResourceMetadata( return metadata; } + private void verifyStatusWithNodeId(Status status, Code expectedCode, String expectedMsg) { + assertThat(status.getCode()).isEqualTo(expectedCode); + assertThat(status.getCause()).isNull(); + // Watcher.onError propagates status description to the channel, and we want to + // augment the description with the node id. + String description = (expectedMsg.isEmpty() ? "" : expectedMsg + " ") + "nodeID: " + NODE_ID; + assertThat(status.getDescription()).isEqualTo(description); + } + /** - * Helper method to validate {@link XdsClient.EdsUpdate} created for the test CDS resource - * {@link ClientXdsClientTestBase#testClusterLoadAssignment}. + * Verifies the LDS update against the golden Listener with vhosts {@link #testListenerVhosts}. */ - private void validateTestClusterLoadAssigment(EdsUpdate edsUpdate) { + private void verifyGoldenListenerVhosts(LdsUpdate ldsUpdate) { + assertThat(ldsUpdate.listener()).isNull(); + HttpConnectionManager hcm = ldsUpdate.httpConnectionManager(); + assertThat(hcm.rdsName()).isNull(); + assertThat(hcm.virtualHosts()).hasSize(VHOST_SIZE); + verifyGoldenHcm(hcm); + } + + /** + * Verifies the LDS update against the golden Listener with RDS name {@link #testListenerRds}. + */ + private void verifyGoldenListenerRds(LdsUpdate ldsUpdate) { + assertThat(ldsUpdate.listener()).isNull(); + HttpConnectionManager hcm = ldsUpdate.httpConnectionManager(); + assertThat(hcm.rdsName()).isEqualTo(RDS_RESOURCE); + assertThat(hcm.virtualHosts()).isNull(); + verifyGoldenHcm(hcm); + } + + private void verifyGoldenHcm(HttpConnectionManager hcm) { + if (useProtocolV3()) { + // The last configured filter has to be a terminal filter. + assertThat(hcm.httpFilterConfigs()).isNotNull(); + assertThat(hcm.httpFilterConfigs()).hasSize(1); + assertThat(hcm.httpFilterConfigs().get(0).name).isEqualTo("terminal"); + assertThat(hcm.httpFilterConfigs().get(0).filterConfig).isEqualTo(RouterFilter.ROUTER_CONFIG); + } else { + assertThat(hcm.httpFilterConfigs()).isNull(); + } + } + + /** + * Verifies the RDS update against the golden route config {@link #testRouteConfig}. + */ + private void verifyGoldenRouteConfig(RdsUpdate rdsUpdate) { + assertThat(rdsUpdate.virtualHosts).hasSize(VHOST_SIZE); + for (VirtualHost vhost : rdsUpdate.virtualHosts) { + assertThat(vhost.name()).contains("do not care"); + assertThat(vhost.domains()).hasSize(1); + assertThat(vhost.routes()).hasSize(1); + } + } + + /** + * Verifies the CDS update against the golden Round Robin Cluster {@link #testClusterRoundRobin}. + */ + private void verifyGoldenClusterRoundRobin(CdsUpdate cdsUpdate) { + assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); + assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate.edsServiceName()).isNull(); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); + assertThat(cdsUpdate.lrsServerInfo()).isNull(); + assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + } + + /** + * Verifies the EDS update against the golden Cluster with load assignment + * {@link #testClusterLoadAssignment}. + */ + private void validateGoldenClusterLoadAssignment(EdsUpdate edsUpdate) { assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); assertThat(edsUpdate.dropPolicies) .containsExactly( @@ -488,7 +615,8 @@ private void validateTestClusterLoadAssigment(EdsUpdate edsUpdate) { @Test public void ldsResourceNotFound() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); Any listener = Any.pack(mf.buildListenerWithApiListener("bar.googleapis.com", mf.buildRouteConfiguration("route-bar.googleapis.com", mf.buildOpaqueVirtualHosts(1)))); @@ -500,16 +628,36 @@ public void ldsResourceNotFound() { verifyResourceMetadataRequested(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } + @Test + public void ldsResourceUpdated_withXdstpResourceName_withUnknownAuthority() { + BootstrapperImpl.enableFederation = true; + String ldsResourceName = useProtocolV3() + ? "xdstp://unknown.example.com/envoy.config.listener.v3.Listener/listener1" + : "xdstp://unknown.example.com/envoy.api.v2.Listener/listener1"; + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceName, + ldsResourceWatcher); + verify(ldsResourceWatcher).onError(errorCaptor.capture()); + Status error = errorCaptor.getValue(); + assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(error.getDescription()).isEqualTo( + "Wrong configuration: xds server does not exist for resource " + ldsResourceName); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(),ldsResourceName, + ldsResourceWatcher); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + } + @Test public void ldsResponseErrorHandling_allResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); call.sendResponse(LDS, ImmutableList.of(FAILING_ANY, FAILING_ANY), VERSION_1, "0000"); @@ -525,7 +673,8 @@ public void ldsResponseErrorHandling_allResourcesFailedUnpack() { @Test public void ldsResponseErrorHandling_someResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); // Correct resource is in the middle to ensure processing continues on errors. @@ -547,14 +696,14 @@ public void ldsResponseErrorHandling_someResourcesFailedUnpack() { * Tests a subscribed LDS resource transitioned to and from the invalid state. * * @see - * A40-csds-support.md. + * A40-csds-support.md */ @Test public void ldsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchLdsResource("A", ldsResourceWatcher); - xdsClient.watchLdsResource("B", ldsResourceWatcher); - xdsClient.watchLdsResource("C", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"C", ldsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(LDS, "A"); @@ -587,7 +736,12 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); - verifyResourceMetadataDoesNotExist(LDS, "C"); + if (!ignoreResourceDeletion()) { + verifyResourceMetadataDoesNotExist(LDS, "C"); + } else { + // When resource deletion is disabled, {C} stays ACKed in the previous version VERSION_1. + verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + } call.verifyRequestNack(LDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); // LDS -> {B, C} version 3 @@ -597,7 +751,12 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(LDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> does not exist // {B, C} -> ACK, version 3 - verifyResourceMetadataDoesNotExist(LDS, "A"); + if (!ignoreResourceDeletion()) { + verifyResourceMetadataDoesNotExist(LDS, "A"); + } else { + // When resource deletion is disabled, {A} stays ACKed in the previous version VERSION_2. + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); + } verifyResourceMetadataAcked(LDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(LDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); call.verifyRequest(LDS, subscribedResourceNames, VERSION_3, "0002", NODE); @@ -605,14 +764,14 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { } @Test - public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscriptioin() { + public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscription() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchLdsResource("A", ldsResourceWatcher); - xdsClient.watchRdsResource("A.1", rdsResourceWatcher); - xdsClient.watchLdsResource("B", ldsResourceWatcher); - xdsClient.watchRdsResource("B.1", rdsResourceWatcher); - xdsClient.watchLdsResource("C", ldsResourceWatcher); - xdsClient.watchRdsResource("C.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"A.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"B.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"C", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"C.1", rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(LDS, "A"); @@ -661,26 +820,48 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti verifyResourceMetadataNacked( LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, errorsV2); - verifyResourceMetadataDoesNotExist(LDS, "C"); + if (!ignoreResourceDeletion()) { + verifyResourceMetadataDoesNotExist(LDS, "C"); + } else { + // When resource deletion is disabled, {C} stays ACKed in the previous version VERSION_1. + verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + } call.verifyRequestNack(LDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); - // {A.1} -> does not exist + // {A.1} -> version 1 // {B.1} -> version 1 - // {C.1} -> does not exist - verifyResourceMetadataDoesNotExist(RDS, "A.1"); + // {C.1} -> does not exist because {C} does not exist + verifyResourceMetadataAcked(RDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); - verifyResourceMetadataDoesNotExist(RDS, "C.1"); + // Verify {C.1} stays in the previous version VERSION_1, no matter {C} is deleted or not. + verifyResourceMetadataAcked(RDS, "C.1", resourcesV11.get("C.1"), VERSION_1, + TIME_INCREMENT * 2); } @Test public void ldsResourceFound_containsVirtualHosts() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); // Client sends an ACK LDS request. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + } + + @Test + public void wrappedLdsResource() { + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + + // Client sends an ACK LDS request. + call.sendResponse(LDS, mf.buildWrappedResource(testListenerVhosts), VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -688,45 +869,48 @@ public void ldsResourceFound_containsVirtualHosts() { @Test public void ldsResourceFound_containsRdsName() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); // Client sends an ACK LDS request. call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().rdsName()) - .isEqualTo(RDS_RESOURCE); + verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } @Test + @SuppressWarnings("unchecked") public void cachedLdsResource_data() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); // Client sends an ACK LDS request. call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - LdsResourceWatcher watcher = mock(LdsResourceWatcher.class); - xdsClient.watchLdsResource(LDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, watcher); verify(watcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().rdsName()) - .isEqualTo(RDS_RESOURCE); + verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } @Test + @SuppressWarnings("unchecked") public void cachedLdsResource_absent() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); // Add another watcher. - LdsResourceWatcher watcher = mock(LdsResourceWatcher.class); - xdsClient.watchLdsResource(LDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, watcher); verify(watcher).onResourceDoesNotExist(LDS_RESOURCE); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); @@ -735,35 +919,65 @@ public void cachedLdsResource_absent() { @Test public void ldsResourceUpdated() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); // Updated LDS response. call.sendResponse(LDS, testListenerRds, VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().rdsName()) - .isEqualTo(RDS_RESOURCE); + verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_2, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); assertThat(channelForCustomAuthority).isNull(); assertThat(channelForEmptyAuthority).isNull(); } + @Test + public void cancelResourceWatcherNotRemoveUrlSubscribers() { + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + verifyResourceMetadataRequested(LDS, LDS_RESOURCE); + + // Initial LDS response. + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), + LDS_RESOURCE + "1", ldsResourceWatcher); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), LDS_RESOURCE + "1", + ldsResourceWatcher); + + // Updated LDS response. + Any testListenerVhosts2 = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, + mf.buildRouteConfiguration("new", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); + call.sendResponse(LDS, testListenerVhosts2, VERSION_2, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); + verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts2, VERSION_2, + TIME_INCREMENT * 2); + } + @Test public void ldsResourceUpdated_withXdstpResourceName() { + BootstrapperImpl.enableFederation = true; String ldsResourceName = useProtocolV3() ? "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/listener1" : "xdstp://authority.xds.com/envoy.api.v2.Listener/listener1"; - DiscoveryRpcCall call = startResourceWatcher(LDS, ldsResourceName, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), ldsResourceName, + ldsResourceWatcher); assertThat(channelForCustomAuthority).isNotNull(); verifyResourceMetadataRequested(LDS, ldsResourceName); @@ -772,18 +986,19 @@ public void ldsResourceUpdated_withXdstpResourceName() { call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, ldsResourceName, VERSION_1, "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); verifyResourceMetadataAcked( LDS, ldsResourceName, testListenerVhosts, VERSION_1, TIME_INCREMENT); } @Test public void ldsResourceUpdated_withXdstpResourceName_withEmptyAuthority() { + BootstrapperImpl.enableFederation = true; String ldsResourceName = useProtocolV3() ? "xdstp:///envoy.config.listener.v3.Listener/listener1" : "xdstp:///envoy.api.v2.Listener/listener1"; - DiscoveryRpcCall call = startResourceWatcher(LDS, ldsResourceName, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), ldsResourceName, + ldsResourceWatcher); assertThat(channelForEmptyAuthority).isNotNull(); verifyResourceMetadataRequested(LDS, ldsResourceName); @@ -792,18 +1007,19 @@ public void ldsResourceUpdated_withXdstpResourceName_withEmptyAuthority() { call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, ldsResourceName, VERSION_1, "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); verifyResourceMetadataAcked( LDS, ldsResourceName, testListenerVhosts, VERSION_1, TIME_INCREMENT); } @Test public void ldsResourceUpdated_withXdstpResourceName_witUnorderedContextParams() { + BootstrapperImpl.enableFederation = true; String ldsResourceName = useProtocolV3() ? "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/listener1/a?bar=2&foo=1" : "xdstp://authority.xds.com/envoy.api.v2.Listener/listener1/a?bar=2&foo=1"; - DiscoveryRpcCall call = startResourceWatcher(LDS, ldsResourceName, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), ldsResourceName, + ldsResourceWatcher); assertThat(channelForCustomAuthority).isNotNull(); String ldsResourceNameWithUnorderedContextParams = useProtocolV3() @@ -819,10 +1035,12 @@ public void ldsResourceUpdated_withXdstpResourceName_witUnorderedContextParams() @Test public void ldsResourceUpdated_withXdstpResourceName_withWrongType() { + BootstrapperImpl.enableFederation = true; String ldsResourceName = useProtocolV3() ? "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/listener1" : "xdstp://authority.xds.com/envoy.api.v2.Listener/listener1"; - DiscoveryRpcCall call = startResourceWatcher(LDS, ldsResourceName, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), ldsResourceName, + ldsResourceWatcher); assertThat(channelForCustomAuthority).isNotNull(); String ldsResourceNameWithWrongType = @@ -839,10 +1057,12 @@ public void ldsResourceUpdated_withXdstpResourceName_withWrongType() { @Test public void rdsResourceUpdated_withXdstpResourceName_withWrongType() { + BootstrapperImpl.enableFederation = true; String rdsResourceName = useProtocolV3() ? "xdstp://authority.xds.com/envoy.config.route.v3.RouteConfiguration/route1" : "xdstp://authority.xds.com/envoy.api.v2.RouteConfiguration/route1"; - DiscoveryRpcCall call = startResourceWatcher(RDS, rdsResourceName, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + rdsResourceName, rdsResourceWatcher); assertThat(channelForCustomAuthority).isNotNull(); String rdsResourceNameWithWrongType = @@ -856,19 +1076,40 @@ public void rdsResourceUpdated_withXdstpResourceName_withWrongType() { "Unsupported resource name: " + rdsResourceNameWithWrongType + " for type: RDS")); } + @Test + public void rdsResourceUpdated_withXdstpResourceName_unknownAuthority() { + BootstrapperImpl.enableFederation = true; + String rdsResourceName = useProtocolV3() + ? "xdstp://unknown.example.com/envoy.config.route.v3.RouteConfiguration/route1" + : "xdstp://unknown.example.com/envoy.api.v2.RouteConfiguration/route1"; + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceName, + rdsResourceWatcher); + verify(rdsResourceWatcher).onError(errorCaptor.capture()); + Status error = errorCaptor.getValue(); + assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(error.getDescription()).isEqualTo( + "Wrong configuration: xds server does not exist for resource " + rdsResourceName); + assertThat(resourceDiscoveryCalls.size()).isEqualTo(0); + xdsClient.cancelXdsResourceWatch( + XdsRouteConfigureResource.getInstance(),rdsResourceName, rdsResourceWatcher); + assertThat(resourceDiscoveryCalls.size()).isEqualTo(0); + } + @Test public void cdsResourceUpdated_withXdstpResourceName_withWrongType() { + BootstrapperImpl.enableFederation = true; String cdsResourceName = useProtocolV3() ? "xdstp://authority.xds.com/envoy.config.cluster.v3.Cluster/cluster1" : "xdstp://authority.xds.com/envoy.api.v2.Cluster/cluster1"; - DiscoveryRpcCall call = startResourceWatcher(CDS, cdsResourceName, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), cdsResourceName, + cdsResourceWatcher); assertThat(channelForCustomAuthority).isNotNull(); String cdsResourceNameWithWrongType = "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/cluster1"; Any testClusterConfig = Any.pack(mf.buildEdsCluster( cdsResourceNameWithWrongType, null, "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null)); + "envoy.transport_sockets.tls", null, null)); call.sendResponse(CDS, testClusterConfig, VERSION_1, "0000"); call.verifyRequestNack( CDS, cdsResourceName, "", "0000", NODE, @@ -876,12 +1117,33 @@ public void cdsResourceUpdated_withXdstpResourceName_withWrongType() { "Unsupported resource name: " + cdsResourceNameWithWrongType + " for type: CDS")); } + @Test + public void cdsResourceUpdated_withXdstpResourceName_unknownAuthority() { + BootstrapperImpl.enableFederation = true; + String cdsResourceName = useProtocolV3() + ? "xdstp://unknown.example.com/envoy.config.cluster.v3.Cluster/cluster1" + : "xdstp://unknown.example.com/envoy.api.v2.Cluster/cluster1"; + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceName, + cdsResourceWatcher); + verify(cdsResourceWatcher).onError(errorCaptor.capture()); + Status error = errorCaptor.getValue(); + assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(error.getDescription()).isEqualTo( + "Wrong configuration: xds server does not exist for resource " + cdsResourceName); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(),cdsResourceName, + cdsResourceWatcher); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + } + @Test public void edsResourceUpdated_withXdstpResourceName_withWrongType() { + BootstrapperImpl.enableFederation = true; String edsResourceName = useProtocolV3() ? "xdstp://authority.xds.com/envoy.config.endpoint.v3.ClusterLoadAssignment/cluster1" : "xdstp://authority.xds.com/envoy.api.v2.ClusterLoadAssignment/cluster1"; - DiscoveryRpcCall call = startResourceWatcher(EDS, edsResourceName, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), edsResourceName, + edsResourceWatcher); assertThat(channelForCustomAuthority).isNotNull(); String edsResourceNameWithWrongType = @@ -899,10 +1161,30 @@ public void edsResourceUpdated_withXdstpResourceName_withWrongType() { "Unsupported resource name: " + edsResourceNameWithWrongType + " for type: EDS")); } + @Test + public void edsResourceUpdated_withXdstpResourceName_unknownAuthority() { + BootstrapperImpl.enableFederation = true; + String edsResourceName = useProtocolV3() + ? "xdstp://unknown.example.com/envoy.config.endpoint.v3.ClusterLoadAssignment/cluster1" + : "xdstp://unknown.example.com/envoy.api.v2.ClusterLoadAssignment/cluster1"; + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), + edsResourceName, edsResourceWatcher); + verify(edsResourceWatcher).onError(errorCaptor.capture()); + Status error = errorCaptor.getValue(); + assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(error.getDescription()).isEqualTo( + "Wrong configuration: xds server does not exist for resource " + edsResourceName); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + xdsClient.cancelXdsResourceWatch(XdsEndpointResource.getInstance(), + edsResourceName, edsResourceWatcher); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + } + @Test public void ldsResourceUpdate_withFaultInjection() { Assume.assumeTrue(useProtocolV3()); - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); Any listener = Any.pack( mf.buildListenerWithApiListener( LDS_RESOURCE, @@ -966,15 +1248,17 @@ public void ldsResourceUpdate_withFaultInjection() { @Test public void ldsResourceDeleted() { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + Assume.assumeFalse(ignoreResourceDeletion()); + + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -986,14 +1270,56 @@ public void ldsResourceDeleted() { verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } + /** + * When ignore_resource_deletion server feature is on, xDS client should keep the deleted listener + * on empty response, and resume the normal work when LDS contains the listener again. + * */ @Test + public void ldsResourceDeleted_ignoreResourceDeletion() { + Assume.assumeTrue(ignoreResourceDeletion()); + + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + verifyResourceMetadataRequested(LDS, LDS_RESOURCE); + + // Initial LDS response. + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + // Empty LDS response does not delete the listener. + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); + // The resource is still ACKED at VERSION_1 (no changes). + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + // onResourceDoesNotExist not called + verify(ldsResourceWatcher, never()).onResourceDoesNotExist(LDS_RESOURCE); + + // Next update is correct, and contains the listener again. + call.sendResponse(LDS, testListenerVhosts, VERSION_3, "0003"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_3, "0003", NODE); + verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + // LDS is now ACKEd at VERSION_3. + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_3, + TIME_INCREMENT * 3); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + verifyNoMoreInteractions(ldsResourceWatcher); + } + + @Test + @SuppressWarnings("unchecked") public void multipleLdsWatchers() { String ldsResourceTwo = "bar.googleapis.com"; - LdsResourceWatcher watcher1 = mock(LdsResourceWatcher.class); - LdsResourceWatcher watcher2 = mock(LdsResourceWatcher.class); - xdsClient.watchLdsResource(LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchLdsResource(ldsResourceTwo, watcher1); - xdsClient.watchLdsResource(ldsResourceTwo, watcher2); + ResourceWatcher watcher1 = mock(ResourceWatcher.class); + ResourceWatcher watcher2 = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, ImmutableList.of(LDS_RESOURCE, ldsResourceTwo), "", "", NODE); // Both LDS resources were requested. @@ -1001,7 +1327,7 @@ public void multipleLdsWatchers() { verifyResourceMetadataRequested(LDS, ldsResourceTwo); verifySubscribedResourcesMetadataSizes(2, 0, 0, 0); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(ldsResourceTwo); verify(watcher2).onResourceDoesNotExist(ldsResourceTwo); @@ -1011,19 +1337,16 @@ public void multipleLdsWatchers() { Any listenerTwo = Any.pack(mf.buildListenerWithApiListenerForRds(ldsResourceTwo, RDS_RESOURCE)); call.sendResponse(LDS, ImmutableList.of(testListenerVhosts, listenerTwo), VERSION_1, "0000"); - // ldsResourceWatcher called with listenerVhosts. + // ResourceWatcher called with listenerVhosts. verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); // watcher1 called with listenerTwo. verify(watcher1).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().rdsName()) - .isEqualTo(RDS_RESOURCE); + verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()).isNull(); // watcher2 called with listenerTwo. verify(watcher2).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().rdsName()) - .isEqualTo(RDS_RESOURCE); + verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()).isNull(); // Metadata of both listeners is stored. verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); @@ -1033,10 +1356,11 @@ public void multipleLdsWatchers() { @Test public void rdsResourceNotFound() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); Any routeConfig = Any.pack(mf.buildRouteConfiguration("route-bar.googleapis.com", mf.buildOpaqueVirtualHosts(2))); - call.sendResponse(ResourceType.RDS, routeConfig, VERSION_1, "0000"); + call.sendResponse(RDS, routeConfig, VERSION_1, "0000"); // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); @@ -1044,7 +1368,7 @@ public void rdsResourceNotFound() { verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); @@ -1053,7 +1377,8 @@ public void rdsResourceNotFound() { @Test public void rdsResponseErrorHandling_allResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); call.sendResponse(RDS, ImmutableList.of(FAILING_ANY, FAILING_ANY), VERSION_1, "0000"); @@ -1069,7 +1394,8 @@ public void rdsResponseErrorHandling_allResourcesFailedUnpack() { @Test public void rdsResponseErrorHandling_someResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); // Correct resource is in the middle to ensure processing continues on errors. @@ -1087,18 +1413,63 @@ public void rdsResponseErrorHandling_someResourcesFailedUnpack() { verify(rdsResourceWatcher).onChanged(any(RdsUpdate.class)); } + @Test + public void rdsResponseErrorHandling_nackWeightedSumZero() { + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); + verifyResourceMetadataRequested(RDS, RDS_RESOURCE); + + io.envoyproxy.envoy.config.route.v3.RouteAction routeAction = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setWeightedClusters( + WeightedCluster.newBuilder() + .addClusters( + WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-foo") + .setWeight(UInt32Value.newBuilder().setValue(0))) + .addClusters(WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-bar") + .setWeight(UInt32Value.newBuilder().setValue(0)))) + .build(); + io.envoyproxy.envoy.config.route.v3.Route route = + io.envoyproxy.envoy.config.route.v3.Route.newBuilder() + .setName("route-blade") + .setMatch( + io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder() + .setPath("/service/method")) + .setRoute(routeAction) + .build(); + + Any zeroWeightSum = Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, + Arrays.asList(mf.buildVirtualHost(Arrays.asList(route), ImmutableMap.of())))); + List resources = ImmutableList.of(zeroWeightSum); + call.sendResponse(RDS, resources, VERSION_1, "0000"); + + List errors = ImmutableList.of( + "RDS response RouteConfiguration \'route-configuration.googleapis.com\' validation error: " + + "RouteConfiguration contains invalid virtual host: Virtual host [do not care] " + + "contains invalid route : Route [route-blade] contains invalid RouteAction: " + + "Sum of cluster weights should be above 0."); + verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); + // The response is NACKed with the same error message. + call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); + verify(rdsResourceWatcher, never()).onChanged(any(RdsUpdate.class)); + } + /** * Tests a subscribed RDS resource transitioned to and from the invalid state. * * @see - * A40-csds-support.md. + * A40-csds-support.md */ @Test public void rdsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchRdsResource("A", rdsResourceWatcher); - xdsClient.watchRdsResource("B", rdsResourceWatcher); - xdsClient.watchRdsResource("C", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"A", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"B", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"C", rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(RDS, "A"); @@ -1153,43 +1524,63 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { @Test public void rdsResourceFound() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(VHOST_SIZE); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); + } + + @Test + public void wrappedRdsResource() { + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); + call.sendResponse(RDS, mf.buildWrappedResource(testRouteConfig), VERSION_1, "0000"); + + // Client sends an ACK RDS request. + call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); + verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); } @Test + @SuppressWarnings("unchecked") public void cachedRdsResource_data() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - RdsResourceWatcher watcher = mock(RdsResourceWatcher.class); - xdsClient.watchRdsResource(RDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, watcher); verify(watcher).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(VHOST_SIZE); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); } @Test + @SuppressWarnings("unchecked") public void cachedRdsResource_absent() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); // Add another watcher. - RdsResourceWatcher watcher = mock(RdsResourceWatcher.class); - xdsClient.watchRdsResource(RDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, watcher); verify(watcher).onResourceDoesNotExist(RDS_RESOURCE); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); @@ -1198,14 +1589,15 @@ public void cachedRdsResource_absent() { @Test public void rdsResourceUpdated() { - DiscoveryRpcCall call = startResourceWatcher(RDS, RDS_RESOURCE, rdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); // Initial RDS response. call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(VHOST_SIZE); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); // Updated RDS response. @@ -1224,8 +1616,10 @@ public void rdsResourceUpdated() { @Test public void rdsResourceDeletedByLdsApiListener() { - xdsClient.watchLdsResource(LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchRdsResource(RDS_RESOURCE, rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, + ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + rdsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); @@ -1233,25 +1627,27 @@ public void rdsResourceDeletedByLdsApiListener() { DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().rdsName()) - .isEqualTo(RDS_RESOURCE); + verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(VHOST_SIZE); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); + // The Listener is getting replaced configured with an RDS name, to the one configured with + // vhosts. Expect the RDS resources to be discarded. + // Note that this must work the same despite the ignore_resource_deletion feature is on. + // This happens because the Listener is getting replaced, and not deleted. call.sendResponse(LDS, testListenerVhosts, VERSION_2, "0001"); verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()) - .hasSize(VHOST_SIZE); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); - verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verifyNoMoreInteractions(rdsResourceWatcher); + verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( LDS, LDS_RESOURCE, testListenerVhosts, VERSION_2, TIME_INCREMENT * 3); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); @@ -1260,8 +1656,10 @@ public void rdsResourceDeletedByLdsApiListener() { @Test public void rdsResourcesDeletedByLdsTcpListener() { Assume.assumeTrue(useProtocolV3()); - xdsClient.watchLdsResource(LISTENER_RESOURCE, ldsResourceWatcher); - xdsClient.watchRdsResource(RDS_RESOURCE, rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); verifyResourceMetadataRequested(LDS, LISTENER_RESOURCE); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); @@ -1293,11 +1691,13 @@ public void rdsResourcesDeletedByLdsTcpListener() { // Simulates receiving the requested RDS resource. call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(VHOST_SIZE); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); // Simulates receiving an updated version of the requested LDS resource as a TCP listener // with a filter chain containing inlined RouteConfiguration. + // Note that this must work the same despite the ignore_resource_deletion feature is on. + // This happens because the Listener is getting replaced, and not deleted. hcmFilter = mf.buildHttpConnectionManagerFilter( null, mf.buildRouteConfiguration( @@ -1314,21 +1714,23 @@ public void rdsResourcesDeletedByLdsTcpListener() { parsedFilterChain = Iterables.getOnlyElement( ldsUpdateCaptor.getValue().listener().filterChains()); assertThat(parsedFilterChain.httpConnectionManager().virtualHosts()).hasSize(VHOST_SIZE); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); - verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); + verify(rdsResourceWatcher, never()).onResourceDoesNotExist(RDS_RESOURCE); + verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( LDS, LISTENER_RESOURCE, packedListener, VERSION_2, TIME_INCREMENT * 3); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); } @Test + @SuppressWarnings("unchecked") public void multipleRdsWatchers() { String rdsResourceTwo = "route-bar.googleapis.com"; - RdsResourceWatcher watcher1 = mock(RdsResourceWatcher.class); - RdsResourceWatcher watcher2 = mock(RdsResourceWatcher.class); - xdsClient.watchRdsResource(RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchRdsResource(rdsResourceTwo, watcher1); - xdsClient.watchRdsResource(rdsResourceTwo, watcher2); + ResourceWatcher watcher1 = mock(ResourceWatcher.class); + ResourceWatcher watcher2 = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(RDS, Arrays.asList(RDS_RESOURCE, rdsResourceTwo), "", "", NODE); // Both RDS resources were requested. @@ -1336,7 +1738,7 @@ public void multipleRdsWatchers() { verifyResourceMetadataRequested(RDS, rdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(rdsResourceTwo); verify(watcher2).onResourceDoesNotExist(rdsResourceTwo); @@ -1346,7 +1748,7 @@ public void multipleRdsWatchers() { call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(VHOST_SIZE); + verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); verifyNoMoreInteractions(watcher1, watcher2); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifyResourceMetadataDoesNotExist(RDS, rdsResourceTwo); @@ -1367,22 +1769,23 @@ public void multipleRdsWatchers() { @Test public void cdsResourceNotFound() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); List clusters = ImmutableList.of( Any.pack(mf.buildEdsCluster("cluster-bar.googleapis.com", null, "round_robin", null, - null, false, null, "envoy.transport_sockets.tls", null)), + null, false, null, "envoy.transport_sockets.tls", null, null)), Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, - null, false, null, "envoy.transport_sockets.tls", null))); + null, false, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verifyNoInteractions(cdsResourceWatcher); + verifyNoInteractions(ldsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); @@ -1391,7 +1794,8 @@ public void cdsResourceNotFound() { @Test public void cdsResponseErrorHandling_allResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); call.sendResponse(CDS, ImmutableList.of(FAILING_ANY, FAILING_ANY), VERSION_1, "0000"); @@ -1402,12 +1806,13 @@ public void cdsResponseErrorHandling_allResourcesFailedUnpack() { call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of( "CDS response Resource index 0 - can't decode Cluster: ", "CDS response Resource index 1 - can't decode Cluster: ")); - verifyNoInteractions(cdsResourceWatcher); + verifyNoInteractions(ldsResourceWatcher); } @Test public void cdsResponseErrorHandling_someResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); // Correct resource is in the middle to ensure processing continues on errors. @@ -1430,14 +1835,14 @@ public void cdsResponseErrorHandling_someResourcesFailedUnpack() { * Tests a subscribed CDS resource transitioned to and from the invalid state. * * @see - * A40-csds-support.md. + * A40-csds-support.md */ @Test public void cdsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchCdsResource("A", cdsResourceWatcher); - xdsClient.watchCdsResource("B", cdsResourceWatcher); - xdsClient.watchCdsResource("C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"C", cdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(CDS, "A"); @@ -1448,13 +1853,13 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // CDS -> {A, B, C}, version 1 ImmutableMap resourcesV1 = ImmutableMap.of( "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 @@ -1467,7 +1872,7 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // Failed to parse endpoint B ImmutableMap resourcesV2 = ImmutableMap.of( "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "B", Any.pack(mf.buildClusterInvalid("B"))); call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); @@ -1478,21 +1883,31 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); - verifyResourceMetadataDoesNotExist(CDS, "C"); + if (!ignoreResourceDeletion()) { + verifyResourceMetadataDoesNotExist(CDS, "C"); + } else { + // When resource deletion is disabled, {C} stays ACKed in the previous version VERSION_1. + verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + } call.verifyRequestNack(CDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); // CDS -> {B, C} version 3 ImmutableMap resourcesV3 = ImmutableMap.of( "B", Any.pack(mf.buildEdsCluster("B", "B.3", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "C", Any.pack(mf.buildEdsCluster("C", "C.3", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null ))); call.sendResponse(CDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> does not exit // {B, C} -> ACK, version 3 - verifyResourceMetadataDoesNotExist(CDS, "A"); + if (!ignoreResourceDeletion()) { + verifyResourceMetadataDoesNotExist(CDS, "A"); + } else { + // When resource deletion is disabled, {A} stays ACKed in the previous version VERSION_2. + verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); + } verifyResourceMetadataAcked(CDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(CDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); call.verifyRequest(CDS, subscribedResourceNames, VERSION_3, "0002", NODE); @@ -1501,12 +1916,12 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { @Test public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscription() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchCdsResource("A", cdsResourceWatcher); - xdsClient.watchEdsResource("A.1", edsResourceWatcher); - xdsClient.watchCdsResource("B", cdsResourceWatcher); - xdsClient.watchEdsResource("B.1", edsResourceWatcher); - xdsClient.watchCdsResource("C", cdsResourceWatcher); - xdsClient.watchEdsResource("C.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"A.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"B.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"C.1", edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(CDS, "A"); @@ -1520,13 +1935,13 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti // CDS -> {A, B, C}, version 1 ImmutableMap resourcesV1 = ImmutableMap.of( "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 @@ -1552,7 +1967,7 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti // Failed to parse endpoint B ImmutableMap resourcesV2 = ImmutableMap.of( "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), "B", Any.pack(mf.buildClusterInvalid("B"))); call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); @@ -1564,32 +1979,49 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti verifyResourceMetadataNacked( CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, errorsV2); - verifyResourceMetadataDoesNotExist(CDS, "C"); + if (!ignoreResourceDeletion()) { + verifyResourceMetadataDoesNotExist(CDS, "C"); + } else { + // When resource deletion is disabled, {C} stays ACKed in the previous version VERSION_1. + verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + } call.verifyRequestNack(CDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); - // {A.1} -> does not exist + // {A.1} -> version 1 // {B.1} -> version 1 - // {C.1} -> does not exist - verifyResourceMetadataDoesNotExist(EDS, "A.1"); + // {C.1} -> does not exist because {C} does not exist + verifyResourceMetadataAcked(EDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); - verifyResourceMetadataDoesNotExist(EDS, "C.1"); + // Verify {C.1} stays in the previous version VERSION_1. {C1} deleted or not does not matter. + verifyResourceMetadataAcked(EDS, "C.1", resourcesV11.get("C.1"), VERSION_1, + TIME_INCREMENT * 2); } @Test public void cdsResourceFound() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), + CDS_RESOURCE, cdsResourceWatcher); call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isNull(); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); - assertThat(cdsUpdate.lrsServerInfo()).isNull(); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + + @Test + public void wrappedCdsResource() { + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + call.sendResponse(CDS, mf.buildWrappedResource(testClusterRoundRobin), VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); + verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -1598,13 +2030,14 @@ public void cdsResourceFound() { @Test public void cdsResourceFound_leastRequestLbPolicy() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); Message leastRequestConfig = mf.buildLeastRequestLbConfig(3); Any clusterRingHash = Any.pack( mf.buildEdsCluster(CDS_RESOURCE, null, "least_request_experimental", null, - leastRequestConfig, false, null, "envoy.transport_sockets.tls", null + leastRequestConfig, false, null, "envoy.transport_sockets.tls", null, null )); - call.sendResponse(ResourceType.CDS, clusterRingHash, VERSION_1, "0000"); + call.sendResponse(CDS, clusterRingHash, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); @@ -1613,8 +2046,12 @@ public void cdsResourceFound_leastRequestLbPolicy() { assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.LEAST_REQUEST); - assertThat(cdsUpdate.choiceCount()).isEqualTo(3); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("least_request_experimental"); + assertThat(childConfigs.get(0).getRawConfigValue().get("choiceCount")).isEqualTo(3); assertThat(cdsUpdate.lrsServerInfo()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); @@ -1625,13 +2062,14 @@ public void cdsResourceFound_leastRequestLbPolicy() { @Test public void cdsResourceFound_ringHashLbPolicy() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); Message ringHashConfig = mf.buildRingHashLbConfig("xx_hash", 10L, 100L); Any clusterRingHash = Any.pack( mf.buildEdsCluster(CDS_RESOURCE, null, "ring_hash_experimental", ringHashConfig, null, - false, null, "envoy.transport_sockets.tls", null + false, null, "envoy.transport_sockets.tls", null, null )); - call.sendResponse(ResourceType.CDS, clusterRingHash, VERSION_1, "0000"); + call.sendResponse(CDS, clusterRingHash, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); @@ -1640,9 +2078,12 @@ public void cdsResourceFound_ringHashLbPolicy() { assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.RING_HASH); - assertThat(cdsUpdate.minRingSize()).isEqualTo(10L); - assertThat(cdsUpdate.maxRingSize()).isEqualTo(100L); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("ring_hash_experimental"); + assertThat(JsonUtil.getNumberAsLong(lbConfig.getRawConfigValue(), "minRingSize")).isEqualTo( + 10L); + assertThat(JsonUtil.getNumberAsLong(lbConfig.getRawConfigValue(), "maxRingSize")).isEqualTo( + 100L); assertThat(cdsUpdate.lrsServerInfo()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); @@ -1653,7 +2094,8 @@ public void cdsResourceFound_ringHashLbPolicy() { @Test public void cdsResponseWithAggregateCluster() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); List candidates = Arrays.asList( "cluster1.googleapis.com", "cluster2.googleapis.com", "cluster3.googleapis.com"); Any clusterAggregate = @@ -1666,7 +2108,11 @@ public void cdsResponseWithAggregateCluster() { CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.AGGREGATE); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); assertThat(cdsUpdate.prioritizedClusterNames()).containsExactlyElementsIn(candidates).inOrder(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterAggregate, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -1674,10 +2120,11 @@ public void cdsResponseWithAggregateCluster() { @Test public void cdsResponseWithCircuitBreakers() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); Any clusterCircuitBreakers = Any.pack( mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", mf.buildCircuitBreakers(50, 200))); + "envoy.transport_sockets.tls", mf.buildCircuitBreakers(50, 200), null)); call.sendResponse(CDS, clusterCircuitBreakers, VERSION_1, "0000"); // Client sent an ACK CDS request. @@ -1687,7 +2134,11 @@ public void cdsResponseWithCircuitBreakers() { assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); assertThat(cdsUpdate.lrsServerInfo()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isEqualTo(200L); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); @@ -1703,25 +2154,27 @@ public void cdsResponseWithCircuitBreakers() { @SuppressWarnings("deprecation") public void cdsResponseWithUpstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); // Management server sends back CDS response with UpstreamTlsContext. Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", null, null, true, mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), - "envoy.transport_sockets.tls", null)); + "envoy.transport_sockets.tls", null, null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, - false, null, "envoy.transport_sockets.tls", null))); + false, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); + verify(cdsResourceWatcher, times(1)) + .onChanged(cdsUpdateCaptor.capture()); CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); CommonTlsContext.CertificateProviderInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getCombinedValidationContext() @@ -1739,20 +2192,21 @@ public void cdsResponseWithUpstreamTlsContext() { @SuppressWarnings("deprecation") public void cdsResponseWithNewUpstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); // Management server sends back CDS response with UpstreamTlsContext. Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", null, null,true, mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), - "envoy.transport_sockets.tls", null)); + "envoy.transport_sockets.tls", null, null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, - false, null, "envoy.transport_sockets.tls", null))); + false, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // Client sent an ACK CDS request. @@ -1774,26 +2228,245 @@ public void cdsResponseWithNewUpstreamTlsContext() { @Test public void cdsResponseErrorHandling_badUpstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); // Management server sends back CDS response with UpstreamTlsContext. List clusters = ImmutableList.of(Any .pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", null, null, true, - mf.buildUpstreamTlsContext(null, null), "envoy.transport_sockets.tls", null))); + mf.buildUpstreamTlsContext(null, null), "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " - + "io.grpc.xds.ClientXdsClient$ResourceInvalidException: " + + "io.grpc.xds.XdsClientImpl$ResourceInvalidException: " + "ca_certificate_provider_instance is required in upstream-tls-context"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(cdsResourceWatcher).onError(captor.capture()); - Status errorStatus = captor.getValue(); - assertThat(errorStatus.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); - assertThat(errorStatus.getDescription()).isEqualTo(errorMsg); + verify(cdsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + } + + /** + * CDS response containing OutlierDetection for a cluster. + */ + @Test + @SuppressWarnings("deprecation") + public void cdsResponseWithOutlierDetection() { + Assume.assumeTrue(useProtocolV3()); + XdsClusterResource.enableOutlierDetection = true; + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + + OutlierDetection outlierDetectionXds = OutlierDetection.newBuilder() + .setInterval(Durations.fromNanos(100)) + .setBaseEjectionTime(Durations.fromNanos(100)) + .setMaxEjectionTime(Durations.fromNanos(100)) + .setMaxEjectionPercent(UInt32Value.of(100)) + .setSuccessRateStdevFactor(UInt32Value.of(100)) + .setEnforcingSuccessRate(UInt32Value.of(100)) + .setSuccessRateMinimumHosts(UInt32Value.of(100)) + .setSuccessRateRequestVolume(UInt32Value.of(100)) + .setFailurePercentageThreshold(UInt32Value.of(100)) + .setEnforcingFailurePercentage(UInt32Value.of(100)) + .setFailurePercentageMinimumHosts(UInt32Value.of(100)) + .setFailurePercentageRequestVolume(UInt32Value.of(100)).build(); + + // Management server sends back CDS response with UpstreamTlsContext. + Any clusterEds = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", + null, null, true, + mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null, outlierDetectionXds)); + List clusters = ImmutableList.of( + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + clusterEds, + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, + false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + + // The outlier detection config in CdsUpdate should match what we get from xDS. + EnvoyServerProtoData.OutlierDetection outlierDetection = cdsUpdate.outlierDetection(); + assertThat(outlierDetection).isNotNull(); + assertThat(outlierDetection.intervalNanos()).isEqualTo(100); + assertThat(outlierDetection.baseEjectionTimeNanos()).isEqualTo(100); + assertThat(outlierDetection.maxEjectionTimeNanos()).isEqualTo(100); + assertThat(outlierDetection.maxEjectionPercent()).isEqualTo(100); + + SuccessRateEjection successRateEjection = outlierDetection.successRateEjection(); + assertThat(successRateEjection).isNotNull(); + assertThat(successRateEjection.stdevFactor()).isEqualTo(100); + assertThat(successRateEjection.enforcementPercentage()).isEqualTo(100); + assertThat(successRateEjection.minimumHosts()).isEqualTo(100); + assertThat(successRateEjection.requestVolume()).isEqualTo(100); + + FailurePercentageEjection failurePercentageEjection + = outlierDetection.failurePercentageEjection(); + assertThat(failurePercentageEjection).isNotNull(); + assertThat(failurePercentageEjection.threshold()).isEqualTo(100); + assertThat(failurePercentageEjection.enforcementPercentage()).isEqualTo(100); + assertThat(failurePercentageEjection.minimumHosts()).isEqualTo(100); + assertThat(failurePercentageEjection.requestVolume()).isEqualTo(100); + + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + + /** + * CDS response containing OutlierDetection for a cluster, but support has not been enabled. + */ + @Test + @SuppressWarnings("deprecation") + public void cdsResponseWithOutlierDetection_supportDisabled() { + Assume.assumeTrue(useProtocolV3()); + XdsClusterResource.enableOutlierDetection = false; + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + + OutlierDetection outlierDetectionXds = OutlierDetection.newBuilder() + .setInterval(Durations.fromNanos(100)).build(); + + // Management server sends back CDS response with UpstreamTlsContext. + Any clusterEds = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", + null, null, true, + mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null, outlierDetectionXds)); + List clusters = ImmutableList.of( + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + clusterEds, + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, + false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + + assertThat(cdsUpdate.outlierDetection()).isNull(); + + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + + /** + * CDS response containing OutlierDetection for a cluster. + */ + @Test + @SuppressWarnings("deprecation") + public void cdsResponseWithInvalidOutlierDetectionNacks() { + Assume.assumeTrue(useProtocolV3()); + XdsClusterResource.enableOutlierDetection = true; + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + + OutlierDetection outlierDetectionXds = OutlierDetection.newBuilder() + .setMaxEjectionPercent(UInt32Value.of(101)).build(); + + // Management server sends back CDS response with UpstreamTlsContext. + Any clusterEds = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", + null, null, true, + mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null, outlierDetectionXds)); + List clusters = ImmutableList.of( + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + clusterEds, + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, + false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + + String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + + "Cluster cluster.googleapis.com: malformed outlier_detection: " + + "io.grpc.xds.XdsClientImpl$ResourceInvalidException: outlier_detection " + + "max_ejection_percent is > 100"; + call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); + verify(cdsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_invalidInterval() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setInterval(Duration.newBuilder().setSeconds(Long.MAX_VALUE)) + .build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_negativeInterval() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setInterval(Duration.newBuilder().setSeconds(-1)) + .build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_invalidBaseEjectionTime() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder() + .setBaseEjectionTime(Duration.newBuilder().setSeconds(Long.MAX_VALUE)) + .build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_negativeBaseEjectionTime() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setBaseEjectionTime(Duration.newBuilder().setSeconds(-1)) + .build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_invalidMaxEjectionTime() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder() + .setMaxEjectionTime(Duration.newBuilder().setSeconds(Long.MAX_VALUE)) + .build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_negativeMaxEjectionTime() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setMaxEjectionTime(Duration.newBuilder().setSeconds(-1)) + .build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_maxEjectionPercentTooHigh() throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setMaxEjectionPercent(UInt32Value.of(101)).build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_enforcingSuccessRateTooHigh() + throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setEnforcingSuccessRate(UInt32Value.of(101)).build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_failurePercentageThresholdTooHigh() + throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setFailurePercentageThreshold(UInt32Value.of(101)).build()); + } + + @Test(expected = ResourceInvalidException.class) + public void validateOutlierDetection_enforcingFailurePercentageTooHigh() + throws ResourceInvalidException { + XdsClusterResource.validateOutlierDetection( + OutlierDetection.newBuilder().setEnforcingFailurePercentage(UInt32Value.of(101)).build()); } /** @@ -1802,45 +2475,39 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { @Test public void cdsResponseErrorHandling_badTransportSocketName() { Assume.assumeTrue(useProtocolV3()); - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); // Management server sends back CDS response with UpstreamTlsContext. List clusters = ImmutableList.of(Any .pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", null, null, true, - mf.buildUpstreamTlsContext("secret1", "cert1"), "envoy.transport_sockets.bad", null))); + mf.buildUpstreamTlsContext("secret1", "cert1"), "envoy.transport_sockets.bad", null, + null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + "transport-socket with name envoy.transport_sockets.bad not supported."; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(cdsResourceWatcher).onError(captor.capture()); - Status errorStatus = captor.getValue(); - assertThat(errorStatus.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); - assertThat(errorStatus.getDescription()).isEqualTo(errorMsg); + verify(cdsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); } @Test + @SuppressWarnings("unchecked") public void cachedCdsResource_data() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); // Client sends an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - CdsResourceWatcher watcher = mock(CdsResourceWatcher.class); - xdsClient.watchCdsResource(CDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, watcher); verify(watcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isNull(); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); - assertThat(cdsUpdate.lrsServerInfo()).isNull(); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -1849,12 +2516,14 @@ public void cachedCdsResource_data() { } @Test + @SuppressWarnings("unchecked") public void cachedCdsResource_absent() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); - CdsResourceWatcher watcher = mock(CdsResourceWatcher.class); - xdsClient.watchCdsResource(CDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, watcher); verify(watcher).onResourceDoesNotExist(CDS_RESOURCE); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); @@ -1863,7 +2532,8 @@ public void cachedCdsResource_absent() { @Test public void cdsResourceUpdated() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); // Initial CDS response. @@ -1879,7 +2549,11 @@ public void cdsResourceUpdated() { assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); assertThat(cdsUpdate.lrsServerInfo()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); @@ -1889,7 +2563,7 @@ public void cdsResourceUpdated() { String edsService = "eds-service-bar.googleapis.com"; Any clusterEds = Any.pack( mf.buildEdsCluster(CDS_RESOURCE, edsService, "round_robin", null, null, true, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )); call.sendResponse(CDS, clusterEds, VERSION_2, "0001"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); @@ -1898,31 +2572,79 @@ public void cdsResourceUpdated() { assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(lrsServerInfo); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); + assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_2, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + // Assures that CDS updates identical to the current config are ignored. + @Test + public void cdsResourceUpdatedWithDuplicate() { + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + + String edsService = "eds-service-bar.googleapis.com"; + String transportSocketName = "envoy.transport_sockets.tls"; + Any roundRobinConfig = Any.pack( + mf.buildEdsCluster(CDS_RESOURCE, edsService, "round_robin", null, null, true, null, + transportSocketName, null, null + )); + Any ringHashConfig = Any.pack( + mf.buildEdsCluster(CDS_RESOURCE, edsService, "ring_hash_experimental", + mf.buildRingHashLbConfig("xx_hash", 1, 2), null, true, null, + transportSocketName, null, null + )); + Any leastRequestConfig = Any.pack( + mf.buildEdsCluster(CDS_RESOURCE, edsService, "least_request_experimental", + null, mf.buildLeastRequestLbConfig(2), true, null, + transportSocketName, null, null + )); + + // Configure with round robin, the update should be sent to the watcher. + call.sendResponse(CDS, roundRobinConfig, VERSION_2, "0001"); + verify(cdsResourceWatcher, times(1)).onChanged(isA(CdsUpdate.class)); + + // Second update is identical, watcher should not get an additional update. + call.sendResponse(CDS, roundRobinConfig, VERSION_2, "0002"); + verify(cdsResourceWatcher, times(1)).onChanged(isA(CdsUpdate.class)); + + // Now we switch to ring hash so the watcher should be notified. + call.sendResponse(CDS, ringHashConfig, VERSION_2, "0003"); + verify(cdsResourceWatcher, times(2)).onChanged(isA(CdsUpdate.class)); + + // Second update to ring hash should not result in watcher being notified. + call.sendResponse(CDS, ringHashConfig, VERSION_2, "0004"); + verify(cdsResourceWatcher, times(2)).onChanged(isA(CdsUpdate.class)); + + // Now we switch to least request so the watcher should be notified. + call.sendResponse(CDS, leastRequestConfig, VERSION_2, "0005"); + verify(cdsResourceWatcher, times(3)).onChanged(isA(CdsUpdate.class)); + + // Second update to least request should not result in watcher being notified. + call.sendResponse(CDS, leastRequestConfig, VERSION_2, "0006"); + verify(cdsResourceWatcher, times(3)).onChanged(isA(CdsUpdate.class)); + } + @Test public void cdsResourceDeleted() { - DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + Assume.assumeFalse(ignoreResourceDeletion()); + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); // Initial CDS response. call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isNull(); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); - assertThat(cdsUpdate.lrsServerInfo()).isNull(); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -1935,21 +2657,65 @@ public void cdsResourceDeleted() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + /** + * When ignore_resource_deletion server feature is on, xDS client should keep the deleted cluster + * on empty response, and resume the normal work when CDS contains the cluster again. + * */ + @Test + public void cdsResourceDeleted_ignoreResourceDeletion() { + Assume.assumeTrue(ignoreResourceDeletion()); + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + verifyResourceMetadataRequested(CDS, CDS_RESOURCE); + + // Initial CDS response. + call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); + verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + // Empty LDS response does not delete the cluster. + call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); + + // The resource is still ACKED at VERSION_1 (no changes). + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + // onResourceDoesNotExist must not be called. + verify(ldsResourceWatcher, never()).onResourceDoesNotExist(CDS_RESOURCE); + + // Next update is correct, and contains the cluster again. + call.sendResponse(CDS, testClusterRoundRobin, VERSION_3, "0003"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_3, "0003", NODE); + verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); + verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_3, + TIME_INCREMENT * 3); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + verifyNoMoreInteractions(ldsResourceWatcher); + } + @Test + @SuppressWarnings("unchecked") public void multipleCdsWatchers() { String cdsResourceTwo = "cluster-bar.googleapis.com"; - CdsResourceWatcher watcher1 = mock(CdsResourceWatcher.class); - CdsResourceWatcher watcher2 = mock(CdsResourceWatcher.class); - xdsClient.watchCdsResource(CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchCdsResource(cdsResourceTwo, watcher1); - xdsClient.watchCdsResource(cdsResourceTwo, watcher2); + ResourceWatcher watcher1 = mock(ResourceWatcher.class); + ResourceWatcher watcher2 = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(CDS, Arrays.asList(CDS_RESOURCE, cdsResourceTwo), "", "", NODE); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifyResourceMetadataRequested(CDS, cdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(cdsResourceTwo); verify(watcher2).onResourceDoesNotExist(cdsResourceTwo); @@ -1964,14 +2730,18 @@ public void multipleCdsWatchers() { Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, dnsHostAddr, dnsHostPort, "round_robin", null, null, false, null, null)), Any.pack(mf.buildEdsCluster(cdsResourceTwo, edsService, "round_robin", null, null, true, - null, "envoy.transport_sockets.tls", null))); + null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); assertThat(cdsUpdate.lrsServerInfo()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); @@ -1980,8 +2750,12 @@ public void multipleCdsWatchers() { assertThat(cdsUpdate.clusterName()).isEqualTo(cdsResourceTwo); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(lrsServerInfo); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); + assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); verify(watcher2).onChanged(cdsUpdateCaptor.capture()); @@ -1989,8 +2763,12 @@ public void multipleCdsWatchers() { assertThat(cdsUpdate.clusterName()).isEqualTo(cdsResourceTwo); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(lrsServerInfo); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); + assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); assertThat(cdsUpdate.upstreamTlsContext()).isNull(); // Metadata of both clusters is stored. @@ -2001,7 +2779,8 @@ public void multipleCdsWatchers() { @Test public void edsResourceNotFound() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); Any clusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment( "cluster-bar.googleapis.com", ImmutableList.of(lbEndpointHealthy), @@ -2014,7 +2793,7 @@ public void edsResourceNotFound() { verifyResourceMetadataRequested(EDS, EDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); assertThat(fakeClock.getPendingTasks(EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); @@ -2023,7 +2802,8 @@ public void edsResourceNotFound() { @Test public void edsResponseErrorHandling_allResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); call.sendResponse(EDS, ImmutableList.of(FAILING_ANY, FAILING_ANY), VERSION_1, "0000"); @@ -2039,7 +2819,8 @@ public void edsResponseErrorHandling_allResourcesFailedUnpack() { @Test public void edsResponseErrorHandling_someResourcesFailedUnpack() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); // Correct resource is in the middle to ensure processing continues on errors. @@ -2064,14 +2845,14 @@ public void edsResponseErrorHandling_someResourcesFailedUnpack() { * Tests a subscribed EDS resource transitioned to and from the invalid state. * * @see - * A40-csds-support.md. + * A40-csds-support.md */ @Test public void edsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchEdsResource("A", edsResourceWatcher); - xdsClient.watchEdsResource("B", edsResourceWatcher); - xdsClient.watchEdsResource("C", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"A", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"B", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"C", edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(EDS, "A"); @@ -2129,30 +2910,48 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { @Test public void edsResourceFound() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); // Client sent an ACK EDS request. call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - validateTestClusterLoadAssigment(edsUpdateCaptor.getValue()); + validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); + } + + @Test + public void wrappedEdsResourceFound() { + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); + call.sendResponse(EDS, mf.buildWrappedResource(testClusterLoadAssignment), VERSION_1, "0000"); + + // Client sent an ACK EDS request. + call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); + verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); + validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); } @Test + @SuppressWarnings("unchecked") public void cachedEdsResource_data() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); // Client sent an ACK EDS request. call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); // Add another watcher. - EdsResourceWatcher watcher = mock(EdsResourceWatcher.class); - xdsClient.watchEdsResource(EDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); verify(watcher).onChanged(edsUpdateCaptor.capture()); - validateTestClusterLoadAssigment(edsUpdateCaptor.getValue()); + validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -2160,12 +2959,14 @@ public void cachedEdsResource_data() { } @Test + @SuppressWarnings("unchecked") public void cachedEdsResource_absent() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); - EdsResourceWatcher watcher = mock(EdsResourceWatcher.class); - xdsClient.watchEdsResource(EDS_RESOURCE, watcher); + ResourceWatcher watcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); verify(watcher).onResourceDoesNotExist(EDS_RESOURCE); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); @@ -2174,7 +2975,8 @@ public void cachedEdsResource_absent() { @Test public void edsResourceUpdated() { - DiscoveryRpcCall call = startResourceWatcher(EDS, EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); // Initial EDS response. @@ -2182,7 +2984,7 @@ public void edsResourceUpdated() { call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); - validateTestClusterLoadAssigment(edsUpdate); + validateGoldenClusterLoadAssignment(edsUpdate); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -2209,14 +3011,39 @@ public void edsResourceUpdated() { } @Test + public void edsDuplicateLocalityInTheSamePriority() { + DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, + edsResourceWatcher); + verifyResourceMetadataRequested(EDS, EDS_RESOURCE); + + // Updated EDS response. + Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, + ImmutableList.of( + mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 1), + mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", + mf.buildLbEndpoint("172.44.2.3", 8080, "healthy", 10), 2, 1) + ), + ImmutableList.of())); + call.sendResponse(EDS, updatedClusterLoadAssignment, "0", "0001"); + String errorMsg = "EDS response ClusterLoadAssignment" + + " \'cluster-load-assignment.googleapis.com\' " + + "validation error: ClusterLoadAssignment has duplicate " + + "locality:Locality{region=region2, zone=zone2, subZone=subzone2} for priority:1"; + call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0001", NODE, ImmutableList.of( + errorMsg)); + } + + @Test + @SuppressWarnings("unchecked") public void edsResourceDeletedByCds() { String resource = "backend-service.googleapis.com"; - CdsResourceWatcher cdsWatcher = mock(CdsResourceWatcher.class); - EdsResourceWatcher edsWatcher = mock(EdsResourceWatcher.class); - xdsClient.watchCdsResource(resource, cdsWatcher); - xdsClient.watchEdsResource(resource, edsWatcher); - xdsClient.watchCdsResource(CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchEdsResource(EDS_RESOURCE, edsResourceWatcher); + ResourceWatcher cdsWatcher = mock(ResourceWatcher.class); + ResourceWatcher edsWatcher = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),resource, cdsWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),resource, edsWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifyResourceMetadataRequested(CDS, resource); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); @@ -2226,15 +3053,15 @@ public void edsResourceDeletedByCds() { DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); List clusters = ImmutableList.of( Any.pack(mf.buildEdsCluster(resource, null, "round_robin", null, null, true, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null )), Any.pack(mf.buildEdsCluster(CDS_RESOURCE, EDS_RESOURCE, "round_robin", null, null, false, - null, "envoy.transport_sockets.tls", null))); + null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); verify(cdsWatcher).onChanged(cdsUpdateCaptor.capture()); CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); assertThat(cdsUpdate.edsServiceName()).isEqualTo(null); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(lrsServerInfo); + assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); cdsUpdate = cdsUpdateCaptor.getValue(); assertThat(cdsUpdate.edsServiceName()).isEqualTo(EDS_RESOURCE); @@ -2275,16 +3102,20 @@ public void edsResourceDeletedByCds() { clusters = ImmutableList.of( Any.pack(mf.buildEdsCluster(resource, null, "round_robin", null, null, true, null, - "envoy.transport_sockets.tls", null)), // no change + "envoy.transport_sockets.tls", null, null)), // no change Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, null, false, null, - "envoy.transport_sockets.tls", null + "envoy.transport_sockets.tls", null, null ))); call.sendResponse(CDS, clusters, VERSION_2, "0001"); verify(cdsResourceWatcher, times(2)).onChanged(cdsUpdateCaptor.capture()); assertThat(cdsUpdateCaptor.getValue().edsServiceName()).isNull(); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); + // Note that the endpoint must be deleted even if the ignore_resource_deletion feature. + // This happens because the cluster CDS_RESOURCE is getting replaced, and not deleted. + verify(edsResourceWatcher, never()).onResourceDoesNotExist(EDS_RESOURCE); + verify(edsResourceWatcher, never()).onResourceDoesNotExist(resource); verifyNoMoreInteractions(cdsWatcher, edsWatcher); - verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); + verifyResourceMetadataAcked( + EDS, EDS_RESOURCE, clusterLoadAssignments.get(0), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( EDS, resource, clusterLoadAssignments.get(1), VERSION_1, TIME_INCREMENT * 2); // no change verifyResourceMetadataAcked(CDS, resource, clusters.get(0), VERSION_2, TIME_INCREMENT * 3); @@ -2293,20 +3124,21 @@ public void edsResourceDeletedByCds() { } @Test + @SuppressWarnings("unchecked") public void multipleEdsWatchers() { String edsResourceTwo = "cluster-load-assignment-bar.googleapis.com"; - EdsResourceWatcher watcher1 = mock(EdsResourceWatcher.class); - EdsResourceWatcher watcher2 = mock(EdsResourceWatcher.class); - xdsClient.watchEdsResource(EDS_RESOURCE, edsResourceWatcher); - xdsClient.watchEdsResource(edsResourceTwo, watcher1); - xdsClient.watchEdsResource(edsResourceTwo, watcher2); + ResourceWatcher watcher1 = mock(ResourceWatcher.class); + ResourceWatcher watcher2 = mock(ResourceWatcher.class); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),edsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),edsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(EDS, Arrays.asList(EDS_RESOURCE, edsResourceTwo), "", "", NODE); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); verifyResourceMetadataRequested(EDS, edsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(edsResourceTwo); verify(watcher2).onResourceDoesNotExist(edsResourceTwo); @@ -2317,7 +3149,7 @@ public void multipleEdsWatchers() { call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); - validateTestClusterLoadAssigment(edsUpdate); + validateGoldenClusterLoadAssignment(edsUpdate); verifyNoMoreInteractions(watcher1, watcher2); verifyResourceMetadataAcked( EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -2366,7 +3198,8 @@ public void useIndependentRpcContext() { CancellableContext cancellableContext = Context.current().withCancellation(); Context prevContext = cancellableContext.attach(); try { - DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); // The inbound RPC finishes and closes its context. The outbound RPC's control plane RPC // should not be impacted. @@ -2383,10 +3216,11 @@ public void useIndependentRpcContext() { @Test public void streamClosedAndRetryWithBackoff() { InOrder inOrder = Mockito.inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - xdsClient.watchLdsResource(LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchRdsResource(RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchCdsResource(CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchEdsResource(EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + rdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "", "", NODE); @@ -2395,14 +3229,15 @@ public void streamClosedAndRetryWithBackoff() { // Management server closes the RPC stream with an error. call.sendError(Status.UNKNOWN.asException()); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNKNOWN); + verify(ldsResourceWatcher, Mockito.timeout(1000).times(1)) + .onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); verify(rdsResourceWatcher).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNKNOWN); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); verify(cdsResourceWatcher).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNKNOWN); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); verify(edsResourceWatcher).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNKNOWN); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); // Retry after backoff. inOrder.verify(backoffPolicyProvider).get(); @@ -2418,15 +3253,16 @@ public void streamClosedAndRetryWithBackoff() { call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); // Management server becomes unreachable. - call.sendError(Status.UNAVAILABLE.asException()); + String errorMsg = "my fault"; + call.sendError(Status.UNAVAILABLE.withDescription(errorMsg).asException()); verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); verify(cdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); verify(edsResourceWatcher, times(2)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); // Retry after backoff. inOrder.verify(backoffPolicy1).nextBackoffNanos(); @@ -2452,18 +3288,20 @@ public void streamClosedAndRetryWithBackoff() { call.verifyRequest(RDS, RDS_RESOURCE, "5", "6764", NODE); call.sendError(Status.DEADLINE_EXCEEDED.asException()); - verify(ldsResourceWatcher, times(3)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); - verify(rdsResourceWatcher, times(3)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); + verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); verify(cdsResourceWatcher, times(3)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.DEADLINE_EXCEEDED, ""); verify(edsResourceWatcher, times(3)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.DEADLINE_EXCEEDED, ""); - // Reset backoff sequence and retry immediately. + // Reset backoff sequence and retry after backoff. inOrder.verify(backoffPolicyProvider).get(); - fakeClock.runDueTasks(); + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(20L); + fakeClock.forwardNanos(20L); call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "63", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "5", "", NODE); @@ -2472,21 +3310,19 @@ public void streamClosedAndRetryWithBackoff() { // Management server becomes unreachable again. call.sendError(Status.UNAVAILABLE.asException()); - verify(ldsResourceWatcher, times(4)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); - verify(rdsResourceWatcher, times(4)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); + verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); verify(cdsResourceWatcher, times(4)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); verify(edsResourceWatcher, times(4)).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); // Retry after backoff. inOrder.verify(backoffPolicy2).nextBackoffNanos(); retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); - assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(20L); - fakeClock.forwardNanos(20L); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(200L); + fakeClock.forwardNanos(200L); call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "63", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "5", "", NODE); @@ -2498,22 +3334,29 @@ public void streamClosedAndRetryWithBackoff() { @Test public void streamClosedAndRetryRaceWithAddRemoveWatchers() { - xdsClient.watchLdsResource(LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchRdsResource(RDS_RESOURCE, rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), + LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.sendError(Status.UNAVAILABLE.asException()); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verify(ldsResourceWatcher, Mockito.timeout(1000).times(1)) + .onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); verify(rdsResourceWatcher).onError(errorCaptor.capture()); - assertThat(errorCaptor.getValue().getCode()).isEqualTo(Code.UNAVAILABLE); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); ScheduledTask retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(10L); - xdsClient.cancelLdsResourceWatch(LDS_RESOURCE, ldsResourceWatcher); - xdsClient.cancelRdsResourceWatch(RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchCdsResource(CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchEdsResource(EDS_RESOURCE, edsResourceWatcher); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), + LDS_RESOURCE, ldsResourceWatcher); + xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), + RDS_RESOURCE, rdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), + CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), + EDS_RESOURCE, edsResourceWatcher); fakeClock.forwardNanos(10L); call = resourceDiscoveryCalls.poll(); call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); @@ -2530,10 +3373,11 @@ public void streamClosedAndRetryRaceWithAddRemoveWatchers() { @Test public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedResources() { - xdsClient.watchLdsResource(LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchRdsResource(RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchCdsResource(CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchEdsResource(EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); ScheduledTask ldsResourceTimeout = Iterables.getOnlyElement(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); @@ -2552,6 +3396,12 @@ public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedRe call.sendError(Status.UNAVAILABLE.asException()); assertThat(cdsResourceTimeout.isCancelled()).isTrue(); assertThat(edsResourceTimeout.isCancelled()).isTrue(); + verify(ldsResourceWatcher, never()).onError(errorCaptor.capture()); + verify(rdsResourceWatcher, never()).onError(errorCaptor.capture()); + verify(cdsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + verify(edsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); fakeClock.forwardNanos(10L); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).hasSize(0); @@ -2562,9 +3412,9 @@ public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedRe @Test public void reportLoadStatsToServer() { - xdsClient.watchLdsResource(LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); String clusterName = "cluster-foo.googleapis.com"; - ClusterDropStats dropStats = xdsClient.addClusterDropStats(lrsServerInfo, clusterName, null); + ClusterDropStats dropStats = xdsClient.addClusterDropStats(xdsServerInfo, clusterName, null); LrsRpcCall lrsCall = loadReportCalls.poll(); lrsCall.verifyNextReportClusters(Collections.emptyList()); // initial LRS request @@ -2589,8 +3439,9 @@ public void reportLoadStatsToServer() { @Test public void serverSideListenerFound() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = - startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); + XdsClientImplTestBase.DiscoveryRpcCall call = + startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); @@ -2602,10 +3453,9 @@ public void serverSideListenerFound() { Message listener = mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // Client sends an ACK LDS request. - call.verifyRequest( - ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); + call.verifyRequest(LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); EnvoyServerProtoData.Listener parsedListener = ldsUpdateCaptor.getValue().listener(); assertThat(parsedListener.name()).isEqualTo(LISTENER_RESOURCE); @@ -2625,8 +3475,9 @@ public void serverSideListenerFound() { @Test public void serverSideListenerNotFound() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = - startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); + XdsClientImplTestBase.DiscoveryRpcCall call = + startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); @@ -2638,13 +3489,12 @@ public void serverSideListenerNotFound() { Message listener = mf.buildListenerWithFilterChain( "grpc/server?xds.resource.listening_address=0.0.0.0:8000", 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // Client sends an ACK LDS request. - call.verifyRequest( - ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); + call.verifyRequest(LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); verifyNoInteractions(ldsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LISTENER_RESOURCE); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @@ -2652,8 +3502,9 @@ public void serverSideListenerNotFound() { @Test public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = - startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); + XdsClientImplTestBase.DiscoveryRpcCall call = + startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); @@ -2665,24 +3516,22 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { Message listener = mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "LDS response Listener \'grpc/server?xds.resource.listening_address=" + "0.0.0.0:7000\' validation error: " + "common-tls-context is required in downstream-tls-context"; call.verifyRequestNack(LDS, LISTENER_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(ldsResourceWatcher).onError(captor.capture()); - Status errorStatus = captor.getValue(); - assertThat(errorStatus.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); - assertThat(errorStatus.getDescription()).isEqualTo(errorMsg); + verify(ldsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); } @Test public void serverSideListenerResponseErrorHandling_badTransportSocketName() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = - startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); + XdsClientImplTestBase.DiscoveryRpcCall call = + startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); @@ -2694,85 +3543,205 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { Message listener = mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "LDS response Listener \'grpc/server?xds.resource.listening_address=" + "0.0.0.0:7000\' validation error: " + "transport-socket with name envoy.transport_sockets.bad1 not supported."; call.verifyRequestNack(LDS, LISTENER_RESOURCE, "", "0000", NODE, ImmutableList.of( errorMsg)); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(ldsResourceWatcher).onError(captor.capture()); - Status errorStatus = captor.getValue(); - assertThat(errorStatus.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); - assertThat(errorStatus.getDescription()).isEqualTo(errorMsg); + verify(ldsResourceWatcher).onError(errorCaptor.capture()); + verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + } + + @Test + public void sendingToStoppedServer() throws Exception { + try { + // Establish the adsStream object + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + DiscoveryRpcCall unused = resourceDiscoveryCalls.take(); // clear this entry + + // Shutdown server and initiate a request + xdsServer.shutdownNow(); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + fakeClock.forwardTime(14, TimeUnit.SECONDS); + + // Restart the server + xdsServer = cleanupRule.register( + InProcessServerBuilder + .forName(serverName) + .addService(adsService) + .addService(lrsService) + .directExecutor() + .build() + .start()); + fakeClock.forwardTime(5, TimeUnit.SECONDS); + verify(ldsResourceWatcher, never()).onResourceDoesNotExist(LDS_RESOURCE); + fakeClock.forwardTime(20, TimeUnit.SECONDS); // Trigger rpcRetryTimer + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(3, TimeUnit.SECONDS); + if (call == null) { // The first rpcRetry may have happened before the channel was ready + fakeClock.forwardTime(50, TimeUnit.SECONDS); + call = resourceDiscoveryCalls.poll(3, TimeUnit.SECONDS); + } + + // NOTE: There is a ScheduledExecutorService that may get involved due to the reconnect + // so you cannot rely on the logic being single threaded. The timeout() in verifyRequest + // is therefore necessary to avoid flakiness. + // Send a response and do verifications + call.sendResponse(LDS, mf.buildWrappedResource(testListenerVhosts), VERSION_1, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE); + verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 1, 0, 0); + } catch (Throwable t) { + throw t; // This allows putting a breakpoint here for debugging + } + } + + @Test + public void sendToBadUrl() throws Exception { + // Setup xdsClient to fail on stream creation + XdsClientImpl client = createXdsClient("some. garbage"); + + client.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + fakeClock.forwardTime(20, TimeUnit.SECONDS); + verify(ldsResourceWatcher, Mockito.timeout(5000).times(1)).onError(ArgumentMatchers.any()); + client.shutdown(); } - private DiscoveryRpcCall startResourceWatcher( - ResourceType type, String name, ResourceWatcher watcher) { + @Test + public void sendToNonexistentHost() throws Exception { + // Setup xdsClient to fail on stream creation + XdsClientImpl client = createXdsClient("some.garbage"); + client.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + fakeClock.forwardTime(20, TimeUnit.SECONDS); + + verify(ldsResourceWatcher, Mockito.timeout(5000).times(1)).onError(ArgumentMatchers.any()); + fakeClock.forwardTime(50, TimeUnit.SECONDS); // Trigger rpcRetry if appropriate + assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + client.shutdown(); + } + + private XdsClientImpl createXdsClient(String serverUri) { + BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); + return new XdsClientImpl( + DEFAULT_XDS_CHANNEL_FACTORY, + bootstrapInfo, + Context.ROOT, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + tlsContextManager); + } + + private BootstrapInfo buildBootStrap(String serverUri) { + + ServerInfo xdsServerInfo = ServerInfo.create(serverUri, CHANNEL_CREDENTIALS, + ignoreResourceDeletion()); + + return Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "authority.xds.com", + AuthorityInfo.create( + "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS))), + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of("cert-instance-name", + CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) + .build(); + } + + private DiscoveryRpcCall startResourceWatcher( + XdsResourceType type, String name, ResourceWatcher watcher) { FakeClock.TaskFilter timeoutTaskFilter; - switch (type) { - case LDS: + switch (type.typeName()) { + case "LDS": timeoutTaskFilter = LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; - xdsClient.watchLdsResource(name, (LdsResourceWatcher) watcher); + xdsClient.watchXdsResource(type, name, watcher); break; - case RDS: + case "RDS": timeoutTaskFilter = RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; - xdsClient.watchRdsResource(name, (RdsResourceWatcher) watcher); + xdsClient.watchXdsResource(type, name, watcher); break; - case CDS: + case "CDS": timeoutTaskFilter = CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; - xdsClient.watchCdsResource(name, (CdsResourceWatcher) watcher); + xdsClient.watchXdsResource(type, name, watcher); break; - case EDS: + case "EDS": timeoutTaskFilter = EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; - xdsClient.watchEdsResource(name, (EdsResourceWatcher) watcher); + xdsClient.watchXdsResource(type, name, watcher); break; - case UNKNOWN: default: throw new AssertionError("should never be here"); } + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(type, Collections.singletonList(name), "", "", NODE); ScheduledTask timeoutTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(timeoutTaskFilter)); assertThat(timeoutTask.getDelay(TimeUnit.SECONDS)) - .isEqualTo(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC); + .isEqualTo(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC); return call; } protected abstract static class DiscoveryRpcCall { - protected abstract void verifyRequest( - ResourceType type, List resources, String versionInfo, String nonce, Node node); + protected void verifyRequest( + XdsResourceType type, List resources, String versionInfo, String nonce, + Node node) { + throw new UnsupportedOperationException(); + } protected void verifyRequest( - ResourceType type, String resource, String versionInfo, String nonce, Node node) { + XdsResourceType type, String resource, String versionInfo, String nonce, Node node) { verifyRequest(type, ImmutableList.of(resource), versionInfo, nonce, node); } - protected abstract void verifyRequestNack( - ResourceType type, List resources, String versionInfo, String nonce, Node node, - List errorMessages); + protected void verifyRequestNack( + XdsResourceType type, List resources, String versionInfo, String nonce, + Node node, List errorMessages) { + throw new UnsupportedOperationException(); + } protected void verifyRequestNack( - ResourceType type, String resource, String versionInfo, String nonce, Node node, + XdsResourceType type, String resource, String versionInfo, String nonce, Node node, List errorMessages) { verifyRequestNack(type, ImmutableList.of(resource), versionInfo, nonce, node, errorMessages); } - protected abstract void verifyNoMoreRequest(); + protected void verifyNoMoreRequest() { + throw new UnsupportedOperationException(); + } - protected abstract void sendResponse( - ResourceType type, List resources, String versionInfo, String nonce); + protected void sendResponse( + XdsResourceType type, List resources, String versionInfo, String nonce) { + throw new UnsupportedOperationException(); + } - protected void sendResponse(ResourceType type, Any resource, String versionInfo, String nonce) { + protected void sendResponse(XdsResourceType type, Any resource, String versionInfo, + String nonce) { sendResponse(type, ImmutableList.of(resource), versionInfo, nonce); } - protected abstract void sendError(Throwable t); + protected void sendError(Throwable t) { + throw new UnsupportedOperationException(); + } - protected abstract void sendCompleted(); + protected void sendCompleted() { + throw new UnsupportedOperationException(); + } } protected abstract static class LrsRpcCall { @@ -2780,15 +3749,21 @@ protected abstract static class LrsRpcCall { /** * Verifies a LRS request has been sent with ClusterStats of the given list of clusters. */ - protected abstract void verifyNextReportClusters(List clusters); + protected void verifyNextReportClusters(List clusters) { + throw new UnsupportedOperationException(); + } - protected abstract void sendResponse(List clusters, long loadReportIntervalNano); + protected void sendResponse(List clusters, long loadReportIntervalNano) { + throw new UnsupportedOperationException(); + } } protected abstract static class MessageFactory { /** Throws {@link InvalidProtocolBufferException} on {@link Any#unpack(Class)}. */ protected static final Any FAILING_ANY = Any.newBuilder().setTypeUrl("fake").build(); + protected abstract Any buildWrappedResource(Any originalResource); + protected Message buildListenerWithApiListener(String name, Message routeConfiguration) { return buildListenerWithApiListener( name, routeConfiguration, Collections.emptyList()); @@ -2827,7 +3802,7 @@ protected abstract Message buildVirtualHost( protected abstract Message buildEdsCluster(String clusterName, @Nullable String edsServiceName, String lbPolicy, @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, - @Nullable Message circuitBreakers); + @Nullable Message circuitBreakers, @Nullable Message outlierDetection); protected abstract Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, @@ -2881,4 +3856,83 @@ protected abstract Message buildHttpConnectionManagerFilter( protected abstract Message buildTerminalFilter(); } + + @Test + public void dropXdsV2Lds() { + startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + assertThat(resourceDiscoveryCallsV2).isEmpty(); + assertThat(loadReportCallsV2).isEmpty(); + } + + @Test + public void dropXdsV2Cds() { + startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + assertThat(resourceDiscoveryCallsV2).isEmpty(); + assertThat(loadReportCallsV2).isEmpty(); + } + + @Test + public void dropXdsV2Rds() { + startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); + assertThat(resourceDiscoveryCallsV2).isEmpty(); + assertThat(loadReportCallsV2).isEmpty(); + } + + @Test + public void dropXdsV2Eds() { + startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); + assertThat(resourceDiscoveryCallsV2).isEmpty(); + assertThat(loadReportCallsV2).isEmpty(); + } + + protected BindableService createAdsServiceV2() { + return new AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase() { + @Override + public StreamObserver streamAggregatedResources( + final StreamObserver responseObserver) { + assertThat(adsEnded.get()).isTrue(); // ensure previous call was ended + adsEnded.set(false); + @SuppressWarnings("unchecked") + StreamObserver requestObserver = + mock(StreamObserver.class); + DiscoveryRpcCall call = new DiscoveryRpcCall() {}; + resourceDiscoveryCallsV2.offer(call); + Context.current().addListener( + new Context.CancellationListener() { + @Override + public void cancelled(Context context) { + adsEnded.set(true); + } + }, MoreExecutors.directExecutor()); + return requestObserver; + } + }; + } + + protected BindableService createLrsServiceV2() { + return new LoadReportingServiceGrpc.LoadReportingServiceImplBase() { + @Override + public + StreamObserver + streamLoadStats( + StreamObserver + responseObserver) { + assertThat(lrsEnded.get()).isTrue(); + lrsEnded.set(false); + @SuppressWarnings("unchecked") + StreamObserver requestObserver = + mock(StreamObserver.class); + LrsRpcCall call = new LrsRpcCall() {}; + Context.current().addListener( + new Context.CancellationListener() { + @Override + public void cancelled(Context context) { + lrsEnded.set(true); + } + }, MoreExecutors.directExecutor()); + loadReportCallsV2.offer(call); + return requestObserver; + } + }; + } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/XdsClientImplV3Test.java similarity index 94% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplV3Test.java index 6a75d9ab068..eba41dc4989 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplV3Test.java @@ -41,6 +41,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFunction; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; import io.envoyproxy.envoy.config.core.v3.ConfigSource; @@ -86,6 +87,7 @@ import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; +import io.envoyproxy.envoy.service.discovery.v3.Resource; import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc.LoadReportingServiceImplBase; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; @@ -97,7 +99,6 @@ import io.grpc.Context.CancellationListener; import io.grpc.Status; import io.grpc.stub.StreamObserver; -import io.grpc.xds.AbstractXdsClient.ResourceType; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -107,15 +108,27 @@ import java.util.Set; import javax.annotation.Nullable; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentMatcher; import org.mockito.InOrder; +import org.mockito.Mockito; /** - * Tests for {@link ClientXdsClient} with protocol version v3. + * Tests for {@link XdsClientImpl} with protocol version v3. */ -@RunWith(JUnit4.class) -public class ClientXdsClientV3Test extends ClientXdsClientTestBase { +@RunWith(Parameterized.class) +public class XdsClientImplV3Test extends XdsClientImplTestBase { + + /** Parameterized test cases. */ + @Parameters(name = "ignoreResourceDeletion={0}") + public static Iterable data() { + return ImmutableList.of(false, true); + } + + @Parameter + public boolean ignoreResourceDeletion; @Override protected BindableService createAdsService() { @@ -175,6 +188,11 @@ protected boolean useProtocolV3() { return true; } + @Override + protected boolean ignoreResourceDeletion() { + return ignoreResourceDeletion; + } + private static class DiscoveryRpcCallV3 extends DiscoveryRpcCall { StreamObserver requestObserver; StreamObserver responseObserver; @@ -187,17 +205,17 @@ private DiscoveryRpcCallV3(StreamObserver requestObserver, @Override protected void verifyRequest( - ResourceType type, List resources, String versionInfo, String nonce, + XdsResourceType type, List resources, String versionInfo, String nonce, EnvoyProtoData.Node node) { - verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( + verify(requestObserver, Mockito.timeout(2000)).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNode(), versionInfo, resources, type.typeUrl(), nonce, null, null))); } @Override protected void verifyRequestNack( - ResourceType type, List resources, String versionInfo, String nonce, + XdsResourceType type, List resources, String versionInfo, String nonce, EnvoyProtoData.Node node, List errorMessages) { - verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( + verify(requestObserver, Mockito.timeout(2000)).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNode(), versionInfo, resources, type.typeUrl(), nonce, Code.INVALID_ARGUMENT_VALUE, errorMessages))); } @@ -209,7 +227,7 @@ protected void verifyNoMoreRequest() { @Override protected void sendResponse( - ResourceType type, List resources, String versionInfo, String nonce) { + XdsResourceType type, List resources, String versionInfo, String nonce) { DiscoveryResponse response = DiscoveryResponse.newBuilder() .setVersionInfo(versionInfo) @@ -261,6 +279,13 @@ protected void sendResponse(List clusters, long loadReportIntervalNano) private static class MessageFactoryV3 extends MessageFactory { + @Override + protected Any buildWrappedResource(Any originalResource) { + return Any.pack(Resource.newBuilder() + .setResource(originalResource) + .build()); + } + @SuppressWarnings("unchecked") @Override protected Message buildListenerWithApiListener( @@ -452,10 +477,10 @@ protected Message buildEdsCluster(String clusterName, @Nullable String edsServic String lbPolicy, @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, - @Nullable Message circuitBreakers) { + @Nullable Message circuitBreakers, @Nullable Message outlierDetection) { Cluster.Builder builder = initClusterBuilder( clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, - enableLrs, upstreamTlsContext, transportSocketName, circuitBreakers); + enableLrs, upstreamTlsContext, transportSocketName, circuitBreakers, outlierDetection); builder.setType(DiscoveryType.EDS); EdsClusterConfig.Builder edsClusterConfigBuilder = EdsClusterConfig.newBuilder(); edsClusterConfigBuilder.setEdsConfig( @@ -474,7 +499,7 @@ protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { Cluster.Builder builder = initClusterBuilder( clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, - enableLrs, upstreamTlsContext, "envoy.transport_sockets.tls", circuitBreakers); + enableLrs, upstreamTlsContext, "envoy.transport_sockets.tls", circuitBreakers, null); builder.setType(DiscoveryType.LOGICAL_DNS); builder.setLoadAssignment( ClusterLoadAssignment.newBuilder().addEndpoints( @@ -494,7 +519,7 @@ protected Message buildAggregateCluster(String clusterName, String lbPolicy, ClusterConfig clusterConfig = ClusterConfig.newBuilder().addAllClusters(clusters).build(); CustomClusterType type = CustomClusterType.newBuilder() - .setName(ClientXdsClient.AGGREGATE_CLUSTER_TYPE_NAME) + .setName(XdsResourceType.AGGREGATE_CLUSTER_TYPE_NAME) .setTypedConfig(Any.pack(clusterConfig)) .build(); Cluster.Builder builder = Cluster.newBuilder().setName(clusterName).setClusterType(type); @@ -515,7 +540,7 @@ protected Message buildAggregateCluster(String clusterName, String lbPolicy, private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, - @Nullable Message circuitBreakers) { + @Nullable Message circuitBreakers, @Nullable Message outlierDetection) { Cluster.Builder builder = Cluster.newBuilder(); builder.setName(clusterName); if (lbPolicy.equals("round_robin")) { @@ -543,6 +568,9 @@ private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, if (circuitBreakers != null) { builder.setCircuitBreakers((CircuitBreakers) circuitBreakers); } + if (outlierDetection != null) { + builder.setOutlierDetection((OutlierDetection) outlierDetection); + } return builder; } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsClientTestHelper.java deleted file mode 100644 index 9ea9b1f8dc2..00000000000 --- a/xds/src/test/java/io/grpc/xds/XdsClientTestHelper.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import com.google.protobuf.Any; -import io.envoyproxy.envoy.config.core.v3.Address; -import io.envoyproxy.envoy.config.listener.v3.ApiListener; -import io.envoyproxy.envoy.config.listener.v3.FilterChain; -import io.envoyproxy.envoy.config.listener.v3.Listener; -import io.envoyproxy.envoy.config.route.v3.Route; -import io.envoyproxy.envoy.config.route.v3.RouteAction; -import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; -import io.envoyproxy.envoy.config.route.v3.RouteMatch; -import io.envoyproxy.envoy.config.route.v3.VirtualHost; -import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; -import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; -import io.grpc.xds.EnvoyProtoData.Node; -import java.util.List; - -/** - * Helper methods for building protobuf messages with custom data for xDS protocols. - */ -// TODO(chengyuanzhang, sanjaypujare): delete this class, should not dump everything here. -class XdsClientTestHelper { - static DiscoveryResponse buildDiscoveryResponse(String versionInfo, - List resources, String typeUrl, String nonce) { - return - DiscoveryResponse.newBuilder() - .setVersionInfo(versionInfo) - .setTypeUrl(typeUrl) - .addAllResources(resources) - .setNonce(nonce) - .build(); - } - - static io.envoyproxy.envoy.api.v2.DiscoveryResponse buildDiscoveryResponseV2(String versionInfo, - List resources, String typeUrl, String nonce) { - return - io.envoyproxy.envoy.api.v2.DiscoveryResponse.newBuilder() - .setVersionInfo(versionInfo) - .setTypeUrl(typeUrl) - .addAllResources(resources) - .setNonce(nonce) - .build(); - } - - static DiscoveryRequest buildDiscoveryRequest(Node node, String versionInfo, - List resourceNames, String typeUrl, String nonce) { - return - DiscoveryRequest.newBuilder() - .setVersionInfo(versionInfo) - .setNode(node.toEnvoyProtoNode()) - .setTypeUrl(typeUrl) - .addAllResourceNames(resourceNames) - .setResponseNonce(nonce) - .build(); - } - - static Listener buildListener(String name, com.google.protobuf.Any apiListener) { - return - Listener.newBuilder() - .setName(name) - .setAddress(Address.getDefaultInstance()) - .addFilterChains(FilterChain.getDefaultInstance()) - .setApiListener(ApiListener.newBuilder().setApiListener(apiListener)) - .build(); - } - - static io.envoyproxy.envoy.api.v2.Listener buildListenerV2( - String name, com.google.protobuf.Any apiListener) { - return - io.envoyproxy.envoy.api.v2.Listener.newBuilder() - .setName(name) - .setAddress(io.envoyproxy.envoy.api.v2.core.Address.getDefaultInstance()) - .addFilterChains(io.envoyproxy.envoy.api.v2.listener.FilterChain.getDefaultInstance()) - .setApiListener(io.envoyproxy.envoy.config.listener.v2.ApiListener.newBuilder() - .setApiListener(apiListener)) - .build(); - } - - static RouteConfiguration buildRouteConfiguration(String name, - List virtualHosts) { - return - RouteConfiguration.newBuilder() - .setName(name) - .addAllVirtualHosts(virtualHosts) - .build(); - } - - static io.envoyproxy.envoy.api.v2.RouteConfiguration buildRouteConfigurationV2(String name, - List virtualHosts) { - return - io.envoyproxy.envoy.api.v2.RouteConfiguration.newBuilder() - .setName(name) - .addAllVirtualHosts(virtualHosts) - .build(); - } - - static VirtualHost buildVirtualHost(List domains, String clusterName) { - return VirtualHost.newBuilder() - .setName("virtualhost00.googleapis.com") // don't care - .addAllDomains(domains) - .addRoutes( - Route.newBuilder() - .setRoute(RouteAction.newBuilder().setCluster(clusterName)) - .setMatch(RouteMatch.newBuilder().setPrefix(""))) - .build(); - } - - static io.envoyproxy.envoy.api.v2.route.VirtualHost buildVirtualHostV2( - List domains, String clusterName) { - return io.envoyproxy.envoy.api.v2.route.VirtualHost.newBuilder() - .setName("virtualhost00.googleapis.com") // don't care - .addAllDomains(domains) - .addRoutes( - io.envoyproxy.envoy.api.v2.route.Route.newBuilder() - .setRoute( - io.envoyproxy.envoy.api.v2.route.RouteAction.newBuilder() - .setCluster(clusterName)) - .setMatch(io.envoyproxy.envoy.api.v2.route.RouteMatch.newBuilder().setPrefix(""))) - .build(); - } -} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index e7d090b6cd1..f40a28158f1 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -18,13 +18,14 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector.NO_FILTER_CHAIN; -import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; +import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -43,13 +44,13 @@ import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; -import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProvider; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SslContextProvider; +import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -70,7 +71,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; - import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -83,6 +83,7 @@ public class XdsClientWrapperForServerSdsTestMisc { private static final int PORT = 7000; + private static final int START_WAIT_AFTER_LISTENER_MILLIS = 100; private EmbeddedChannel channel; private ChannelPipeline pipeline; @@ -158,15 +159,16 @@ public void run() { String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT); - EnvoyServerProtoData.Listener listener = + EnvoyServerProtoData.Listener tcpListener = EnvoyServerProtoData.Listener.create( "listener1", "10.1.2.3", ImmutableList.of(), null); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(tcpListener); xdsClient.ldsWatcher.onChanged(listenerUpdate); - start.get(5, TimeUnit.SECONDS); + verify(listener, timeout(5000)).onServing(); + start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); FilterChainSelector selector = selectorManager.getSelectorToUpdateSelector(); assertThat(getSslContextProviderSupplier(selector)).isNull(); } @@ -186,56 +188,9 @@ public void run() { }); String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsWatched); + verify(listener, timeout(5000)).onNotServing(any()); try { - start.get(5, TimeUnit.SECONDS); - fail("Start should throw exception"); - } catch (TimeoutException ex) { - assertThat(start.isDone()).isFalse(); - } - assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); - } - - @Test - public void registerServerWatcher_notifyInternalError() throws Exception { - final SettableFuture start = SettableFuture.create(); - Executors.newSingleThreadExecutor().execute(new Runnable() { - @Override - public void run() { - try { - start.set(xdsServerWrapper.start()); - } catch (Exception ex) { - start.setException(ex); - } - } - }); - xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onError(Status.INTERNAL); - try { - start.get(5, TimeUnit.SECONDS); - fail("Start should throw exception"); - } catch (TimeoutException ex) { - assertThat(start.isDone()).isFalse(); - } - assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); - } - - @Test - public void registerServerWatcher_notifyPermDeniedError() throws Exception { - final SettableFuture start = SettableFuture.create(); - Executors.newSingleThreadExecutor().execute(new Runnable() { - @Override - public void run() { - try { - start.set(xdsServerWrapper.start()); - } catch (Exception ex) { - start.setException(ex); - } - } - }); - xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); - try { - start.get(5, TimeUnit.SECONDS); + start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); fail("Start should throw exception"); } catch (TimeoutException ex) { assertThat(start.isDone()).isFalse(); @@ -321,23 +276,6 @@ public void releaseOldSupplierOnNotFound_verifyClose() throws Exception { verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } - @Test - public void releaseOldSupplierOnPermDeniedError_verifyClose() throws Exception { - SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); - when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) - .thenReturn(sslContextProvider1); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - localAddress = new InetSocketAddress(ipLocalAddress, PORT); - sendListenerUpdate(localAddress, tlsContext1, null, - tlsContextManager); - SslContextProviderSupplier returnedSupplier = - getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); - assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); - callUpdateSslContext(returnedSupplier); - xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); - verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); - } - @Test public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); @@ -380,7 +318,8 @@ public void run() { XdsServerTestHelper .generateListenerUpdate(xdsClient, ImmutableList.of(), tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); - start.get(5, TimeUnit.SECONDS); + verify(listener, timeout(5000)).onServing(); + start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); channel = new EmbeddedChannel() { diff --git a/xds/src/test/java/io/grpc/xds/XdsCredentialsRegistryTest.java b/xds/src/test/java/io/grpc/xds/XdsCredentialsRegistryTest.java new file mode 100644 index 00000000000..facaffc67a2 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsCredentialsRegistryTest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableMap; +import io.grpc.ChannelCredentials; +import io.grpc.xds.XdsCredentialsProvider; +import io.grpc.xds.XdsCredentialsRegistry; +import io.grpc.xds.internal.GoogleDefaultXdsCredentialsProvider; +import io.grpc.xds.internal.InsecureXdsCredentialsProvider; +import io.grpc.xds.internal.TlsXdsCredentialsProvider; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link XdsCredentialsRegistry}. */ +@RunWith(JUnit4.class) +public class XdsCredentialsRegistryTest { + + @Test + public void register_unavailableProviderThrows() { + XdsCredentialsRegistry reg = new XdsCredentialsRegistry(); + try { + reg.register(new BaseCredsProvider(false, 5, "creds")); + fail("Should throw"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("isAvailable() returned false"); + } + assertThat(reg.providers()).isEmpty(); + } + + @Test + public void deregister() { + XdsCredentialsRegistry reg = new XdsCredentialsRegistry(); + String credsName = "sampleCredsName"; + XdsCredentialsProvider p1 = new BaseCredsProvider(true, 5, credsName); + XdsCredentialsProvider p2 = new BaseCredsProvider(true, 5, credsName); + XdsCredentialsProvider p3 = new BaseCredsProvider(true, 5, credsName); + reg.register(p1); + reg.register(p2); + reg.register(p3); + assertThat(reg.getProvider(credsName)).isSameInstanceAs(p1); + reg.deregister(p2); + assertThat(reg.getProvider(credsName)).isSameInstanceAs(p1); + reg.deregister(p1); + assertThat(reg.getProvider(credsName)).isSameInstanceAs(p3); + + } + + @Test + public void provider_sorted() { + XdsCredentialsRegistry reg = new XdsCredentialsRegistry(); + String credsName = "sampleCredsName"; + XdsCredentialsProvider p1 = new BaseCredsProvider(true, 5, credsName); + XdsCredentialsProvider p2 = new BaseCredsProvider(true, 3, credsName); + XdsCredentialsProvider p3 = new BaseCredsProvider(true, 8, credsName); + XdsCredentialsProvider p4 = new BaseCredsProvider(true, 3, credsName); + XdsCredentialsProvider p5 = new BaseCredsProvider(true, 8, credsName); + reg.register(p1); + reg.register(p2); + reg.register(p3); + reg.register(p4); + reg.register(p5); + assertThat(reg.getProvider(credsName)).isSameInstanceAs(p3); + } + + @Test + public void channelCredentials_successful() { + XdsCredentialsRegistry registry = new XdsCredentialsRegistry(); + String credsName = "sampleCredsName"; + + registry.register( + new BaseCredsProvider(true, 5, credsName) { + @Override + public ChannelCredentials newChannelCredentials(Map config) { + return new SampleChannelCredentials(config); + } + }); + + ImmutableMap sampleConfig = ImmutableMap.of("a", "b"); + ChannelCredentials creds = registry.providers().get(credsName) + .newChannelCredentials(sampleConfig); + assertSame(SampleChannelCredentials.class, creds.getClass()); + assertEquals(sampleConfig, ((SampleChannelCredentials)creds).getConfig()); + } + + @Test + public void channelCredentials_multiSuccessful() { + XdsCredentialsRegistry registry = new XdsCredentialsRegistry(); + String credsName1 = "sampleCreds1"; + String credsName2 = "sampleCreds2"; + registry.register( + new BaseCredsProvider(true, 5, credsName1) { + @Override + public ChannelCredentials newChannelCredentials(Map config) { + return null; + } + }); + + registry.register( + new BaseCredsProvider(true, 7, credsName2) { + @Override + public ChannelCredentials newChannelCredentials(Map config) { + return new SampleChannelCredentials(config); + } + }); + + assertThat(registry.getProvider(credsName1).newChannelCredentials(null)).isNull(); + assertThat(registry.getProvider(credsName1).getName()).isEqualTo(credsName1); + assertThat(registry.getProvider(credsName2).newChannelCredentials(null)).isNotNull(); + assertThat(registry.getProvider(credsName2).getName()).isEqualTo(credsName2); + } + + @Test + public void defaultRegistry_providers() { + Map providers = + XdsCredentialsRegistry.getDefaultRegistry().providers(); + assertThat(providers).hasSize(3); + assertThat(providers.get("google_default").getClass()) + .isEqualTo(GoogleDefaultXdsCredentialsProvider.class); + assertThat(providers.get("insecure").getClass()) + .isEqualTo(InsecureXdsCredentialsProvider.class); + assertThat(providers.get("tls").getClass()) + .isEqualTo(TlsXdsCredentialsProvider.class); + } + + @Test + public void getClassesViaHardcoded_classesPresent() throws Exception { + List> classes = XdsCredentialsRegistry.getHardCodedClasses(); + assertThat(classes).containsExactly( + GoogleDefaultXdsCredentialsProvider.class, + InsecureXdsCredentialsProvider.class, + TlsXdsCredentialsProvider.class); + } + + @Test + public void getProvider_null() { + try { + XdsCredentialsRegistry.getDefaultRegistry().getProvider(null); + fail("Should throw"); + } catch (NullPointerException e) { + assertThat(e).hasMessageThat().contains("name"); + } + } + + + private static class BaseCredsProvider extends XdsCredentialsProvider { + private final boolean isAvailable; + private final int priority; + private final String name; + + public BaseCredsProvider(boolean isAvailable, int priority, String name) { + this.isAvailable = isAvailable; + this.priority = priority; + this.name = name; + } + + @Override + protected String getName() { + return name; + } + + @Override + public boolean isAvailable() { + return isAvailable; + } + + @Override + public int priority() { + return priority; + } + + @Override + public ChannelCredentials newChannelCredentials(Map config) { + throw new UnsupportedOperationException(); + } + } + + private static class SampleChannelCredentials extends ChannelCredentials { + private final Map config; + + SampleChannelCredentials(Map config) { + this.config = config; + } + + public Map getConfig() { + return config; + } + + @Override + public ChannelCredentials withoutBearerTokens() { + return this; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index a17a043af5e..b6f8b3c3663 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -86,7 +86,9 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.io.IOException; import java.util.ArrayList; @@ -143,7 +145,7 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { private final TestChannel channel = new TestChannel(); private BootstrapInfo bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))) + "td.googleapis.com", InsecureChannelCredentials.create()))) .node(Node.newBuilder().build()) .build(); private String expectedLdsResourceName = AUTHORITY; @@ -167,7 +169,8 @@ public void setUp() { FilterRegistry filterRegistry = FilterRegistry.newRegistry().register( new FaultFilter(mockRandom, new AtomicLong()), RouterFilter.INSTANCE); - resolver = new XdsNameResolver(null, AUTHORITY, serviceConfigParser, syncContext, scheduler, + resolver = new XdsNameResolver(null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, null); } @@ -200,7 +203,8 @@ public ObjectPool getOrCreate() throws XdsInitializationException { throw new XdsInitializationException("Fail to read bootstrap file"); } }; - resolver = new XdsNameResolver(null, AUTHORITY, serviceConfigParser, syncContext, scheduler, + resolver = new XdsNameResolver(null, AUTHORITY, null, + serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); @@ -213,7 +217,7 @@ public ObjectPool getOrCreate() throws XdsInitializationException { @Test public void resolving_withTargetAuthorityNotFound() { resolver = new XdsNameResolver( - "notfound.google.com", AUTHORITY, serviceConfigParser, syncContext, scheduler, + "notfound.google.com", AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); @@ -227,14 +231,15 @@ public void resolving_withTargetAuthorityNotFound() { public void resolving_noTargetAuthority_templateWithoutXdstp() { bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))) + "td.googleapis.com", InsecureChannelCredentials.create()))) .node(Node.newBuilder().build()) .clientDefaultListenerResourceNameTemplate("%s/id=1") .build(); String serviceAuthority = "[::FFFF:129.144.52.38]:80"; expectedLdsResourceName = "[::FFFF:129.144.52.38]:80/id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, + null, serviceAuthority, null, serviceConfigParser, syncContext, + scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); @@ -244,7 +249,7 @@ public void resolving_noTargetAuthority_templateWithoutXdstp() { public void resolving_noTargetAuthority_templateWithXdstp() { bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))) + "td.googleapis.com", InsecureChannelCredentials.create()))) .node(Node.newBuilder().build()) .clientDefaultListenerResourceNameTemplate( "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/%s?id=1") @@ -254,7 +259,7 @@ public void resolving_noTargetAuthority_templateWithXdstp() { "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, serviceConfigParser, syncContext, scheduler, + null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); @@ -271,13 +276,13 @@ public void resolving_targetAuthorityInAuthoritiesMap() { .authorities( ImmutableMap.of(targetAuthority, AuthorityInfo.create( "xdstp://" + targetAuthority + "/envoy.config.listener.v3.Listener/%s?foo=1&bar=2", - ImmutableList.of(ServerInfo.create( + ImmutableList.of(ServerInfo.create( "td.googleapis.com", InsecureChannelCredentials.create(), true))))) .build(); expectedLdsResourceName = "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?bar=2&foo=1"; // query param canonified resolver = new XdsNameResolver( - "xds.authority.com", serviceAuthority, serviceConfigParser, syncContext, scheduler, + "xds.authority.com", serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); @@ -288,7 +293,7 @@ public void resolving_ldsResourceNotFound() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsResourceNotFound(); - assertEmptyResolutionResult(); + assertEmptyResolutionResult(expectedLdsResourceName); } @SuppressWarnings("unchecked") @@ -296,12 +301,12 @@ public void resolving_ldsResourceNotFound() { public void resolving_ldsResourceUpdateRdsName() { Route route1 = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()); + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); Route route2 = Route.forAction(RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), - ImmutableMap.of()); + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), + ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); @@ -310,7 +315,7 @@ public void resolving_ldsResourceUpdateRdsName() { VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route1), - ImmutableMap.of()); + ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); verify(mockListener).onResult(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( @@ -326,7 +331,7 @@ public void resolving_ldsResourceUpdateRdsName() { virtualHost = VirtualHost.create("virtualhost-alter", Collections.singletonList(AUTHORITY), Collections.singletonList(route2), - ImmutableMap.of()); + ImmutableMap.of()); xdsClient.deliverRdsUpdate(alternativeRdsResource, Collections.singletonList(virtualHost)); // Two new service config updates triggered: // - with load balancing config being able to select cluster1 and cluster2 @@ -346,7 +351,7 @@ public void resolving_rdsResourceNotFound() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); xdsClient.deliverRdsResourceNotFound(RDS_RESOURCE_NAME); - assertEmptyResolutionResult(); + assertEmptyResolutionResult(RDS_RESOURCE_NAME); } @SuppressWarnings("unchecked") @@ -354,8 +359,8 @@ public void resolving_rdsResourceNotFound() { public void resolving_ldsResourceRevokedAndAddedBack() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()); + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); @@ -364,7 +369,7 @@ public void resolving_ldsResourceRevokedAndAddedBack() { VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), - ImmutableMap.of()); + ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); verify(mockListener).onResult(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( @@ -374,7 +379,7 @@ public void resolving_ldsResourceRevokedAndAddedBack() { reset(mockListener); xdsClient.deliverLdsResourceNotFound(); // revoke LDS resource assertThat(xdsClient.rdsResource).isNull(); // stop subscribing to stale RDS resource - assertEmptyResolutionResult(); + assertEmptyResolutionResult(expectedLdsResourceName); reset(mockListener); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); @@ -393,8 +398,8 @@ public void resolving_ldsResourceRevokedAndAddedBack() { public void resolving_rdsResourceRevokedAndAddedBack() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()); + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); @@ -403,7 +408,7 @@ public void resolving_rdsResourceRevokedAndAddedBack() { VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), - ImmutableMap.of()); + ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); verify(mockListener).onResult(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( @@ -412,7 +417,7 @@ public void resolving_rdsResourceRevokedAndAddedBack() { reset(mockListener); xdsClient.deliverRdsResourceNotFound(RDS_RESOURCE_NAME); // revoke RDS resource - assertEmptyResolutionResult(); + assertEmptyResolutionResult(RDS_RESOURCE_NAME); // Simulate management server adds back the previously used RDS resource. reset(mockListener); @@ -431,7 +436,21 @@ public void resolving_encounterErrorLdsWatcherOnly() { verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("server unreachable"); + assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY + + ". xDS server returned: UNAVAILABLE: server unreachable"); + } + + @Test + public void resolving_translateErrorLds() { + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverError(Status.NOT_FOUND.withDescription("server unreachable")); + verify(mockListener).onError(errorCaptor.capture()); + Status error = errorCaptor.getValue(); + assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY + + ". xDS server returned: NOT_FOUND: server unreachable"); + assertThat(error.getCause()).isNull(); } @Test @@ -441,10 +460,69 @@ public void resolving_encounterErrorLdsAndRdsWatchers() { xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); xdsClient.deliverError(Status.UNAVAILABLE.withDescription("server unreachable")); verify(mockListener, times(2)).onError(errorCaptor.capture()); - for (Status error : errorCaptor.getAllValues()) { - assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("server unreachable"); - } + Status error = errorCaptor.getAllValues().get(0); + assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY + + ". xDS server returned: UNAVAILABLE: server unreachable"); + error = errorCaptor.getAllValues().get(1); + assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo("Unable to load RDS " + RDS_RESOURCE_NAME + + ". xDS server returned: UNAVAILABLE: server unreachable"); + } + + @SuppressWarnings("unchecked") + @Test + public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { + Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), + RouteAction.forCluster( + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); + VirtualHost virtualHost = + VirtualHost.create("virtualhost", Collections.singletonList("random"), + Collections.singletonList(route), + ImmutableMap.of()); + + resolver = new XdsNameResolver(null, AUTHORITY, "random", + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate(0L, Arrays.asList(virtualHost)); + verify(mockListener).onResult(resolutionResultCaptor.capture()); + assertServiceConfigForLoadBalancingConfig( + Collections.singletonList(cluster1), + (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); + } + + @Test + public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() { + Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), + RouteAction.forCluster( + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); + VirtualHost virtualHost = + VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), + Collections.singletonList(route), + ImmutableMap.of()); + + resolver = new XdsNameResolver(null, AUTHORITY, "random", + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate(0L, Arrays.asList(virtualHost)); + assertEmptyResolutionResult("random"); + } + + @Test + public void resolving_matchingVirtualHostNotFoundForOverrideAuthority() { + resolver = new XdsNameResolver(null, AUTHORITY, AUTHORITY, + serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate(0L, buildUnmatchedVirtualHosts()); + assertEmptyResolutionResult(expectedLdsResourceName); } @Test @@ -452,7 +530,7 @@ public void resolving_matchingVirtualHostNotFoundInLdsResource() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate(0L, buildUnmatchedVirtualHosts()); - assertEmptyResolutionResult(); + assertEmptyResolutionResult(expectedLdsResourceName); } @Test @@ -461,25 +539,25 @@ public void resolving_matchingVirtualHostNotFoundInRdsResource() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, buildUnmatchedVirtualHosts()); - assertEmptyResolutionResult(); + assertEmptyResolutionResult(expectedLdsResourceName); } private List buildUnmatchedVirtualHosts() { Route route1 = Route.forAction(RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()); + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); Route route2 = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()); + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + ImmutableMap.of()); return Arrays.asList( VirtualHost.create("virtualhost-foo", Collections.singletonList("hello.googleapis.com"), Collections.singletonList(route1), - ImmutableMap.of()), + ImmutableMap.of()), VirtualHost.create("virtualhost-bar", Collections.singletonList("hi.googleapis.com"), Collections.singletonList(route2), - ImmutableMap.of())); + ImmutableMap.of())); } @Test @@ -488,11 +566,11 @@ public void resolved_noTimeout() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), null, null), // per-route timeout unset - ImmutableMap.of()); + cluster1, Collections.emptyList(), null, null), // per-route timeout unset + ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("does not matter", Collections.singletonList(AUTHORITY), Collections.singletonList(route), - ImmutableMap.of()); + ImmutableMap.of()); xdsClient.deliverLdsUpdate(0L, Collections.singletonList(virtualHost)); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); @@ -506,11 +584,11 @@ public void resolved_fallbackToHttpMaxStreamDurationAsTimeout() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), null, null), // per-route timeout unset - ImmutableMap.of()); + cluster1, Collections.emptyList(), null, null), // per-route timeout unset + ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("does not matter", Collections.singletonList(AUTHORITY), Collections.singletonList(route), - ImmutableMap.of()); + ImmutableMap.of()); xdsClient.deliverLdsUpdate(TimeUnit.SECONDS.toNanos(5L), Collections.singletonList(virtualHost)); verify(mockListener).onResult(resolutionResultCaptor.capture()); @@ -523,7 +601,7 @@ public void resolved_fallbackToHttpMaxStreamDurationAsTimeout() { public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { ServiceConfigParser realParser = new ScParser( true, 5, 5, new AutoConfiguredLoadBalancerFactory("pick-first")); - resolver = new XdsNameResolver(null, AUTHORITY, realParser, syncContext, scheduler, + resolver = new XdsNameResolver(null, AUTHORITY, null, realParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); @@ -536,10 +614,10 @@ public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( cluster1, - Collections.emptyList(), + Collections.emptyList(), null, retryPolicy), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); @@ -592,12 +670,12 @@ public void resolved_simpleCallFailedToRoute_routeWithNonForwardingAction() { Arrays.asList( Route.forNonForwardingAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), - RouteAction.forCluster(cluster2, Collections.emptyList(), + RouteAction.forCluster(cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertThat(result.getAddresses()).isEmpty(); @@ -633,7 +711,7 @@ public void resolved_rpcHashingByHeader_withoutSubstitution() { HashPolicy.forHeader(false, "custom-key", null, null)), null, null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -667,7 +745,7 @@ public void resolved_rpcHashingByHeader_withSubstitution() { HashPolicy.forHeader(false, "custom-key", Pattern.compile("value"), "val")), null, null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -706,7 +784,7 @@ public void resolved_rpcHashingByChannelId() { Collections.singletonList(HashPolicy.forChannelId(false)), null, null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -719,14 +797,16 @@ public void resolved_rpcHashingByChannelId() { // Second call, with no custom header. startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), + Collections.emptyMap(), CallOptions.DEFAULT); long hash2 = testCall.callOptions.getOption(XdsNameResolver.RPC_HASH_KEY); // A different resolver/Channel. resolver.shutdown(); reset(mockListener); - resolver = new XdsNameResolver(null, AUTHORITY, serviceConfigParser, syncContext, scheduler, + when(mockRandom.nextLong()).thenReturn(123L); + resolver = new XdsNameResolver(null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); xdsClient = (FakeXdsClient) resolver.getXdsClient(); @@ -740,14 +820,14 @@ public void resolved_rpcHashingByChannelId() { Collections.singletonList(HashPolicy.forChannelId(false)), null, null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); configSelector = resolutionResultCaptor.getValue().getAttributes().get( InternalConfigSelector.KEY); // Third call, with no custom header. startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), + Collections.emptyMap(), CallOptions.DEFAULT); long hash3 = testCall.callOptions.getOption(XdsNameResolver.RPC_HASH_KEY); @@ -769,15 +849,15 @@ public void resolved_resourceUpdateAfterCallStarted() { Route.forAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - "another-cluster", Collections.emptyList(), + "another-cluster", Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); // Updated service config still contains cluster1 while it is removed resource. New calls no @@ -809,15 +889,15 @@ public void resolved_resourceUpdatedBeforeCallStarted() { Route.forAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - "another-cluster", Collections.emptyList(), + "another-cluster", Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); // Two consecutive service config updates: one for removing clcuster1, // one for adding "another=cluster". verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); @@ -845,15 +925,15 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { Route.forAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - "another-cluster", Collections.emptyList(), + "another-cluster", Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); @@ -866,15 +946,15 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { Route.forAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - "another-cluster", Collections.emptyList(), + "another-cluster", Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verifyNoMoreInteractions(mockListener); // no cluster added/deleted assertCallSelectClusterResult(call1, configSelector, "another-cluster", 15.0); } @@ -889,23 +969,23 @@ public void resolved_raceBetweenClusterReleasedAndResourceUpdateAddBackAgain() { Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); xdsClient.deliverLdsUpdate( Arrays.asList( Route.forAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); testCall.deliverErrorStatus(); verifyNoMoreInteractions(mockListener); } @@ -922,13 +1002,13 @@ public void resolved_simpleCallSucceeds_routeToWeightedCluster() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forWeightedClusters( Arrays.asList( - ClusterWeight.create(cluster1, 20, ImmutableMap.of()), + ClusterWeight.create(cluster1, 20, ImmutableMap.of()), ClusterWeight.create( - cluster2, 80, ImmutableMap.of())), - Collections.emptyList(), + cluster2, 80, ImmutableMap.of())), + Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertThat(result.getAddresses()).isEmpty(); @@ -954,10 +1034,10 @@ public void resolved_simpleCallSucceeds_routeToRls() { "rls-plugin-foo", RlsPluginConfig.create( ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), - Collections.emptyList(), + Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertThat(result.getAddresses()).isEmpty(); @@ -1001,11 +1081,11 @@ public void resolved_simpleCallSucceeds_routeToRls() { RlsPluginConfig.create( // changed ImmutableMap.of("lookupService", "rls-cbt-2.googleapis.com"))), - Collections.emptyList(), + Collections.emptyList(), // changed TimeUnit.SECONDS.toNanos(30L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); ResolutionResult result2 = resolutionResultCaptor.getValue(); @SuppressWarnings("unchecked") @@ -1038,11 +1118,16 @@ public void resolved_simpleCallSucceeds_routeToRls() { } @SuppressWarnings("unchecked") - private void assertEmptyResolutionResult() { + private void assertEmptyResolutionResult(String resource) { verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertThat(result.getAddresses()).isEmpty(); assertThat((Map) result.getServiceConfig().getConfig()).isEmpty(); + InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); + Result configResult = configSelector.selectConfig( + new PickSubchannelArgsImpl(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + assertThat(configResult.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(configResult.getStatus().getDescription()).contains(resource); } private void assertCallSelectClusterResult( @@ -1054,7 +1139,7 @@ private void assertCallSelectClusterResult( ClientInterceptor interceptor = result.getInterceptor(); ClientCall clientCall = interceptor.interceptCall( call.methodDescriptor, CallOptions.DEFAULT, channel); - clientCall.start(new NoopClientCallListener(), new Metadata()); + clientCall.start(new NoopClientCallListener<>(), new Metadata()); assertThat(testCall.callOptions.getOption(XdsNameResolver.CLUSTER_SELECTION_KEY)) .isEqualTo("cluster:" + expectedCluster); @SuppressWarnings("unchecked") @@ -1082,7 +1167,7 @@ private void assertCallSelectRlsPluginResult( ClientInterceptor interceptor = result.getInterceptor(); ClientCall clientCall = interceptor.interceptCall( call.methodDescriptor, CallOptions.DEFAULT, channel); - clientCall.start(new NoopClientCallListener(), new Metadata()); + clientCall.start(new NoopClientCallListener<>(), new Metadata()); assertThat(testCall.callOptions.getOption(XdsNameResolver.CLUSTER_SELECTION_KEY)) .isEqualTo("cluster_specifier_plugin:" + expectedPluginName); @SuppressWarnings("unchecked") @@ -1104,15 +1189,15 @@ private InternalConfigSelector resolveToClusters() { Route.forAction( RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()), + ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), - ImmutableMap.of()))); + ImmutableMap.of()))); verify(mockListener).onResult(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertThat(result.getAddresses()).isEmpty(); @@ -1247,7 +1332,7 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { 4, ImmutableList.of(Code.UNAVAILABLE, Code.CANCELLED), Durations.fromMillis(100), Durations.fromMillis(200), null); RetryPolicy retryPolicyWithEmptyStatusCodes = RetryPolicy.create( - 4, ImmutableList.of(), Durations.fromMillis(100), Durations.fromMillis(200), null); + 4, ImmutableList.of(), Durations.fromMillis(100), Durations.fromMillis(200), null); // timeout only String expectedServiceConfigJson = "{\n" @@ -1369,13 +1454,13 @@ public void findVirtualHostForHostName_exactMatchFirst() { List routes = Collections.emptyList(); VirtualHost vHost1 = VirtualHost.create("virtualhost01.googleapis.com", Arrays.asList("a.googleapis.com", "b.googleapis.com"), routes, - ImmutableMap.of()); + ImmutableMap.of()); VirtualHost vHost2 = VirtualHost.create("virtualhost02.googleapis.com", Collections.singletonList("*.googleapis.com"), routes, - ImmutableMap.of()); + ImmutableMap.of()); VirtualHost vHost3 = VirtualHost.create("virtualhost03.googleapis.com", Collections.singletonList("*"), routes, - ImmutableMap.of()); + ImmutableMap.of()); List virtualHosts = Arrays.asList(vHost1, vHost2, vHost3); assertThat(XdsNameResolver.findVirtualHostForHostName(virtualHosts, hostname)) .isEqualTo(vHost1); @@ -1387,13 +1472,13 @@ public void findVirtualHostForHostName_preferSuffixDomainOverPrefixDomain() { List routes = Collections.emptyList(); VirtualHost vHost1 = VirtualHost.create("virtualhost01.googleapis.com", Arrays.asList("*.googleapis.com", "b.googleapis.com"), routes, - ImmutableMap.of()); + ImmutableMap.of()); VirtualHost vHost2 = VirtualHost.create("virtualhost02.googleapis.com", Collections.singletonList("a.googleapis.*"), routes, - ImmutableMap.of()); + ImmutableMap.of()); VirtualHost vHost3 = VirtualHost.create("virtualhost03.googleapis.com", Collections.singletonList("*"), routes, - ImmutableMap.of()); + ImmutableMap.of()); List virtualHosts = Arrays.asList(vHost1, vHost2, vHost3); assertThat(XdsNameResolver.findVirtualHostForHostName(virtualHosts, hostname)) .isEqualTo(vHost1); @@ -1405,13 +1490,13 @@ public void findVirtualHostForHostName_asteriskMatchAnyDomain() { List routes = Collections.emptyList(); VirtualHost vHost1 = VirtualHost.create("virtualhost01.googleapis.com", Collections.singletonList("*"), routes, - ImmutableMap.of()); + ImmutableMap.of()); VirtualHost vHost2 = VirtualHost.create("virtualhost02.googleapis.com", Collections.singletonList("b.googleapis.com"), routes, - ImmutableMap.of()); + ImmutableMap.of()); List virtualHosts = Arrays.asList(vHost1, vHost2); assertThat(XdsNameResolver.findVirtualHostForHostName(virtualHosts, hostname)) - .isEqualTo(vHost1);; + .isEqualTo(vHost1); } @Test @@ -1431,7 +1516,7 @@ public void resolved_faultAbortInLdsUpdate() { InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // no header abort key provided in metadata, rpc should succeed ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcSucceeded(observer); // header abort http status key provided, rpc should fail observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1500,7 +1585,7 @@ public void resolved_faultAbortInLdsUpdate() { result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), CallOptions.DEFAULT); + Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcFailed( observer, Status.UNAUTHENTICATED.withDescription( @@ -1518,7 +1603,7 @@ public void resolved_faultAbortInLdsUpdate() { result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), CallOptions.DEFAULT); + Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcSucceeded(observer); } @@ -1537,7 +1622,7 @@ public void resolved_faultDelayInLdsUpdate() { InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // no header delay key provided in metadata, rpc should succeed immediately ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcSucceeded(observer); // header delay key provided, rpc should be delayed observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1577,7 +1662,7 @@ public void resolved_faultDelayInLdsUpdate() { result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), CallOptions.DEFAULT); + Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcDelayed(observer, 5000L); // fixed delay, fix rate = 40% @@ -1590,7 +1675,7 @@ public void resolved_faultDelayInLdsUpdate() { result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), CallOptions.DEFAULT); + Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcSucceeded(observer); } @@ -1612,15 +1697,15 @@ public void resolved_faultDelayWithMaxActiveStreamsInLdsUpdate() { // Send two calls, then the first call should delayed and the second call should not be delayed // because maxActiveFaults is exceeded. ClientCall.Listener observer1 = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); assertThat(testCall).isNull(); ClientCall.Listener observer2 = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcSucceeded(observer2); verifyRpcDelayed(observer1, 5000L); // Once all calls are finished, new call should be delayed. ClientCall.Listener observer3 = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcDelayed(observer3, 5000L); } @@ -1646,12 +1731,12 @@ public long nanoTime() { } }; ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT.withDeadline( + configSelector, Collections.emptyMap(), CallOptions.DEFAULT.withDeadline( Deadline.after(4000, TimeUnit.NANOSECONDS, fakeTicker))); assertThat(testCall).isNull(); verifyRpcDelayedThenAborted(observer, 4000L, Status.DEADLINE_EXCEEDED.withDescription( "Deadline exceeded after up to 5000 ns of fault-injected delay:" - + " Deadline exceeded after 0.000004000s. ")); + + " Deadline CallOptions will be exceeded in 0.000004000s. ")); } @Test @@ -1671,7 +1756,7 @@ public void resolved_faultAbortAndDelayInLdsUpdateInLdsUpdate() { ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcDelayedThenAborted( observer, 5000L, Status.UNAUTHENTICATED.withDescription( @@ -1700,7 +1785,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT); + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcFailed( observer, Status.INTERNAL.withDescription("RPC terminated due to fault injection")); @@ -1715,7 +1800,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), CallOptions.DEFAULT); + Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcFailed( observer, Status.UNKNOWN.withDescription("RPC terminated due to fault injection")); @@ -1732,7 +1817,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, - Collections.emptyMap(), CallOptions.DEFAULT); + Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcFailed( observer, Status.UNAVAILABLE.withDescription("RPC terminated due to fault injection")); } @@ -1761,7 +1846,7 @@ public void resolved_faultConfigOverrideInLdsAndInRdsUpdate() { ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), - configSelector, Collections.emptyMap(), CallOptions.DEFAULT);; + configSelector, Collections.emptyMap(), CallOptions.DEFAULT); verifyRpcFailed( observer, Status.UNKNOWN.withDescription("RPC terminated due to fault injection")); } @@ -1822,7 +1907,7 @@ public void routeMatching_pathOnly() { RouteMatch routeMatch1 = RouteMatch.create( PathMatcher.fromPath("/FooService/barMethod", true), - Collections.emptyList(), null); + Collections.emptyList(), null); assertThat(XdsNameResolver.matchRoute(routeMatch1, "/FooService/barMethod", headers, random)) .isTrue(); assertThat(XdsNameResolver.matchRoute(routeMatch1, "/FooService/bazMethod", headers, random)) @@ -1831,7 +1916,7 @@ public void routeMatching_pathOnly() { RouteMatch routeMatch2 = RouteMatch.create( PathMatcher.fromPrefix("/FooService/", true), - Collections.emptyList(), null); + Collections.emptyList(), null); assertThat(XdsNameResolver.matchRoute(routeMatch2, "/FooService/barMethod", headers, random)) .isTrue(); assertThat(XdsNameResolver.matchRoute(routeMatch2, "/FooService/bazMethod", headers, random)) @@ -1842,7 +1927,7 @@ public void routeMatching_pathOnly() { RouteMatch routeMatch3 = RouteMatch.create( PathMatcher.fromRegEx(Pattern.compile(".*Foo.*")), - Collections.emptyList(), null); + Collections.emptyList(), null); assertThat(XdsNameResolver.matchRoute(routeMatch3, "/FooService/barMethod", headers, random)) .isTrue(); } @@ -1855,14 +1940,14 @@ public void routeMatching_pathOnly_caseInsensitive() { RouteMatch routeMatch1 = RouteMatch.create( PathMatcher.fromPath("/FooService/barMethod", false), - Collections.emptyList(), null); + Collections.emptyList(), null); assertThat(XdsNameResolver.matchRoute(routeMatch1, "/fooservice/barmethod", headers, random)) .isTrue(); RouteMatch routeMatch2 = RouteMatch.create( PathMatcher.fromPrefix("/FooService", false), - Collections.emptyList(), null); + Collections.emptyList(), null); assertThat(XdsNameResolver.matchRoute(routeMatch2, "/fooservice/barmethod", headers, random)) .isTrue(); } @@ -1989,8 +2074,8 @@ private class FakeXdsClient extends XdsClient { // Should never be subscribing to more than one LDS and RDS resource at any point of time. private String ldsResource; // should always be AUTHORITY private String rdsResource; - private LdsResourceWatcher ldsWatcher; - private RdsResourceWatcher rdsWatcher; + private ResourceWatcher ldsWatcher; + private ResourceWatcher rdsWatcher; @Override BootstrapInfo getBootstrapInfo() { @@ -1998,37 +2083,49 @@ BootstrapInfo getBootstrapInfo() { } @Override - void watchLdsResource(String resourceName, LdsResourceWatcher watcher) { - assertThat(ldsResource).isNull(); - assertThat(ldsWatcher).isNull(); - assertThat(resourceName).isEqualTo(expectedLdsResourceName); - ldsResource = resourceName; - ldsWatcher = watcher; - } - - @Override - void cancelLdsResourceWatch(String resourceName, LdsResourceWatcher watcher) { - assertThat(ldsResource).isNotNull(); - assertThat(ldsWatcher).isNotNull(); - assertThat(resourceName).isEqualTo(expectedLdsResourceName); - ldsResource = null; - ldsWatcher = null; - } - - @Override - void watchRdsResource(String resourceName, RdsResourceWatcher watcher) { - assertThat(rdsResource).isNull(); - assertThat(rdsWatcher).isNull(); - rdsResource = resourceName; - rdsWatcher = watcher; + @SuppressWarnings("unchecked") + void watchXdsResource(XdsResourceType resourceType, + String resourceName, + ResourceWatcher watcher) { + + switch (resourceType.typeName()) { + case "LDS": + assertThat(ldsResource).isNull(); + assertThat(ldsWatcher).isNull(); + assertThat(resourceName).isEqualTo(expectedLdsResourceName); + ldsResource = resourceName; + ldsWatcher = (ResourceWatcher) watcher; + break; + case "RDS": + assertThat(rdsResource).isNull(); + assertThat(rdsWatcher).isNull(); + rdsResource = resourceName; + rdsWatcher = (ResourceWatcher) watcher; + break; + default: + } } @Override - void cancelRdsResourceWatch(String resourceName, RdsResourceWatcher watcher) { - assertThat(rdsResource).isNotNull(); - assertThat(rdsWatcher).isNotNull(); - rdsResource = null; - rdsWatcher = null; + void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { + switch (type.typeName()) { + case "LDS": + assertThat(ldsResource).isNotNull(); + assertThat(ldsWatcher).isNotNull(); + assertThat(resourceName).isEqualTo(expectedLdsResourceName); + ldsResource = null; + ldsWatcher = null; + break; + case "RDS": + assertThat(rdsResource).isNotNull(); + assertThat(rdsWatcher).isNotNull(); + rdsResource = null; + rdsWatcher = null; + break; + default: + } } void deliverLdsUpdate(long httpMaxStreamDurationNano, List virtualHosts) { @@ -2040,7 +2137,7 @@ void deliverLdsUpdate(final List routes) { VirtualHost virtualHost = VirtualHost.create( "virtual-host", Collections.singletonList(expectedLdsResourceName), routes, - ImmutableMap.of()); + ImmutableMap.of()); ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( 0L, Collections.singletonList(virtualHost), null))); } @@ -2058,28 +2155,28 @@ void deliverLdsUpdateWithFaultInjection( new NamedFilterConfig(FAULT_FILTER_INSTANCE_NAME, httpFilterFaultConfig), new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); ImmutableMap overrideConfig = weightedClusterFaultConfig == null - ? ImmutableMap.of() - : ImmutableMap.of( + ? ImmutableMap.of() + : ImmutableMap.of( FAULT_FILTER_INSTANCE_NAME, weightedClusterFaultConfig); ClusterWeight clusterWeight = ClusterWeight.create( cluster, 100, overrideConfig); overrideConfig = routeFaultConfig == null - ? ImmutableMap.of() - : ImmutableMap.of(FAULT_FILTER_INSTANCE_NAME, routeFaultConfig); + ? ImmutableMap.of() + : ImmutableMap.of(FAULT_FILTER_INSTANCE_NAME, routeFaultConfig); Route route = Route.forAction( RouteMatch.create( - PathMatcher.fromPrefix("/", false), Collections.emptyList(), null), + PathMatcher.fromPrefix("/", false), Collections.emptyList(), null), RouteAction.forWeightedClusters( Collections.singletonList(clusterWeight), - Collections.emptyList(), + Collections.emptyList(), null, null), overrideConfig); overrideConfig = virtualHostFaultConfig == null - ? ImmutableMap.of() - : ImmutableMap.of( + ? ImmutableMap.of() + : ImmutableMap.of( FAULT_FILTER_INSTANCE_NAME, virtualHostFaultConfig); VirtualHost virtualHost = VirtualHost.create( "virtual-host", @@ -2119,26 +2216,26 @@ void deliverRdsUpdateWithFaultInjection( return; } ImmutableMap overrideConfig = weightedClusterFaultConfig == null - ? ImmutableMap.of() - : ImmutableMap.of( + ? ImmutableMap.of() + : ImmutableMap.of( FAULT_FILTER_INSTANCE_NAME, weightedClusterFaultConfig); ClusterWeight clusterWeight = ClusterWeight.create(cluster1, 100, overrideConfig); overrideConfig = routFaultConfig == null - ? ImmutableMap.of() - : ImmutableMap.of(FAULT_FILTER_INSTANCE_NAME, routFaultConfig); + ? ImmutableMap.of() + : ImmutableMap.of(FAULT_FILTER_INSTANCE_NAME, routFaultConfig); Route route = Route.forAction( RouteMatch.create( - PathMatcher.fromPrefix("/", false), Collections.emptyList(), null), + PathMatcher.fromPrefix("/", false), Collections.emptyList(), null), RouteAction.forWeightedClusters( Collections.singletonList(clusterWeight), - Collections.emptyList(), + Collections.emptyList(), null, null), overrideConfig); overrideConfig = virtualHostFaultConfig == null - ? ImmutableMap.of() - : ImmutableMap.of( + ? ImmutableMap.of() + : ImmutableMap.of( FAULT_FILTER_INSTANCE_NAME, virtualHostFaultConfig); VirtualHost virtualHost = VirtualHost.create( "virtual-host", diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index 36669537255..3ef23c11375 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -17,15 +17,15 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_SERVER_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -57,13 +57,13 @@ import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; -import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.Matchers.HeaderMatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.grpc.xds.internal.sds.TlsContextManagerImpl; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SslContextProviderSupplier; +import io.grpc.xds.internal.security.TlsContextManagerImpl; import io.netty.handler.ssl.NotSslRecordException; import java.net.Inet4Address; import java.net.InetSocketAddress; diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index d67ed9d09fc..32cb3eb418f 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -35,7 +35,7 @@ import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -103,7 +103,7 @@ private void verifyServer( assertThat(socketAddress.getPort()).isGreaterThan(-1); if (mockXdsServingStatusListener != null) { if (notServingStatus != null) { - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Throwable.class); verify(mockXdsServingStatusListener, times(1)).onNotServing(argCaptor.capture()); Throwable throwable = argCaptor.getValue(); assertThat(throwable).isInstanceOf(StatusException.class); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 15868ba414e..256e3f61fec 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -30,7 +30,8 @@ import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.VirtualHost.Route; -import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -54,7 +55,7 @@ public class XdsServerTestHelper { Bootstrapper.BootstrapInfo.builder() .servers(Arrays.asList( Bootstrapper.ServerInfo.create( - SERVER_URI, InsecureChannelCredentials.create(), true))) + SERVER_URI, InsecureChannelCredentials.create()))) .node(BOOTSTRAP_NODE) .serverListenerResourceNameTemplate("grpc/server?udpa.resource.listening_address=%s") .build(); @@ -163,9 +164,9 @@ public XdsClient returnObject(Object object) { static final class FakeXdsClient extends XdsClient { boolean shutdown; SettableFuture ldsResource = SettableFuture.create(); - LdsResourceWatcher ldsWatcher; + ResourceWatcher ldsWatcher; CountDownLatch rdsCount = new CountDownLatch(1); - final Map rdsWatchers = new HashMap<>(); + final Map> rdsWatchers = new HashMap<>(); @Override public TlsContextManager getTlsContextManager() { @@ -178,28 +179,40 @@ public BootstrapInfo getBootstrapInfo() { } @Override - void watchLdsResource(String resourceName, LdsResourceWatcher watcher) { - assertThat(ldsWatcher).isNull(); - ldsWatcher = watcher; - ldsResource.set(resourceName); + @SuppressWarnings("unchecked") + void watchXdsResource(XdsResourceType resourceType, + String resourceName, + ResourceWatcher watcher) { + switch (resourceType.typeName()) { + case "LDS": + assertThat(ldsWatcher).isNull(); + ldsWatcher = (ResourceWatcher) watcher; + ldsResource.set(resourceName); + break; + case "RDS": + //re-register is not allowed. + assertThat(rdsWatchers.put(resourceName, (ResourceWatcher)watcher)).isNull(); + rdsCount.countDown(); + break; + default: + } } @Override - void cancelLdsResourceWatch(String resourceName, LdsResourceWatcher watcher) { - assertThat(ldsWatcher).isNotNull(); - ldsResource = null; - ldsWatcher = null; - } - - @Override - void watchRdsResource(String resourceName, RdsResourceWatcher watcher) { - assertThat(rdsWatchers.put(resourceName, watcher)).isNull(); //re-register is not allowed. - rdsCount.countDown(); - } - - @Override - void cancelRdsResourceWatch(String resourceName, RdsResourceWatcher watcher) { - rdsWatchers.remove(resourceName); + void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { + switch (type.typeName()) { + case "LDS": + assertThat(ldsWatcher).isNotNull(); + ldsResource = null; + ldsWatcher = null; + break; + case "RDS": + rdsWatchers.remove(resourceName); + break; + default: + } } @Override @@ -213,7 +226,7 @@ boolean isShutDown() { } void deliverLdsUpdate(List filterChains, - FilterChain defaultFilterChain) { + FilterChain defaultFilterChain) { ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create( "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index f8b8ca2e105..6271ca791c6 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -26,6 +26,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -56,17 +57,16 @@ import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; -import io.grpc.xds.XdsClient.LdsResourceWatcher; -import io.grpc.xds.XdsClient.RdsResourceWatcher; -import io.grpc.xds.XdsClient.RdsUpdate; +import io.grpc.xds.XdsClient.ResourceWatcher; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.XdsServerWrapper.ConfigApplyingInterceptor; import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; import io.grpc.xds.internal.Matchers.HeaderMatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -91,6 +91,8 @@ @RunWith(JUnit4.class) public class XdsServerWrapperTest { + private static final int START_WAIT_AFTER_LISTENER_MILLIS = 100; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -123,15 +125,34 @@ public void tearDown() { } @Test - public void testBootstrap_notV3() throws Exception { + @SuppressWarnings("unchecked") + public void testBootstrap() throws Exception { Bootstrapper.BootstrapInfo b = Bootstrapper.BootstrapInfo.builder() .servers(Arrays.asList( - Bootstrapper.ServerInfo.create("uri", InsecureChannelCredentials.create(), false))) + Bootstrapper.ServerInfo.create("uri", InsecureChannelCredentials.create()))) .node(EnvoyProtoData.Node.newBuilder().setId("id").build()) .serverListenerResourceNameTemplate("grpc/server?udpa.resource.listening_address=%s") .build(); - verifyBootstrapFail(b); + XdsClient xdsClient = mock(XdsClient.class); + XdsListenerResource listenerResource = XdsListenerResource.getInstance(); + when(xdsClient.getBootstrapInfo()).thenReturn(b); + xdsServerWrapper = new XdsServerWrapper("[::FFFF:129.144.52.38]:80", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + xdsServerWrapper.start(); + } catch (IOException ex) { + // ignore + } + } + }); + verify(xdsClient, timeout(5000)).watchXdsResource( + eq(listenerResource), + eq("grpc/server?udpa.resource.listening_address=[::FFFF:129.144.52.38]:80"), + any(ResourceWatcher.class)); } @Test @@ -139,7 +160,7 @@ public void testBootstrap_noTemplate() throws Exception { Bootstrapper.BootstrapInfo b = Bootstrapper.BootstrapInfo.builder() .servers(Arrays.asList( - Bootstrapper.ServerInfo.create("uri", InsecureChannelCredentials.create(), true))) + Bootstrapper.ServerInfo.create("uri", InsecureChannelCredentials.create()))) .node(EnvoyProtoData.Node.newBuilder().setId("id").build()) .build(); verifyBootstrapFail(b); @@ -174,16 +195,18 @@ public void run() { } @Test + @SuppressWarnings("unchecked") public void testBootstrap_templateWithXdstp() throws Exception { Bootstrapper.BootstrapInfo b = Bootstrapper.BootstrapInfo.builder() .servers(Arrays.asList( Bootstrapper.ServerInfo.create( - "uri", InsecureChannelCredentials.create(), true))) + "uri", InsecureChannelCredentials.create()))) .node(EnvoyProtoData.Node.newBuilder().setId("id").build()) .serverListenerResourceNameTemplate( "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/grpc/server/%s") .build(); XdsClient xdsClient = mock(XdsClient.class); + XdsListenerResource listenerResource = XdsListenerResource.getInstance(); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("[::FFFF:129.144.52.38]:80", mockBuilder, listener, selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); @@ -197,10 +220,11 @@ public void run() { } } }); - verify(xdsClient, timeout(5000)).watchLdsResource( + verify(xdsClient, timeout(5000)).watchXdsResource( + eq(listenerResource), eq("xdstp://xds.authority.com/envoy.config.listener.v3.Listener/grpc/server/" + "%5B::FFFF:129.144.52.38%5D:80"), - any(LdsResourceWatcher.class)); + any(ResourceWatcher.class)); } @Test @@ -227,7 +251,8 @@ public void run() { xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); - start.get(5000, TimeUnit.MILLISECONDS); + verify(listener, timeout(5000)).onServing(); + start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); verify(mockServer).start(); xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); @@ -296,8 +321,9 @@ public void run() { }); String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + verify(listener, timeout(5000)).onNotServing(any()); try { - start.get(5000, TimeUnit.MILLISECONDS); + start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); fail("server should not start() successfully."); } catch (TimeoutException ex) { // expect to block here. @@ -453,6 +479,47 @@ public void run() { verify(mockServer).start(); } + @Test + public void discoverState_restart_afterResourceNotExist() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + start.get(5000, TimeUnit.MILLISECONDS); + verify(listener).onServing(); + verify(mockServer).start(); + + // server shutdown after resourceDoesNotExist + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + verify(mockServer).shutdown(); + + // re-deliver lds resource + reset(mockServer); + reset(listener); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + verify(listener).onServing(); + verify(mockServer).start(); + } + @Test public void discoverState_rds() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -665,8 +732,9 @@ public void run() { }); String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + verify(listener, timeout(5000)).onNotServing(any()); try { - start.get(5000, TimeUnit.MILLISECONDS); + start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); fail("server should not start()"); } catch (TimeoutException ex) { // expect to block here. @@ -680,7 +748,7 @@ public void run() { xdsClient.ldsWatcher.onError(Status.INTERNAL); assertThat(selectorManager.getSelectorToUpdateSelector()) .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); - assertThat(xdsClient.rdsWatchers).isEmpty(); + ResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); verify(mockBuilder, times(1)).build(); verify(listener, times(2)).onNotServing(any(StatusException.class)); assertThat(sslSupplier0.isShutdown()).isFalse(); @@ -700,7 +768,6 @@ public void run() { assertThat(ex.getCause()).isInstanceOf(IOException.class); assertThat(ex.getCause().getMessage()).isEqualTo("error!"); } - RdsResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); assertThat(executor.forwardNanos(RETRY_DELAY_NANOS)).isEqualTo(1); verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); @@ -734,7 +801,7 @@ public void run() { // not serving after serving xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); assertThat(xdsClient.rdsWatchers).isEmpty(); - verify(mockServer, times(3)).shutdown(); + verify(mockServer, times(2)).shutdown(); when(mockServer.isShutdown()).thenReturn(true); assertThat(selectorManager.getSelectorToUpdateSelector()) .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); @@ -772,8 +839,9 @@ public void run() { assertThat(executor.numPendingTasks()).isEqualTo(1); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); - verify(mockServer, times(4)).shutdown(); + verify(mockServer, times(3)).shutdown(); verify(listener, times(4)).onNotServing(any(StatusException.class)); + verify(listener, times(1)).onNotServing(any(IOException.class)); when(mockServer.isShutdown()).thenReturn(true); assertThat(executor.numPendingTasks()).isEqualTo(0); assertThat(sslSupplier2.isShutdown()).isTrue(); @@ -799,7 +867,7 @@ public void run() { assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); xdsServerWrapper.shutdown(); - verify(mockServer, times(5)).shutdown(); + verify(mockServer, times(4)).shutdown(); assertThat(sslSupplier3.isShutdown()).isTrue(); when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); diff --git a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java new file mode 100644 index 00000000000..9a1ce3350de --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java @@ -0,0 +1,199 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package io.grpc.xds; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; +import io.grpc.SynchronizationContext; +import io.grpc.stub.StreamObserver; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** +* A bidi-stream service that acts as a local xDS Control Plane. + * It accepts xDS config injection through a method call {@link #setXdsConfig}. Handling AdsStream + * response or updating xds config are run in syncContext. + * + *

    The service maintains lookup tables: + * Subscriber table: map from each resource type, to a map from each client to subscribed resource + * names set. + * Resources table: store the resources in raw proto message. + * + *

    xDS protocol requires version/nonce to avoid various race conditions. In this impl: + * Version stores the latest version number per each resource type. It is simply bumped up on each + * xds config set. + * Nonce stores the nonce number for each resource type and for each client. Incoming xDS requests + * share the same proto message type but may at different resources update phases: + * 1) Original: an initial xDS request. + * 2) NACK an xDS response. + * 3) ACK an xDS response. + * The service is capable of distinguish these cases when handling the request. + */ +final class XdsTestControlPlaneService extends + AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase { + private static final Logger logger = Logger.getLogger(XdsTestControlPlaneService.class.getName()); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + static final String ADS_TYPE_URL_LDS = + "type.googleapis.com/envoy.config.listener.v3.Listener"; + static final String ADS_TYPE_URL_RDS = + "type.googleapis.com/envoy.config.route.v3.RouteConfiguration"; + static final String ADS_TYPE_URL_CDS = + "type.googleapis.com/envoy.config.cluster.v3.Cluster"; + static final String ADS_TYPE_URL_EDS = + "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + + private final Map> xdsResources = new HashMap<>(); + private ImmutableMap, Set>> subscribers + = ImmutableMap.of( + ADS_TYPE_URL_LDS, new HashMap, Set>(), + ADS_TYPE_URL_RDS, new HashMap, Set>(), + ADS_TYPE_URL_CDS, new HashMap, Set>(), + ADS_TYPE_URL_EDS, new HashMap, Set>() + ); + private final ImmutableMap xdsVersions = ImmutableMap.of( + ADS_TYPE_URL_LDS, new AtomicInteger(1), + ADS_TYPE_URL_RDS, new AtomicInteger(1), + ADS_TYPE_URL_CDS, new AtomicInteger(1), + ADS_TYPE_URL_EDS, new AtomicInteger(1) + ); + private final ImmutableMap, AtomicInteger>> + xdsNonces = ImmutableMap.of( + ADS_TYPE_URL_LDS, new HashMap, AtomicInteger>(), + ADS_TYPE_URL_RDS, new HashMap, AtomicInteger>(), + ADS_TYPE_URL_CDS, new HashMap, AtomicInteger>(), + ADS_TYPE_URL_EDS, new HashMap, AtomicInteger>() + ); + + + // treat all the resource types as state-of-the-world, send back all resources of a particular + // type when any of them change. + public void setXdsConfig(final String type, final Map resources) { + logger.log(Level.FINE, "setting config {0} {1}", new Object[]{type, resources}); + syncContext.execute(new Runnable() { + @Override + public void run() { + HashMap copyResources = new HashMap<>(resources); + xdsResources.put(type, copyResources); + String newVersionInfo = String.valueOf(xdsVersions.get(type).getAndDecrement()); + + for (Map.Entry, Set> entry : + subscribers.get(type).entrySet()) { + DiscoveryResponse response = generateResponse(type, newVersionInfo, + String.valueOf(xdsNonces.get(type).get(entry.getKey()).incrementAndGet()), + entry.getValue()); + entry.getKey().onNext(response); + } + } + }); + } + + @Override + public StreamObserver streamAggregatedResources( + final StreamObserver responseObserver) { + final StreamObserver requestObserver = + new StreamObserver() { + @Override + public void onNext(final DiscoveryRequest value) { + syncContext.execute(new Runnable() { + @Override + public void run() { + logger.log(Level.FINEST, "control plane received request {0}", value); + if (value.hasErrorDetail()) { + logger.log(Level.FINE, "control plane received nack resource {0}, error {1}", + new Object[]{value.getResourceNamesList(), value.getErrorDetail()}); + return; + } + String resourceType = value.getTypeUrl(); + if (!value.getResponseNonce().isEmpty() + && !String.valueOf(xdsNonces.get(resourceType)).equals(value.getResponseNonce())) { + logger.log(Level.FINE, "Resource nonce does not match, ignore."); + return; + } + Set requestedResourceNames = new HashSet<>(value.getResourceNamesList()); + if (subscribers.get(resourceType).containsKey(responseObserver) + && subscribers.get(resourceType).get(responseObserver) + .equals(requestedResourceNames)) { + logger.log(Level.FINEST, "control plane received ack for resource: {0}", + value.getResourceNamesList()); + return; + } + if (!xdsNonces.get(resourceType).containsKey(responseObserver)) { + xdsNonces.get(resourceType).put(responseObserver, new AtomicInteger(0)); + } + DiscoveryResponse response = generateResponse(resourceType, + String.valueOf(xdsVersions.get(resourceType)), + String.valueOf(xdsNonces.get(resourceType).get(responseObserver)), + requestedResourceNames); + responseObserver.onNext(response); + subscribers.get(resourceType).put(responseObserver, requestedResourceNames); + } + }); + } + + @Override + public void onError(Throwable t) { + logger.log(Level.FINE, "Control plane error: {0} ", t); + onCompleted(); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + for (String type : subscribers.keySet()) { + subscribers.get(type).remove(responseObserver); + xdsNonces.get(type).remove(responseObserver); + } + } + }; + return requestObserver; + } + + //must run in syncContext + private DiscoveryResponse generateResponse(String resourceType, String version, String nonce, + Set resourceNames) { + DiscoveryResponse.Builder responseBuilder = DiscoveryResponse.newBuilder() + .setTypeUrl(resourceType) + .setVersionInfo(version) + .setNonce(nonce); + for (String resourceName: resourceNames) { + if (xdsResources.containsKey(resourceType) + && xdsResources.get(resourceType).containsKey(resourceName)) { + responseBuilder.addResources(Any.pack(xdsResources.get(resourceType).get(resourceName), + resourceType)); + } + } + return responseBuilder.build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/GoogleDefaultXdsCredentialsProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/GoogleDefaultXdsCredentialsProviderTest.java new file mode 100644 index 00000000000..dd615809bc2 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/GoogleDefaultXdsCredentialsProviderTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.CompositeChannelCredentials; +import io.grpc.InternalServiceProviders; +import io.grpc.xds.XdsCredentialsProvider; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link GoogleDefaultXdsCredentialsProvider}. */ +@RunWith(JUnit4.class) +public class GoogleDefaultXdsCredentialsProviderTest { + private GoogleDefaultXdsCredentialsProvider provider = new GoogleDefaultXdsCredentialsProvider(); + + @Test + public void provided() { + for (XdsCredentialsProvider current + : InternalServiceProviders.getCandidatesViaServiceLoader( + XdsCredentialsProvider.class, getClass().getClassLoader())) { + if (current instanceof GoogleDefaultXdsCredentialsProvider) { + return; + } + } + fail("ServiceLoader unable to load GoogleDefaultXdsCredentialsProvider"); + } + + @Test + public void isAvailable() { + assertTrue(provider.isAvailable()); + } + + @Test + public void channelCredentials() { + assertSame(CompositeChannelCredentials.class, + provider.newChannelCredentials(null).getClass()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/InsecureXdsCredentialsProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/InsecureXdsCredentialsProviderTest.java new file mode 100644 index 00000000000..583255473eb --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/InsecureXdsCredentialsProviderTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.InsecureChannelCredentials; +import io.grpc.InternalServiceProviders; +import io.grpc.xds.XdsCredentialsProvider; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InsecureXdsCredentialsProvider}. */ +@RunWith(JUnit4.class) +public class InsecureXdsCredentialsProviderTest { + private InsecureXdsCredentialsProvider provider = new InsecureXdsCredentialsProvider(); + + @Test + public void provided() { + for (XdsCredentialsProvider current + : InternalServiceProviders.getCandidatesViaServiceLoader( + XdsCredentialsProvider.class, getClass().getClassLoader())) { + if (current instanceof InsecureXdsCredentialsProvider) { + return; + } + } + fail("ServiceLoader unable to load InsecureXdsCredentialsProvider"); + } + + @Test + public void isAvailable() { + assertTrue(provider.isAvailable()); + } + + @Test + public void channelCredentials() { + assertSame(InsecureChannelCredentials.class, + provider.newChannelCredentials(null).getClass()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java index 93a9b7087d6..4e5d278d50f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java @@ -163,6 +163,15 @@ public void headerMatcher() { assertThat(matcher.matches("1v2")).isFalse(); assertThat(matcher.matches(null)).isFalse(); + matcher = HeaderMatcher.forContains("version", "v1", false); + assertThat(matcher.matches("xv1")).isTrue(); + assertThat(matcher.matches("1vx")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); + matcher = HeaderMatcher.forContains("version", "v1", true); + assertThat(matcher.matches("xv1")).isFalse(); + assertThat(matcher.matches("1vx")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); + matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v2.*"), false); assertThat(matcher.matches("v2..")).isTrue(); assertThat(matcher.matches("v1")).isFalse(); @@ -180,5 +189,14 @@ public void headerMatcher() { assertThat(matcher.matches("1")).isTrue(); assertThat(matcher.matches("8080")).isFalse(); assertThat(matcher.matches(null)).isFalse(); + + matcher = HeaderMatcher.forString("version", StringMatcher.forExact("v1", true), false); + assertThat(matcher.matches("v1")).isTrue(); + assertThat(matcher.matches("v1x")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); + matcher = HeaderMatcher.forString("version", StringMatcher.forExact("v1", true), true); + assertThat(matcher.matches("v1x")).isTrue(); + assertThat(matcher.matches("v1")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java new file mode 100644 index 00000000000..3ba26bdb281 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/TlsXdsCredentialsProviderTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.InternalServiceProviders; +import io.grpc.TlsChannelCredentials; +import io.grpc.xds.XdsCredentialsProvider; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TlsXdsCredentialsProvider}. */ +@RunWith(JUnit4.class) +public class TlsXdsCredentialsProviderTest { + private TlsXdsCredentialsProvider provider = new TlsXdsCredentialsProvider(); + + @Test + public void provided() { + for (XdsCredentialsProvider current + : InternalServiceProviders.getCandidatesViaServiceLoader( + XdsCredentialsProvider.class, getClass().getClassLoader())) { + if (current instanceof TlsXdsCredentialsProvider) { + return; + } + } + fail("ServiceLoader unable to load TlsXdsCredentialsProvider"); + } + + @Test + public void isAvailable() { + assertTrue(provider.isAvailable()); + } + + @Test + public void channelCredentials() { + assertSame(TlsChannelCredentials.class, + provider.newChannelCredentials(null).getClass()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java index 410ffb1b462..44b3407ba0a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java @@ -49,7 +49,6 @@ import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PathMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PolicyMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.SourceIpMatcher; - import java.net.InetAddress; import java.net.InetSocketAddress; import java.security.Principal; diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java similarity index 89% rename from xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index adc96a36336..4f85afc2ead 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -32,12 +32,12 @@ import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.XdsInitializationException; -import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider; -import io.grpc.xds.internal.certprovider.CertificateProvider; -import io.grpc.xds.internal.certprovider.CertificateProviderProvider; -import io.grpc.xds.internal.certprovider.CertificateProviderRegistry; -import io.grpc.xds.internal.certprovider.CertificateProviderStore; -import io.grpc.xds.internal.certprovider.TestCertificateProvider; +import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; +import io.grpc.xds.internal.security.certprovider.CertificateProvider; +import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; +import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; +import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; import java.io.IOException; import org.junit.Assert; import org.junit.Before; @@ -53,7 +53,7 @@ public class ClientSslContextProviderFactoryTest { CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; - CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory; + CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory; ClientSslContextProviderFactory clientSslContextProviderFactory; @Before @@ -61,7 +61,7 @@ public void setUp() { certificateProviderRegistry = new CertificateProviderRegistry(); certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); certProviderClientSslContextProviderFactory = - new CertProviderClientSslContextProvider.Factory(certificateProviderStore); + new CertProviderClientSslContextProviderFactory(certificateProviderStore); } @Test @@ -84,12 +84,14 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); // verify that bootstrapInfo is cached... sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); } @Test @@ -117,7 +119,8 @@ public void bothPresent_expectCertProviderClientSslContextProvider() bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } @@ -142,7 +145,8 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } @@ -152,6 +156,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + @SuppressWarnings("deprecation") CertificateValidationContext staticCertValidationContext = CertificateValidationContext.newBuilder() .addAllMatchSubjectAltNames( @@ -174,7 +179,8 @@ public void createCertProviderClientSslContextProvider_withStaticContext() certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } @@ -203,7 +209,8 @@ public void createCertProviderClientSslContextProvider_2providers() bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); verifyWatcher(sslContextProvider, watcherCaptor[1]); } @@ -216,6 +223,7 @@ public void createNewCertProviderClientSslContextProvider_withSans() { createAndRegisterProviderProvider( certificateProviderRegistry, watcherCaptor, "file_watcher", 1); + @SuppressWarnings("deprecation") CertificateValidationContext staticCertValidationContext = CertificateValidationContext.newBuilder() .addAllMatchSubjectAltNames( @@ -238,7 +246,8 @@ public void createNewCertProviderClientSslContextProvider_withSans() { bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); verifyWatcher(sslContextProvider, watcherCaptor[1]); } @@ -248,6 +257,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + @SuppressWarnings("deprecation") CertificateValidationContext staticCertValidationContext = CertificateValidationContext.newBuilder() .addAllMatchSubjectAltNames( @@ -270,7 +280,8 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java similarity index 99% rename from xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java rename to xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 840cced424f..728bb06efec 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; @@ -32,7 +32,7 @@ import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.internal.testing.TestUtils; import io.grpc.xds.EnvoyServerProtoData; -import io.grpc.xds.internal.sds.trust.CertificateUtils; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.handler.ssl.SslContext; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -424,7 +424,7 @@ public TestCallback(Executor executor) { } @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingMapTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ReferenceCountingMapTest.java similarity index 97% rename from xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingMapTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/ReferenceCountingMapTest.java index b94aefd2151..d54a61b5510 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ReferenceCountingMapTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ReferenceCountingMapTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; @@ -24,7 +24,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; import org.junit.Before; import org.junit.Rule; import org.junit.Test; diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java similarity index 89% rename from xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 502d2185a82..863e4dcfecf 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -14,15 +14,15 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -51,9 +51,9 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.TlsContextManager; -import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils; -import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsHandler; -import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsProtocolNegotiator; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSdsHandler; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSdsProtocolNegotiator; +import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -83,9 +83,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link SdsProtocolNegotiators}. */ +/** Unit tests for {@link SecurityProtocolNegotiators}. */ @RunWith(JUnit4.class) -public class SdsProtocolNegotiatorsTest { +public class SecurityProtocolNegotiatorsTest { private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -156,8 +156,8 @@ public void clientSdsHandler_addLast() SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); - SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = - new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); + SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler = + new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); pipeline.addLast(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSdsHandler); assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it @@ -168,7 +168,7 @@ public void clientSdsHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -221,8 +221,8 @@ public SocketAddress remoteAddress() { "google_cloud_private_spiffe-server", true, true); TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer); - SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, + SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = + new SecurityProtocolNegotiators.HandlerPickerHandler(grpcHandler, InternalProtocolNegotiators.serverPlaintext()); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); @@ -236,7 +236,7 @@ public SocketAddress remoteAddress() { pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr)); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNull(); - channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class); + channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSdsHandler.class); assertThat(channelHandlerCtx).isNotNull(); SslContextProviderSupplier sslContextProviderSupplier = @@ -245,7 +245,7 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -259,7 +259,7 @@ protected void onException(Throwable throwable) { Object fromFuture = future.get(2, TimeUnit.SECONDS); assertThat(fromFuture).isInstanceOf(SslContext.class); channel.runPendingTasks(); - channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class); + channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSdsHandler.class); assertThat(channelHandlerCtx).isNull(); // pipeline should only have SslHandler and ServerTlsHandler @@ -287,8 +287,8 @@ public SocketAddress localAddress() { }; pipeline = channel.pipeline(); - SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler( + SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = + new SecurityProtocolNegotiators.HandlerPickerHandler( grpcHandler, mockProtocolNegotiator); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); @@ -313,8 +313,8 @@ public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { ChannelHandler mockChannelHandler = mock(ChannelHandler.class); ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class); when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler); - SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler( + SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = + new SecurityProtocolNegotiators.HandlerPickerHandler( grpcHandler, mockProtocolNegotiator); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); @@ -333,8 +333,8 @@ public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { @Test public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { - SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler( + SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = + new SecurityProtocolNegotiators.HandlerPickerHandler( grpcHandler, null); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); @@ -368,8 +368,8 @@ public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); - SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = - new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); + SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler = + new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); pipeline.addLast(clientSdsHandler); channelHandlerCtx = pipeline.context(clientSdsHandler); @@ -381,7 +381,7 @@ public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java similarity index 86% rename from xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java index 7623b614001..07648194f72 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java @@ -14,11 +14,11 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider; -import static io.grpc.xds.internal.sds.ClientSslContextProviderFactoryTest.verifyWatcher; +import static io.grpc.xds.internal.security.ClientSslContextProviderFactoryTest.createAndRegisterProviderProvider; +import static io.grpc.xds.internal.security.ClientSslContextProviderFactoryTest.verifyWatcher; import com.google.common.collect.ImmutableSet; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -29,10 +29,10 @@ import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.XdsInitializationException; -import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider; -import io.grpc.xds.internal.certprovider.CertificateProvider; -import io.grpc.xds.internal.certprovider.CertificateProviderRegistry; -import io.grpc.xds.internal.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.CertProviderServerSslContextProviderFactory; +import io.grpc.xds.internal.security.certprovider.CertificateProvider; +import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; +import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -44,7 +44,7 @@ public class ServerSslContextProviderFactoryTest { CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; - CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory; + CertProviderServerSslContextProviderFactory certProviderServerSslContextProviderFactory; ServerSslContextProviderFactory serverSslContextProviderFactory; @Before @@ -52,7 +52,7 @@ public void setUp() { certificateProviderRegistry = new CertificateProviderRegistry(); certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); certProviderServerSslContextProviderFactory = - new CertProviderServerSslContextProvider.Factory(certificateProviderStore); + new CertProviderServerSslContextProviderFactory(certificateProviderStore); } @Test @@ -76,12 +76,14 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); // verify that bootstrapInfo is cached... sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); } @Test @@ -113,7 +115,8 @@ public void bothPresent_expectCertProviderServerSslContextProvider() bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } @@ -139,7 +142,8 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } @@ -149,6 +153,7 @@ public void createCertProviderServerSslContextProvider_withStaticContext() final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + @SuppressWarnings("deprecation") CertificateValidationContext staticCertValidationContext = CertificateValidationContext.newBuilder() .addAllMatchSubjectAltNames( @@ -172,7 +177,8 @@ public void createCertProviderServerSslContextProvider_withStaticContext() bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); } @@ -202,7 +208,8 @@ public void createCertProviderServerSslContextProvider_2providers() bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); verifyWatcher(sslContextProvider, watcherCaptor[1]); } @@ -215,6 +222,7 @@ public void createNewCertProviderServerSslContextProvider_withSans() createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); createAndRegisterProviderProvider( certificateProviderRegistry, watcherCaptor, "file_watcher", 1); + @SuppressWarnings("deprecation") CertificateValidationContext staticCertValidationContext = CertificateValidationContext.newBuilder() .addAllMatchSubjectAltNames( @@ -239,7 +247,8 @@ public void createNewCertProviderServerSslContextProvider_withSans() bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); - assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( + "CertProviderServerSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); verifyWatcher(sslContextProvider, watcherCaptor[1]); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java similarity index 94% rename from xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 19fd0e189c1..cdb534971d4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.eq; @@ -80,13 +80,14 @@ public void get_updateSecret() { .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(0)) .releaseClientSslContextProvider(any(SslContextProvider.class)); - ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSecret(mockSslContext); - verify(mockCallback, times(1)).updateSecret(eq(mockSslContext)); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); @@ -99,7 +100,8 @@ public void get_updateSecret() { public void get_onException() { prepareSupplier(); callUpdateSslContext(); - ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java similarity index 91% rename from xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 7634bfee376..4589e328d4a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -14,16 +14,16 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -34,7 +34,7 @@ import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; +import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java similarity index 91% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 111a44e3224..857d4b017c1 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -14,18 +14,18 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils.getCertFromResourceName; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils.getCertFromResourceName; import static org.junit.Assert.fail; import com.google.common.annotations.VisibleForTesting; @@ -36,8 +36,8 @@ import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil.TestCallback; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; @@ -56,14 +56,14 @@ public class CertProviderClientSslContextProviderTest { CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; - private CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory; + private CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory; @Before public void setUp() throws Exception { certificateProviderRegistry = new CertificateProviderRegistry(); certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); certProviderClientSslContextProviderFactory = - new CertProviderClientSslContextProvider.Factory(certificateProviderStore); + new CertProviderClientSslContextProviderFactory(certificateProviderStore); } /** Helper method to build CertProviderClientSslContextProvider. */ @@ -81,10 +81,11 @@ private CertProviderClientSslContextProvider getSslContextProvider( "root-default", alpnProtocols, staticCertValidationContext); - return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), - bootstrapInfo.certProviders()); + return (CertProviderClientSslContextProvider) + certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.node().toEnvoyProtoNode(), + bootstrapInfo.certProviders()); } /** Helper method to build CertProviderClientSslContextProvider. */ @@ -102,7 +103,8 @@ private CertProviderClientSslContextProvider getNewSslContextProvider( "root-default", alpnProtocols, staticCertValidationContext); - return certProviderClientSslContextProviderFactory.getProvider( + return (CertProviderClientSslContextProvider) + certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java similarity index 90% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java index 7cd3cd2a793..14d772c779b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java @@ -14,17 +14,17 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils.getCertFromResourceName; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.doChecksOnSslContext; +import static io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils.getCertFromResourceName; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -35,9 +35,9 @@ import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; -import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil.TestCallback; +import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; import java.util.Arrays; import org.junit.Before; import org.junit.Test; @@ -50,14 +50,14 @@ public class CertProviderServerSslContextProviderTest { CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; - private CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory; + private CertProviderServerSslContextProviderFactory certProviderServerSslContextProviderFactory; @Before public void setUp() throws Exception { certificateProviderRegistry = new CertificateProviderRegistry(); certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); certProviderServerSslContextProviderFactory = - new CertProviderServerSslContextProvider.Factory(certificateProviderStore); + new CertProviderServerSslContextProviderFactory(certificateProviderStore); } /** Helper method to build CertProviderServerSslContextProvider. */ @@ -77,10 +77,11 @@ private CertProviderServerSslContextProvider getSslContextProvider( alpnProtocols, staticCertValidationContext, requireClientCert); - return certProviderServerSslContextProviderFactory.getProvider( - downstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), - bootstrapInfo.certProviders()); + return (CertProviderServerSslContextProvider) + certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapInfo.node().toEnvoyProtoNode(), + bootstrapInfo.certProviders()); } /** Helper method to build CertProviderServerSslContextProvider. */ @@ -100,7 +101,8 @@ private CertProviderServerSslContextProvider getNewSslContextProvider( alpnProtocols, staticCertValidationContext, requireClientCert); - return certProviderServerSslContextProviderFactory.getProvider( + return (CertProviderServerSslContextProvider) + certProviderServerSslContextProviderFactory.getProvider( downstreamTlsContext, bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); @@ -177,6 +179,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { new CertificateProvider.DistributorWatcher[1]; TestCertificateProvider.createAndRegisterProviderProvider( certificateProviderRegistry, watcherCaptor, "testca", 0); + @SuppressWarnings("deprecation") CertificateValidationContext staticCertValidationContext = CertificateValidationContext.newBuilder().addAllMatchSubjectAltNames(Arrays .asList(StringMatcher.newBuilder().setExact("foo.com").build(), diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java similarity index 99% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java index 33ec6b291ed..8f77de7b5e2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CommonCertProviderTestUtils.java similarity index 92% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/CommonCertProviderTestUtils.java index 0e60c4c6716..c62aa2d3a81 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CommonCertProviderTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CommonCertProviderTestUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static java.nio.charset.StandardCharsets.UTF_8; @@ -22,8 +22,8 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.TimeProvider; import io.grpc.internal.testing.TestUtils; -import io.grpc.xds.internal.certprovider.FileWatcherCertificateProviderProvider.ScheduledExecutorServiceFactory; -import io.grpc.xds.internal.sds.trust.CertificateUtils; +import io.grpc.xds.internal.security.certprovider.FileWatcherCertificateProviderProvider.ScheduledExecutorServiceFactory; +import io.grpc.xds.internal.security.trust.CertificateUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java similarity index 99% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java index d113b520057..9f7b13f86e9 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java similarity index 85% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java index 4b22cfb4e34..d08be8cbcef 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java @@ -14,15 +14,15 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -31,17 +31,17 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import io.grpc.Status; import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.certprovider.CertificateProvider.DistributorWatcher; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.NoSuchFileException; import java.nio.file.Paths; +import java.nio.file.attribute.FileTime; import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; @@ -75,7 +75,7 @@ public class FileWatcherCertificateProviderTest { @Mock private CertificateProvider.Watcher mockWatcher; @Mock private ScheduledExecutorService timeService; - @Mock private TimeProvider timeProvider; + private final FakeTimeProvider timeProvider = new FakeTimeProvider(); @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); @@ -114,6 +114,8 @@ private void populateTarget( if (certFileSource != null) { certFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFileSource); Files.copy(Paths.get(certFileSource), Paths.get(certFile), REPLACE_EXISTING); + Files.setLastModifiedTime( + Paths.get(certFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); } if (deleteCurKey) { Files.delete(Paths.get(keyFile)); @@ -121,6 +123,8 @@ private void populateTarget( if (keyFileSource != null) { keyFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFileSource); Files.copy(Paths.get(keyFileSource), Paths.get(keyFile), REPLACE_EXISTING); + Files.setLastModifiedTime( + Paths.get(keyFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); } if (deleteCurRoot) { Files.delete(Paths.get(rootFile)); @@ -128,6 +132,8 @@ private void populateTarget( if (rootFileSource != null) { rootFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFileSource); Files.copy(Paths.get(rootFileSource), Paths.get(rootFile), REPLACE_EXISTING); + Files.setLastModifiedTime( + Paths.get(rootFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); } } @@ -166,7 +172,7 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - Thread.sleep(1000L); + timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); @@ -205,7 +211,7 @@ public void rootFileUpdateOnly() throws IOException, CertificateException, Inter doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - Thread.sleep(1000L); + timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(null, SERVER_1_PEM_FILE); @@ -227,7 +233,7 @@ public void certAndKeyFileUpdateOnly() doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - Thread.sleep(1000L); + timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherUpdates(SERVER_0_PEM_FILE, null); @@ -242,8 +248,6 @@ public void getCertificate_initialMissingCertFile() throws IOException { .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); - when(timeProvider.currentTimeNanos()) - .thenReturn(TimeProvider.SYSTEM_TIME_PROVIDER.currentTimeNanos()); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem"); } @@ -285,12 +289,11 @@ public void getCertificate_missingRootFile() throws IOException, InterruptedExce provider.checkAndReloadCertificates(); reset(mockWatcher); - Thread.sleep(1000L); + timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + timeProvider.forwardTime( + CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), + TimeUnit.MILLISECONDS); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); } @@ -315,22 +318,18 @@ private void commonErrorTest( provider.checkAndReloadCertificates(); reset(mockWatcher); - Thread.sleep(1000L); + timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget( certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + timeProvider.forwardTime( + CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), + TimeUnit.MILLISECONDS); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates( null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null); - reset(mockWatcher, timeProvider); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - CERT0_EXPIRY_TIME_MILLIS - 590_000L)); + reset(mockWatcher); + timeProvider.forwardTime(20, TimeUnit.SECONDS); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates( Status.Code.UNKNOWN, @@ -353,7 +352,7 @@ private void verifyWatcherErrorUpdates( if (code == null && throwableType == null && causeMessages == null) { verify(mockWatcher, never()).onError(any(Status.class)); } else { - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); verify(mockWatcher, times(1)).onError(statusCaptor.capture()); Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(code); @@ -376,7 +375,8 @@ private void verifyTimeServiceAndScheduledFuture() { private void verifyWatcherUpdates(String certPemFile, String rootPemFile) throws IOException, CertificateException { if (certPemFile != null) { - ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(null); + @SuppressWarnings("unchecked") + ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(List.class); verify(mockWatcher, times(1)) .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); List certChain = certChainCaptor.getValue(); @@ -388,7 +388,8 @@ private void verifyWatcherUpdates(String certPemFile, String rootPemFile) .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); } if (rootPemFile != null) { - ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(null); + @SuppressWarnings("unchecked") + ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(List.class); verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); List roots = rootsCaptor.getValue(); assertThat(roots).hasSize(1); @@ -450,4 +451,25 @@ public V get(long timeout, TimeUnit unit) { return null; } } + + /** + * Fake TimeProvider that roughly mirrors FakeClock. Not using FakeClock because it incorrectly + * fails to align the wall-time API TimeProvider.currentTimeNanos() with currentTimeMillis() and + * fixing it upsets a _lot_ of tests. + */ + static class FakeTimeProvider implements TimeProvider { + public long currentTimeNanos = TimeUnit.SECONDS.toNanos(1262332800); /* 2010-01-01 */ + + @Override public long currentTimeNanos() { + return currentTimeNanos; + } + + public void forwardTime(long duration, TimeUnit unit) { + currentTimeNanos += unit.toNanos(duration); + } + + public long currentTimeMillis() { + return TimeUnit.NANOSECONDS.toMillis(currentTimeNanos); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/TestCertificateProvider.java similarity index 98% rename from xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java rename to xds/src/test/java/io/grpc/xds/internal/security/certprovider/TestCertificateProvider.java index 9253d071fba..aba7a910813 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/TestCertificateProvider.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/TestCertificateProvider.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds.internal.certprovider; +package io.grpc.xds.internal.security.certprovider; public class TestCertificateProvider extends CertificateProvider { Object config; diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java similarity index 72% rename from xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java index 47ac9e6bb42..30c5e542a80 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds.trust; +package io.grpc.xds.internal.security.trust; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import com.google.protobuf.ByteString; import io.envoyproxy.envoy.config.core.v3.DataSource; @@ -38,22 +38,22 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link SdsTrustManagerFactory}. */ +/** Unit tests for {@link XdsTrustManagerFactory}. */ @RunWith(JUnit4.class) -public class SdsTrustManagerFactoryTest { +public class XdsTrustManagerFactoryTest { @Test public void constructor_fromFile() throws CertificateException, IOException, CertStoreException { - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); assertThat(tms).hasLength(1); TrustManager myTm = tms[0]; - assertThat(myTm).isInstanceOf(SdsX509TrustManager.class); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) myTm; - X509Certificate[] acceptedIssuers = sdsX509TrustManager.getAcceptedIssuers(); + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) myTm; + X509Certificate[] acceptedIssuers = xdsX509TrustManager.getAcceptedIssuers(); assertThat(acceptedIssuers).isNotNull(); assertThat(acceptedIssuers).hasLength(1); X509Certificate caCert = acceptedIssuers[0]; @@ -64,16 +64,16 @@ public void constructor_fromFile() throws CertificateException, IOException, Cer @Test public void constructor_fromInlineBytes() throws CertificateException, IOException, CertStoreException { - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(getCertContextFromPathAsInlineBytes(CA_PEM_FILE)); + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(getCertContextFromPathAsInlineBytes(CA_PEM_FILE)); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); assertThat(tms).hasLength(1); TrustManager myTm = tms[0]; - assertThat(myTm).isInstanceOf(SdsX509TrustManager.class); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) myTm; - X509Certificate[] acceptedIssuers = sdsX509TrustManager.getAcceptedIssuers(); + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) myTm; + X509Certificate[] acceptedIssuers = xdsX509TrustManager.getAcceptedIssuers(); assertThat(acceptedIssuers).isNotNull(); assertThat(acceptedIssuers).hasLength(1); X509Certificate caCert = acceptedIssuers[0]; @@ -87,16 +87,16 @@ public void constructor_fromRootCert() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); assertThat(tms).hasLength(1); TrustManager myTm = tms[0]; - assertThat(myTm).isInstanceOf(SdsX509TrustManager.class); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) myTm; - X509Certificate[] acceptedIssuers = sdsX509TrustManager.getAcceptedIssuers(); + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) myTm; + X509Certificate[] acceptedIssuers = xdsX509TrustManager.getAcceptedIssuers(); assertThat(acceptedIssuers).isNotNull(); assertThat(acceptedIssuers).hasLength(1); X509Certificate caCert = acceptedIssuers[0]; @@ -110,12 +110,12 @@ public void constructorRootCert_checkServerTrusted() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "waterzooi.test.google.be"); - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); - sdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); + xdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); } @Test @@ -123,7 +123,7 @@ public void constructorRootCert_nonStaticContext_throwsException() throws CertificateException, IOException, CertStoreException { X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); try { - new SdsTrustManagerFactory( + new XdsTrustManagerFactory( new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE)); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { @@ -139,13 +139,13 @@ public void constructorRootCert_checkServerTrusted_throwsException() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { - sdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); + xdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); Assert.fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected) @@ -160,13 +160,13 @@ public void constructorRootCert_checkClientTrusted_throwsException() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { - sdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); + xdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); Assert.fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected) @@ -178,35 +178,35 @@ public void constructorRootCert_checkClientTrusted_throwsException() @Test public void checkServerTrusted_goodCert() throws CertificateException, IOException, CertStoreException { - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); - sdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); + xdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); } @Test public void checkClientTrusted_goodCert() throws CertificateException, IOException, CertStoreException { - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(CLIENT_PEM_FILE)); - sdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); + xdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); } @Test public void checkServerTrusted_badCert_throwsException() throws CertificateException, IOException, CertStoreException { - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(BAD_SERVER_PEM_FILE)); try { - sdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); + xdsX509TrustManager.checkServerTrusted(serverChain, "RSA"); Assert.fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected) @@ -218,13 +218,13 @@ public void checkServerTrusted_badCert_throwsException() @Test public void checkClientTrusted_badCert_throwsException() throws CertificateException, IOException, CertStoreException { - SdsTrustManagerFactory factory = - new SdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); - SdsX509TrustManager sdsX509TrustManager = (SdsX509TrustManager) factory.getTrustManagers()[0]; + XdsTrustManagerFactory factory = + new XdsTrustManagerFactory(getCertContextFromPath(CA_PEM_FILE)); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TestUtils.loadCert(BAD_CLIENT_PEM_FILE)); try { - sdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); + xdsX509TrustManager.checkClientTrusted(clientChain, "RSA"); Assert.fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected) @@ -256,7 +256,9 @@ private static final CertificateValidationContext buildStaticValidationContext( String... verifySans) { CertificateValidationContext.Builder builder = CertificateValidationContext.newBuilder(); for (String san : verifySans) { - builder.addMatchSubjectAltNames(StringMatcher.newBuilder().setExact(san)); + @SuppressWarnings("deprecation") + CertificateValidationContext.Builder unused = + builder.addMatchSubjectAltNames(StringMatcher.newBuilder().setExact(san)); } return builder.build(); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java similarity index 88% rename from xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java rename to xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index 166b60f4caf..c6319ee5a9f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds.trust; +package io.grpc.xds.internal.security.trust; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.junit.Assert.fail; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.doReturn; @@ -54,10 +54,10 @@ import org.mockito.junit.MockitoRule; /** - * Unit tests for {@link SdsX509TrustManager}. + * Unit tests for {@link XdsX509TrustManager}. */ @RunWith(JUnit4.class) -public class SdsX509TrustManagerTest { +public class XdsX509TrustManagerTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @@ -68,11 +68,11 @@ public class SdsX509TrustManagerTest { @Mock private SSLSession mockSession; - private SdsX509TrustManager trustManager; + private XdsX509TrustManager trustManager; @Test public void nullCertContextTest() throws CertificateException, IOException { - trustManager = new SdsX509TrustManager(null, mockDelegate); + trustManager = new XdsX509TrustManager(null, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -81,7 +81,7 @@ public void nullCertContextTest() throws CertificateException, IOException { @Test public void emptySanListContextTest() throws CertificateException, IOException { CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -90,9 +90,10 @@ public void emptySanListContextTest() throws CertificateException, IOException { @Test public void missingPeerCerts() { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); try { trustManager.verifySubjectAltNameInChain(null); fail("no exception thrown"); @@ -104,9 +105,10 @@ public void missingPeerCerts() { @Test public void emptyArrayPeerCerts() { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); try { trustManager.verifySubjectAltNameInChain(new X509Certificate[0]); fail("no exception thrown"); @@ -118,9 +120,10 @@ public void emptyArrayPeerCerts() { @Test public void noSansInPeerCerts() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(CLIENT_PEM_FILE)); try { @@ -138,9 +141,10 @@ public void oneSanInPeerCertsVerifies() throws CertificateException, IOException .setExact("waterzooi.test.google.be") .setIgnoreCase(false) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -154,9 +158,10 @@ public void oneSanInPeerCertsVerifies_differentCase_expectException() .setExact("waterZooi.test.Google.be") .setIgnoreCase(false) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -171,9 +176,10 @@ public void oneSanInPeerCertsVerifies_differentCase_expectException() public void oneSanInPeerCertsVerifies_ignoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("Waterzooi.Test.google.be").setIgnoreCase(true).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -186,9 +192,10 @@ public void oneSanInPeerCerts_prefix() throws CertificateException, IOException .setPrefix("waterzooi.") // test.google.be .setIgnoreCase(false) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -199,9 +206,10 @@ public void oneSanInPeerCertsPrefix_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setPrefix("waterZooi.").setIgnoreCase(false).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -219,9 +227,10 @@ public void oneSanInPeerCerts_prefixIgnoreCase() throws CertificateException, IO .setPrefix("WaterZooi.") // test.google.be .setIgnoreCase(true) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -231,9 +240,10 @@ public void oneSanInPeerCerts_prefixIgnoreCase() throws CertificateException, IO public void oneSanInPeerCerts_suffix() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".google.be").setIgnoreCase(false).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -244,9 +254,10 @@ public void oneSanInPeerCertsSuffix_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".gooGle.bE").setIgnoreCase(false).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -261,9 +272,10 @@ public void oneSanInPeerCertsSuffix_differentCase_expectException() public void oneSanInPeerCerts_suffixIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".GooGle.BE").setIgnoreCase(true).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -273,9 +285,10 @@ public void oneSanInPeerCerts_suffixIgnoreCase() throws CertificateException, IO public void oneSanInPeerCerts_substring() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooi.test.google").setIgnoreCase(false).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -286,9 +299,10 @@ public void oneSanInPeerCertsSubstring_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooi.Test.gooGle").setIgnoreCase(false).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -303,9 +317,10 @@ public void oneSanInPeerCertsSubstring_differentCase_expectException() public void oneSanInPeerCerts_substringIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooI.Test.Google").setIgnoreCase(true).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -318,9 +333,10 @@ public void oneSanInPeerCerts_safeRegex() throws CertificateException, IOExcepti .setSafeRegex( RegexMatcher.newBuilder().setRegex("water[[:alpha:]]{1}ooi\\.test\\.google\\.be")) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -333,9 +349,10 @@ public void oneSanInPeerCerts_safeRegex1() throws CertificateException, IOExcept .setSafeRegex( RegexMatcher.newBuilder().setRegex("no-match-string|\\*\\.test\\.youtube\\.com")) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -348,9 +365,10 @@ public void oneSanInPeerCerts_safeRegex_ipAddress() throws CertificateException, .setSafeRegex( RegexMatcher.newBuilder().setRegex("([[:digit:]]{1,3}\\.){3}[[:digit:]]{1,3}")) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -363,9 +381,10 @@ public void oneSanInPeerCerts_safeRegex_noMatch() throws CertificateException, I .setSafeRegex( RegexMatcher.newBuilder().setRegex("water[[:alpha:]]{2}ooi\\.test\\.google\\.be")) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -382,12 +401,13 @@ public void oneSanInPeerCertsVerifiesMultipleVerifySans() StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("waterzooi.test.google.be").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -397,9 +417,10 @@ public void oneSanInPeerCertsVerifiesMultipleVerifySans() public void oneSanInPeerCertsNotFoundException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -416,12 +437,13 @@ public void wildcardSanInPeerCertsVerifiesMultipleVerifySans() StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setSuffix("test.youTube.Com").setIgnoreCase(true).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) // should match suffix test.youTube.Com .build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -433,12 +455,13 @@ public void wildcardSanInPeerCertsVerifiesMultipleVerifySans1() StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setContains("est.Google.f").setIgnoreCase(true).build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) // should contain est.Google.f .build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -452,9 +475,10 @@ public void wildcardSanInPeerCertsSubdomainMismatch() // sub.test.example.com. StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -469,12 +493,13 @@ public void wildcardSanInPeerCertsSubdomainMismatch() public void oneIpAddressInPeerCertsVerifies() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("192.168.1.3").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); @@ -484,12 +509,13 @@ public void oneIpAddressInPeerCertsVerifies() throws CertificateException, IOExc public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("192.168.2.3").build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate[] certs = CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE)); try { @@ -561,9 +587,10 @@ public void unsupportedAltNameType() throws CertificateException, IOException { .setExact("waterzooi.test.google.be") .setIgnoreCase(false) .build(); + @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new SdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); X509Certificate mockCert = mock(X509Certificate.class); when(mockCert.getSubjectAlternativeNames()) @@ -602,7 +629,7 @@ private SSLParameters buildTrustManagerAndGetSslParameters() throws CertificateException, IOException, CertStoreException { X509Certificate[] caCerts = CertificateUtils.toX509Certificates(TestUtils.loadCert(CA_PEM_FILE)); - trustManager = SdsTrustManagerFactory.createSdsX509TrustManager(caCerts, + trustManager = XdsTrustManagerFactory.createSdsX509TrustManager(caCerts, null); when(mockSession.getProtocol()).thenReturn("TLSv1.2"); when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock"); diff --git a/xds/src/test/java/io/grpc/xds/OrcaMetricReportingServerInterceptorTest.java b/xds/src/test/java/io/grpc/xds/orca/OrcaMetricReportingServerInterceptorTest.java similarity index 83% rename from xds/src/test/java/io/grpc/xds/OrcaMetricReportingServerInterceptorTest.java rename to xds/src/test/java/io/grpc/xds/orca/OrcaMetricReportingServerInterceptorTest.java index 4a2074b91d8..7ec7ef7b5e3 100644 --- a/xds/src/test/java/io/grpc/xds/OrcaMetricReportingServerInterceptorTest.java +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaMetricReportingServerInterceptorTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.orca; import static com.google.common.truth.Truth.assertThat; @@ -69,7 +69,10 @@ public class OrcaMetricReportingServerInterceptorTest { private static final SimpleRequest REQUEST = SimpleRequest.newBuilder().setRequestMessage("Simple request").build(); - private final Map applicationMetrics = new HashMap<>(); + private final Map applicationUtilizationMetrics = new HashMap<>(); + private final Map applicationCostMetrics = new HashMap<>(); + private double cpuUtilizationMetrics = 0; + private double memoryUtilizationMetrics = 0; private final AtomicReference trailersCapture = new AtomicReference<>(); @@ -82,9 +85,16 @@ public void setUp() throws Exception { @Override public void unaryRpc( SimpleRequest request, StreamObserver responseObserver) { - for (Map.Entry entry : applicationMetrics.entrySet()) { - CallMetricRecorder.getCurrent().recordCallMetric(entry.getKey(), entry.getValue()); + for (Map.Entry entry : applicationUtilizationMetrics.entrySet()) { + CallMetricRecorder.getCurrent().recordUtilizationMetric(entry.getKey(), + entry.getValue()); } + for (Map.Entry entry : applicationCostMetrics.entrySet()) { + CallMetricRecorder.getCurrent().recordRequestCostMetric(entry.getKey(), + entry.getValue()); + } + CallMetricRecorder.getCurrent().recordCpuUtilizationMetric(cpuUtilizationMetrics); + CallMetricRecorder.getCurrent().recordMemoryUtilizationMetric(memoryUtilizationMetrics); SimpleResponse response = SimpleResponse.newBuilder().setResponseMessage("Simple response").build(); responseObserver.onNext(response); @@ -111,8 +121,7 @@ public void unaryRpc( @Test public void shareCallMetricRecorderInContext() throws IOException { - final CallMetricRecorder callMetricRecorder = - InternalCallMetricRecorder.newCallMetricRecorder(); + final CallMetricRecorder callMetricRecorder = new CallMetricRecorder(); ServerStreamTracer.Factory callMetricRecorderSharingStreamTracerFactory = new ServerStreamTracer.Factory() { @Override @@ -169,15 +178,24 @@ public void noTrailerReportIfNoRecordedMetrics() { @Test public void responseTrailersContainAllReportedMetrics() { - applicationMetrics.put("cost1", 1231.4543); - applicationMetrics.put("cost2", 0.1367); - applicationMetrics.put("cost3", 7614.145); + applicationCostMetrics.put("cost1", 1231.4543); + applicationCostMetrics.put("cost2", 0.1367); + applicationCostMetrics.put("cost3", 7614.145); + applicationUtilizationMetrics.put("util1", 0.1082); + applicationUtilizationMetrics.put("util2", 0.4936); + applicationUtilizationMetrics.put("util3", 0.5342); + cpuUtilizationMetrics = 0.3465; + memoryUtilizationMetrics = 0.764; ClientCalls.blockingUnaryCall(channelToUse, SIMPLE_METHOD, CallOptions.DEFAULT, REQUEST); Metadata receivedTrailers = trailersCapture.get(); OrcaLoadReport report = receivedTrailers.get(OrcaMetricReportingServerInterceptor.ORCA_ENDPOINT_LOAD_METRICS_KEY); + assertThat(report.getUtilizationMap()) + .containsExactly("util1", 0.1082, "util2", 0.4936, "util3", 0.5342); assertThat(report.getRequestCostMap()) .containsExactly("cost1", 1231.4543, "cost2", 0.1367, "cost3", 7614.145); + assertThat(report.getCpuUtilization()).isEqualTo(0.3465); + assertThat(report.getMemUtilization()).isEqualTo(0.764); } private static final class TrailersCapturingClientInterceptor implements ClientInterceptor { diff --git a/xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilTest.java similarity index 84% rename from xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java rename to xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilTest.java index 5f5bc5a69aa..5c8de10e2d0 100644 --- a/xds/src/test/java/io/grpc/xds/OrcaOobUtilTest.java +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.orca; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -23,7 +23,9 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; +import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.inOrder; @@ -47,6 +49,7 @@ import io.grpc.Context; import io.grpc.Context.CancellationListener; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; @@ -59,12 +62,13 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.FakeClock; +import io.grpc.services.MetricReport; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.OrcaOobUtil.OrcaOobReportListener; -import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig; -import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper; -import io.grpc.xds.OrcaOobUtil.SubchannelImpl; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; +import io.grpc.xds.orca.OrcaOobUtil.OrcaReportingConfig; +import io.grpc.xds.orca.OrcaOobUtil.SubchannelImpl; import java.net.SocketAddress; import java.text.MessageFormat; import java.util.ArrayDeque; @@ -130,9 +134,10 @@ public void uncaughtException(Thread t, Throwable e) { @Mock private BackoffPolicy backoffPolicy1; @Mock private BackoffPolicy backoffPolicy2; private FakeSubchannel[] subchannels = new FakeSubchannel[NUM_SUBCHANNELS]; - private OrcaReportingHelperWrapper orcaHelperWrapper; - private OrcaReportingHelperWrapper parentHelperWrapper; - private OrcaReportingHelperWrapper childHelperWrapper; + private LoadBalancer.Helper orcaHelper; + private LoadBalancer.Helper parentHelper; + private LoadBalancer.Helper childHelper; + private Subchannel savedParentSubchannel; private static FakeSubchannel unwrap(Subchannel s) { return (FakeSubchannel) ((SubchannelImpl) s).delegate(); @@ -175,7 +180,6 @@ public void orcaReportingConfig_construct() { } @Before - @SuppressWarnings("unchecked") public void setUp() throws Exception { MockitoAnnotations.initMocks(this); @@ -202,42 +206,48 @@ public void setUp() throws Exception { when(backoffPolicy1.nextBackoffNanos()).thenReturn(11L, 21L); when(backoffPolicy2.nextBackoffNanos()).thenReturn(12L, 22L); - orcaHelperWrapper = - OrcaOobUtil.newOrcaReportingHelperWrapper( + orcaHelper = + OrcaOobUtil.newOrcaReportingHelper( origHelper, - mockOrcaListener0, backoffPolicyProvider, fakeClock.getStopwatchSupplier()); - parentHelperWrapper = - OrcaOobUtil.newOrcaReportingHelperWrapper( - origHelper, - mockOrcaListener1, - backoffPolicyProvider, - fakeClock.getStopwatchSupplier()); - childHelperWrapper = - OrcaOobUtil.newOrcaReportingHelperWrapper( - parentHelperWrapper.asHelper(), - mockOrcaListener2, + parentHelper = + new ForwardingLoadBalancerHelper() { + @Override + protected Helper delegate() { + return orcaHelper; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = super.createSubchannel(args); + savedParentSubchannel = subchannel; + return subchannel; + } + }; + childHelper = + OrcaOobUtil.newOrcaReportingHelper( + parentHelper, backoffPolicyProvider, fakeClock.getStopwatchSupplier()); } @Test - @SuppressWarnings("unchecked") public void singlePolicyTypicalWorkflow() { - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); verify(origHelper, atLeast(0)).getSynchronizationContext(); verifyNoMoreInteractions(origHelper); // Calling createSubchannel() on orcaHelper correctly passes augmented CreateSubchannelArgs // to origHelper. - ArgumentCaptor createArgsCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor createArgsCaptor = + ArgumentCaptor.forClass(CreateSubchannelArgs.class); for (int i = 0; i < NUM_SUBCHANNELS; i++) { String subchannelAttrValue = "eag attr " + i; Attributes attrs = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); - assertThat(unwrap(createSubchannel(orcaHelperWrapper.asHelper(), i, attrs))) - .isSameInstanceAs(subchannels[i]); + Subchannel created = createSubchannel(orcaHelper, i, attrs); + assertThat(unwrap(created)).isSameInstanceAs(subchannels[i]); + setOrcaReportConfig(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) @@ -278,7 +288,9 @@ public void singlePolicyTypicalWorkflow() { OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); serverCall.responseObserver.onNext(report); assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); - verify(mockOrcaListener0, times(i + 1)).onLoadReport(eq(report)); + verify(mockOrcaListener0, times(i + 1)).onLoadReport( + argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher( + OrcaPerRequestUtil.fromOrcaLoadReport(report)))); } for (int i = 0; i < NUM_SUBCHANNELS; i++) { @@ -306,20 +318,20 @@ public void singlePolicyTypicalWorkflow() { @Test public void twoLevelPoliciesTypicalWorkflow() { - setOrcaReportConfig(childHelperWrapper, SHORT_INTERVAL_CONFIG); - setOrcaReportConfig(parentHelperWrapper, SHORT_INTERVAL_CONFIG); verify(origHelper, atLeast(0)).getSynchronizationContext(); verifyNoMoreInteractions(origHelper); // Calling createSubchannel() on child helper correctly passes augmented CreateSubchannelArgs // to origHelper. - ArgumentCaptor createArgsCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor createArgsCaptor = + ArgumentCaptor.forClass(CreateSubchannelArgs.class); for (int i = 0; i < NUM_SUBCHANNELS; i++) { String subchannelAttrValue = "eag attr " + i; Attributes attrs = Attributes.newBuilder().set(SUBCHANNEL_ATTR_KEY, subchannelAttrValue).build(); - assertThat(unwrap(createSubchannel(childHelperWrapper.asHelper(), i, attrs))) - .isSameInstanceAs(subchannels[i]); + Subchannel created = createSubchannel(childHelper, i, attrs); + assertThat(unwrap(((SubchannelImpl) created).delegate())).isSameInstanceAs(subchannels[i]); + OrcaOobUtil.setListener(created, mockOrcaListener1, SHORT_INTERVAL_CONFIG); verify(origHelper, times(i + 1)).createSubchannel(createArgsCaptor.capture()); assertThat(createArgsCaptor.getValue().getAddresses()).isEqualTo(eagLists[i]); assertThat(createArgsCaptor.getValue().getAttributes().get(SUBCHANNEL_ATTR_KEY)) @@ -363,8 +375,9 @@ public void twoLevelPoliciesTypicalWorkflow() { OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); serverCall.responseObserver.onNext(report); assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); - verify(mockOrcaListener1, times(i + 1)).onLoadReport(eq(report)); - verify(mockOrcaListener2, times(i + 1)).onLoadReport(eq(report)); + verify(mockOrcaListener1, times(i + 1)).onLoadReport( + argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher( + OrcaPerRequestUtil.fromOrcaLoadReport(report)))); } for (int i = 0; i < NUM_SUBCHANNELS; i++) { @@ -391,10 +404,9 @@ public void twoLevelPoliciesTypicalWorkflow() { } @Test - @SuppressWarnings("unchecked") public void orcReportingDisabledWhenServiceNotImplemented() { - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + final Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); FakeSubchannel subchannel = subchannels[0]; OpenRcaServiceImp orcaServiceImp = orcaServiceImps[0]; SubchannelStateListener mockStateListener = mockStateListeners[0]; @@ -421,15 +433,16 @@ public void orcReportingDisabledWhenServiceNotImplemented() { OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); serverCall.responseObserver.onNext(report); assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); - verify(mockOrcaListener0).onLoadReport(eq(report)); - + verify(mockOrcaListener0).onLoadReport( + argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher( + OrcaPerRequestUtil.fromOrcaLoadReport(report)))); verifyNoInteractions(backoffPolicyProvider); } @Test public void orcaReportingStreamClosedAndRetried() { - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + final Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); FakeSubchannel subchannel = subchannels[0]; OpenRcaServiceImp orcaServiceImp = orcaServiceImps[0]; SubchannelStateListener mockStateListener = mockStateListeners[0]; @@ -467,8 +480,9 @@ public void orcaReportingStreamClosedAndRetried() { OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); orcaServiceImp.calls.peek().responseObserver.onNext(report); assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report); - inOrder.verify(mockOrcaListener0).onLoadReport(eq(report)); - + inOrder.verify(mockOrcaListener0).onLoadReport( + argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher( + OrcaPerRequestUtil.fromOrcaLoadReport(report)))); // Server closes the ORCA reporting RPC after a response, will restart immediately. orcaServiceImp.calls.poll().responseObserver.onCompleted(); assertThat(subchannel.logs).containsExactly( @@ -494,14 +508,14 @@ public void orcaReportingStreamClosedAndRetried() { @Test public void reportingNotStartedUntilConfigured() { - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); verify(mockStateListeners[0]) .onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); assertThat(orcaServiceImps[0].calls).isEmpty(); assertThat(subchannels[0].logs).isEmpty(); - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls).hasSize(1); assertLog(subchannels[0].logs, "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); @@ -509,10 +523,33 @@ public void reportingNotStartedUntilConfigured() { .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); } + @Test + public void updateListenerThrows() { + Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); + deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); + verify(mockStateListeners[0]) + .onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); + + assertThat(orcaServiceImps[0].calls).hasSize(1); + assertLog(subchannels[0].logs, + "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); + assertThat(orcaServiceImps[0].calls.peek().request) + .isEqualTo(buildOrcaRequestFromConfig(SHORT_INTERVAL_CONFIG)); + assertThat(unwrap(created)).isSameInstanceAs(subchannels[0]); + try { + OrcaOobUtil.setListener(subchannels[0], mockOrcaListener1, MEDIUM_INTERVAL_CONFIG); + fail("Update orca listener on non-orca subchannel should fail"); + } catch (IllegalArgumentException ex) { + assertThat(ex.getMessage()).isEqualTo("Subchannel does not have orca Out-Of-Band " + + "stream enabled. Try to use a subchannel created by OrcaOobUtil.OrcaHelper."); + } + } + @Test public void updateReportingIntervalBeforeCreatingSubchannel() { - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); @@ -525,8 +562,8 @@ public void updateReportingIntervalBeforeCreatingSubchannel() { @Test public void updateReportingIntervalBeforeSubchannelReady() { - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); @@ -541,8 +578,9 @@ public void updateReportingIntervalBeforeSubchannelReady() { public void updateReportingIntervalWhenRpcActive() { // Sets report interval before creating a Subchannel, reporting starts right after suchannel // state becomes READY. - setOrcaReportConfig(orcaHelperWrapper, MEDIUM_INTERVAL_CONFIG); - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); + Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, + MEDIUM_INTERVAL_CONFIG); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); @@ -553,7 +591,7 @@ public void updateReportingIntervalWhenRpcActive() { .isEqualTo(buildOrcaRequestFromConfig(MEDIUM_INTERVAL_CONFIG)); // Make reporting less frequent. - setOrcaReportConfig(orcaHelperWrapper, LONG_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener0, LONG_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls.poll().cancelled).isTrue(); assertThat(orcaServiceImps[0].calls).hasSize(1); assertLog(subchannels[0].logs, @@ -562,12 +600,13 @@ public void updateReportingIntervalWhenRpcActive() { .isEqualTo(buildOrcaRequestFromConfig(LONG_INTERVAL_CONFIG)); // Configuring with the same report interval again does not restart ORCA RPC. - setOrcaReportConfig(orcaHelperWrapper, LONG_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener0, LONG_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls.peek().cancelled).isFalse(); assertThat(subchannels[0].logs).isEmpty(); // Make reporting more frequent. - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener0, + SHORT_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls.poll().cancelled).isTrue(); assertThat(orcaServiceImps[0].calls).hasSize(1); assertLog(subchannels[0].logs, @@ -578,8 +617,8 @@ public void updateReportingIntervalWhenRpcActive() { @Test public void updateReportingIntervalWhenRpcPendingRetry() { - createSubchannel(orcaHelperWrapper.asHelper(), 0, Attributes.EMPTY); - setOrcaReportConfig(orcaHelperWrapper, SHORT_INTERVAL_CONFIG); + Subchannel created = createSubchannel(orcaHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, SHORT_INTERVAL_CONFIG); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); @@ -599,7 +638,7 @@ public void updateReportingIntervalWhenRpcPendingRetry() { assertThat(orcaServiceImps[0].calls).isEmpty(); // Make reporting less frequent. - setOrcaReportConfig(orcaHelperWrapper, LONG_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener0, LONG_INTERVAL_CONFIG); // Retry task will be canceled and restarts new RPC immediately. assertThat(fakeClock.getPendingTasks()).isEmpty(); assertThat(orcaServiceImps[0].calls).hasSize(1); @@ -611,7 +650,7 @@ public void updateReportingIntervalWhenRpcPendingRetry() { @Test public void policiesReceiveSameReportIndependently() { - createSubchannel(childHelperWrapper.asHelper(), 0, Attributes.EMPTY); + Subchannel childSubchannel = createSubchannel(childHelper, 0, Attributes.EMPTY); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); // No helper sets ORCA reporting interval, so load reporting is not started. @@ -620,7 +659,7 @@ public void policiesReceiveSameReportIndependently() { assertThat(subchannels[0].logs).isEmpty(); // Parent helper requests ORCA reports with a certain interval, load reporting starts. - setOrcaReportConfig(parentHelperWrapper, SHORT_INTERVAL_CONFIG); + OrcaOobUtil.setListener(savedParentSubchannel, mockOrcaListener1, SHORT_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls).hasSize(1); assertLog(subchannels[0].logs, "DEBUG: Starting ORCA reporting for " + subchannels[0].getAllAddresses()); @@ -630,17 +669,18 @@ public void policiesReceiveSameReportIndependently() { orcaServiceImps[0].calls.peek().responseObserver.onNext(report); assertLog(subchannels[0].logs, "DEBUG: Received an ORCA report: " + report); // Only parent helper's listener receives the report. - ArgumentCaptor parentReportCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor parentReportCaptor = ArgumentCaptor.forClass(MetricReport.class); verify(mockOrcaListener1).onLoadReport(parentReportCaptor.capture()); - assertThat(parentReportCaptor.getValue()).isEqualTo(report); + assertThat(OrcaPerRequestUtilTest.reportEqual(parentReportCaptor.getValue(), + OrcaPerRequestUtil.fromOrcaLoadReport(report))).isTrue(); verifyNoMoreInteractions(mockOrcaListener2); // Now child helper also wants to receive reports. - setOrcaReportConfig(childHelperWrapper, SHORT_INTERVAL_CONFIG); + OrcaOobUtil.setListener(childSubchannel, mockOrcaListener2, SHORT_INTERVAL_CONFIG); orcaServiceImps[0].calls.peek().responseObserver.onNext(report); assertLog(subchannels[0].logs, "DEBUG: Received an ORCA report: " + report); // Both helper receives the same report instance. - ArgumentCaptor childReportCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor childReportCaptor = ArgumentCaptor.forClass(MetricReport.class); verify(mockOrcaListener1, times(2)) .onLoadReport(parentReportCaptor.capture()); verify(mockOrcaListener2) @@ -650,9 +690,9 @@ public void policiesReceiveSameReportIndependently() { @Test public void reportWithMostFrequentIntervalRequested() { - setOrcaReportConfig(parentHelperWrapper, SHORT_INTERVAL_CONFIG); - setOrcaReportConfig(childHelperWrapper, LONG_INTERVAL_CONFIG); - createSubchannel(childHelperWrapper.asHelper(), 0, Attributes.EMPTY); + Subchannel created = createSubchannel(childHelper, 0, Attributes.EMPTY); + OrcaOobUtil.setListener(created, mockOrcaListener0, LONG_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener1, SHORT_INTERVAL_CONFIG); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); verify(mockStateListeners[0]).onSubchannelState(eq(ConnectivityStateInfo.forNonError(READY))); assertThat(orcaServiceImps[0].calls).hasSize(1); @@ -663,14 +703,14 @@ public void reportWithMostFrequentIntervalRequested() { assertThat(Durations.toNanos(orcaServiceImps[0].calls.peek().request.getReportInterval())) .isEqualTo(SHORT_INTERVAL_CONFIG.getReportIntervalNanos()); - // Child helper wants reporting to be more frequent than its current setting while it is still + // Parent helper wants reporting to be more frequent than its current setting while it is still // less frequent than parent helper. Nothing should happen on existing RPC. - setOrcaReportConfig(childHelperWrapper, MEDIUM_INTERVAL_CONFIG); + OrcaOobUtil.setListener(savedParentSubchannel, mockOrcaListener0, MEDIUM_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls.peek().cancelled).isFalse(); assertThat(subchannels[0].logs).isEmpty(); // Parent helper wants reporting to be less frequent. - setOrcaReportConfig(parentHelperWrapper, MEDIUM_INTERVAL_CONFIG); + OrcaOobUtil.setListener(created, mockOrcaListener1, MEDIUM_INTERVAL_CONFIG); assertThat(orcaServiceImps[0].calls.poll().cancelled).isTrue(); assertThat(orcaServiceImps[0].calls).hasSize(1); assertLog(subchannels[0].logs, @@ -723,13 +763,10 @@ public void run() { } private void setOrcaReportConfig( - final OrcaReportingHelperWrapper helperWrapper, final OrcaReportingConfig config) { - syncContext.execute(new Runnable() { - @Override - public void run() { - helperWrapper.setReportingConfig(config); - } - }); + final Subchannel subchannel, + final OrcaOobReportListener listener, + final OrcaReportingConfig config) { + OrcaOobUtil.setListener(subchannel, listener, config); } private static final class OpenRcaServiceImp extends OpenRcaServiceGrpc.OpenRcaServiceImplBase { diff --git a/xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java b/xds/src/test/java/io/grpc/xds/orca/OrcaPerRequestUtilTest.java similarity index 77% rename from xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java rename to xds/src/test/java/io/grpc/xds/orca/OrcaPerRequestUtilTest.java index a6e7c6aca20..bbc61c3be4f 100644 --- a/xds/src/test/java/io/grpc/xds/OrcaPerRequestUtilTest.java +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaPerRequestUtilTest.java @@ -14,12 +14,12 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.orca; import static com.google.common.truth.Truth.assertThat; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -28,15 +28,18 @@ import static org.mockito.Mockito.when; import com.github.xds.data.orca.v3.OrcaLoadReport; +import com.google.common.base.Objects; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; -import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener; -import io.grpc.xds.OrcaPerRequestUtil.OrcaReportingTracerFactory; +import io.grpc.services.MetricReport; +import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaReportingTracerFactory; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -80,7 +83,8 @@ public void singlePolicyTypicalWorkflow() { OrcaPerRequestUtil.getInstance() .newOrcaClientStreamTracerFactory(fakeDelegateFactory, orcaListener1); ClientStreamTracer tracer = factory.newClientStreamTracer(STREAM_INFO, new Metadata()); - ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor streamInfoCaptor = + ArgumentCaptor.forClass(ClientStreamTracer.StreamInfo.class); verify(fakeDelegateFactory) .newClientStreamTracer(streamInfoCaptor.capture(), any(Metadata.class)); ClientStreamTracer.StreamInfo capturedInfo = streamInfoCaptor.getValue(); @@ -96,9 +100,31 @@ public void singlePolicyTypicalWorkflow() { OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, OrcaLoadReport.getDefaultInstance()); tracer.inboundTrailers(trailer); - ArgumentCaptor reportCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor reportCaptor = ArgumentCaptor.forClass(MetricReport.class); verify(orcaListener1).onLoadReport(reportCaptor.capture()); - assertThat(reportCaptor.getValue()).isEqualTo(OrcaLoadReport.getDefaultInstance()); + assertThat(reportEqual(reportCaptor.getValue(), + OrcaPerRequestUtil.fromOrcaLoadReport(OrcaLoadReport.getDefaultInstance()))).isTrue(); + } + + static final class MetricsReportMatcher implements ArgumentMatcher { + private MetricReport original; + + public MetricsReportMatcher(MetricReport report) { + this.original = report; + } + + @Override + public boolean matches(MetricReport argument) { + return reportEqual(original, argument); + } + } + + static boolean reportEqual(MetricReport a, + MetricReport b) { + return a.getCpuUtilization() == b.getCpuUtilization() + && a.getMemoryUtilization() == b.getMemoryUtilization() + && Objects.equal(a.getRequestCostMetrics(), b.getRequestCostMetrics()) + && Objects.equal(a.getUtilizationMetrics(), b.getUtilizationMetrics()); } /** @@ -118,7 +144,8 @@ public void twoLevelPoliciesTypicalWorkflow() { // Child factory will augment the StreamInfo and pass it to the parent factory. ClientStreamTracer childTracer = childFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); - ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor streamInfoCaptor = + ArgumentCaptor.forClass(ClientStreamTracer.StreamInfo.class); verify(parentFactory).newClientStreamTracer(streamInfoCaptor.capture(), any(Metadata.class)); ClientStreamTracer.StreamInfo parentStreamInfo = streamInfoCaptor.getValue(); assertThat(parentStreamInfo).isNotEqualTo(STREAM_INFO); @@ -136,11 +163,12 @@ public void twoLevelPoliciesTypicalWorkflow() { OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, OrcaLoadReport.getDefaultInstance()); childTracer.inboundTrailers(trailer); - ArgumentCaptor parentReportCap = ArgumentCaptor.forClass(null); - ArgumentCaptor childReportCap = ArgumentCaptor.forClass(null); + ArgumentCaptor parentReportCap = ArgumentCaptor.forClass(MetricReport.class); + ArgumentCaptor childReportCap = ArgumentCaptor.forClass(MetricReport.class); verify(orcaListener1).onLoadReport(parentReportCap.capture()); verify(orcaListener2).onLoadReport(childReportCap.capture()); - assertThat(parentReportCap.getValue()).isEqualTo(OrcaLoadReport.getDefaultInstance()); + assertThat(reportEqual(parentReportCap.getValue(), + OrcaPerRequestUtil.fromOrcaLoadReport(OrcaLoadReport.getDefaultInstance()))).isTrue(); assertThat(childReportCap.getValue()).isSameInstanceAs(parentReportCap.getValue()); } @@ -159,11 +187,12 @@ public void onlyParentPolicyReceivesReportsIfCreatesOwnTracer() { ClientStreamTracer parentTracer = parentFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); Metadata trailer = new Metadata(); + OrcaLoadReport report = OrcaLoadReport.getDefaultInstance(); trailer.put( - OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, - OrcaLoadReport.getDefaultInstance()); + OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, report); parentTracer.inboundTrailers(trailer); - verify(orcaListener1).onLoadReport(eq(OrcaLoadReport.getDefaultInstance())); + verify(orcaListener1).onLoadReport( + argThat(new MetricsReportMatcher(OrcaPerRequestUtil.fromOrcaLoadReport(report)))); verifyNoInteractions(childFactory); verifyNoInteractions(orcaListener2); } diff --git a/xds/src/test/java/io/grpc/xds/orca/OrcaServiceImplTest.java b/xds/src/test/java/io/grpc/xds/orca/OrcaServiceImplTest.java new file mode 100644 index 00000000000..a292df0f035 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaServiceImplTest.java @@ -0,0 +1,302 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.orca; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import com.github.xds.data.orca.v3.OrcaLoadReport; +import com.github.xds.service.orca.v3.OpenRcaServiceGrpc; +import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Duration; +import io.grpc.BindableService; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.services.MetricRecorder; +import io.grpc.testing.GrpcCleanupRule; +import java.util.Iterator; +import java.util.Random; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class OrcaServiceImplTest { + @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + private ManagedChannel channel; + private Server oobServer; + private final FakeClock fakeClock = new FakeClock(); + private MetricRecorder defaultTestService; + private BindableService orcaServiceImpl; + private final Random random = new Random(); + @Mock + ClientCall.Listener listener; + + @Before + public void setup() throws Exception { + defaultTestService = MetricRecorder.newInstance(); + orcaServiceImpl = OrcaServiceImpl.createService(fakeClock.getScheduledExecutorService(), + defaultTestService, 1, TimeUnit.SECONDS); + startServerAndGetChannel(orcaServiceImpl); + } + + @After + public void teardown() throws Exception { + channel.shutdownNow(); + } + + private void startServerAndGetChannel(BindableService orcaService) throws Exception { + oobServer = grpcCleanup.register( + InProcessServerBuilder.forName("orca-service-test") + .addService(orcaService) + .directExecutor() + .build() + .start()); + channel = grpcCleanup.register( + InProcessChannelBuilder.forName("orca-service-test") + .directExecutor().build()); + } + + @Test + public void testReportingLifeCycle() { + defaultTestService.setCpuUtilizationMetric(0.1); + Iterator reports = OpenRcaServiceGrpc.newBlockingStub(channel) + .streamCoreMetrics(OrcaLoadReportRequest.newBuilder().build()); + assertThat(reports.next()).isEqualTo( + OrcaLoadReport.newBuilder().setCpuUtilization(0.1).build()); + assertThat(((OrcaServiceImpl)orcaServiceImpl).clientCount.get()).isEqualTo(1); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(fakeClock.forwardTime(1, TimeUnit.SECONDS)).isEqualTo(1); + assertThat(reports.next()).isEqualTo( + OrcaLoadReport.newBuilder().setCpuUtilization(0.1).build()); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + channel.shutdownNow(); + assertThat(((OrcaServiceImpl)orcaServiceImpl).clientCount.get()).isEqualTo(0); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(0); + } + + @Test + @SuppressWarnings("unchecked") + public void testReportingLifeCycle_serverShutdown() { + ClientCall call = channel.newCall( + OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + defaultTestService.putUtilizationMetric("buffer", 0.2); + call.start(listener, new Metadata()); + call.sendMessage(OrcaLoadReportRequest.newBuilder() + .setReportInterval(Duration.newBuilder().setSeconds(0).setNanos(500).build()).build()); + call.halfClose(); + call.request(1); + OrcaLoadReport expect = OrcaLoadReport.newBuilder().putUtilization("buffer", 0.2).build(); + assertThat(((OrcaServiceImpl)orcaServiceImpl).clientCount.get()).isEqualTo(1); + verify(listener).onMessage(eq(expect)); + reset(listener); + oobServer.shutdownNow(); + assertThat(fakeClock.forwardTime(1, TimeUnit.SECONDS)).isEqualTo(0); + assertThat(((OrcaServiceImpl)orcaServiceImpl).clientCount.get()).isEqualTo(0); + ArgumentCaptor callCloseCaptor = ArgumentCaptor.forClass(Status.class); + verify(listener).onClose(callCloseCaptor.capture(), any()); + assertThat(callCloseCaptor.getValue().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + } + + @Test + @SuppressWarnings("unchecked") + public void testRequestIntervalLess() { + ClientCall call = channel.newCall( + OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + defaultTestService.putUtilizationMetric("buffer", 0.2); + call.start(listener, new Metadata()); + call.sendMessage(OrcaLoadReportRequest.newBuilder() + .setReportInterval(Duration.newBuilder().setSeconds(0).setNanos(500).build()).build()); + call.halfClose(); + call.request(1); + OrcaLoadReport expect = OrcaLoadReport.newBuilder().putUtilization("buffer", 0.2).build(); + verify(listener).onMessage(eq(expect)); + reset(listener); + defaultTestService.removeUtilizationMetric("buffer0"); + assertThat(fakeClock.forwardTime(500, TimeUnit.NANOSECONDS)).isEqualTo(0); + verifyNoInteractions(listener); + assertThat(fakeClock.forwardTime(1, TimeUnit.SECONDS)).isEqualTo(1); + call.request(1); + verify(listener).onMessage(eq(expect)); + } + + @Test + @SuppressWarnings("unchecked") + public void testRequestIntervalGreater() { + ClientCall call = channel.newCall( + OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + defaultTestService.putUtilizationMetric("buffer", 0.2); + call.start(listener, new Metadata()); + call.sendMessage(OrcaLoadReportRequest.newBuilder() + .setReportInterval(Duration.newBuilder().setSeconds(10).build()).build()); + call.halfClose(); + call.request(1); + OrcaLoadReport expect = OrcaLoadReport.newBuilder().putUtilization("buffer", 0.2).build(); + verify(listener).onMessage(eq(expect)); + reset(listener); + defaultTestService.removeUtilizationMetric("buffer0"); + assertThat(fakeClock.forwardTime(1, TimeUnit.SECONDS)).isEqualTo(0); + verifyNoInteractions(listener); + assertThat(fakeClock.forwardTime(9, TimeUnit.SECONDS)).isEqualTo(1); + call.request(1); + verify(listener).onMessage(eq(expect)); + } + + @Test + @SuppressWarnings("unchecked") + public void testRequestIntervalDefault() throws Exception { + defaultTestService = MetricRecorder.newInstance(); + oobServer.shutdownNow(); + startServerAndGetChannel(OrcaServiceImpl.createService( + fakeClock.getScheduledExecutorService(), defaultTestService)); + ClientCall call = channel.newCall( + OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + defaultTestService.putUtilizationMetric("buffer", 0.2); + call.start(listener, new Metadata()); + call.sendMessage(OrcaLoadReportRequest.newBuilder() + .setReportInterval(Duration.newBuilder().setSeconds(10).build()).build()); + call.halfClose(); + call.request(1); + OrcaLoadReport expect = OrcaLoadReport.newBuilder().putUtilization("buffer", 0.2).build(); + verify(listener).onMessage(eq(expect)); + reset(listener); + defaultTestService.removeUtilizationMetric("buffer0"); + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(0); + verifyNoInteractions(listener); + assertThat(fakeClock.forwardTime(20, TimeUnit.SECONDS)).isEqualTo(1); + call.request(1); + verify(listener).onMessage(eq(expect)); + } + + @Test + public void testMultipleClients() { + ClientCall call = channel.newCall( + OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + defaultTestService.putUtilizationMetric("omg", 100); + call.start(listener, new Metadata()); + call.sendMessage(OrcaLoadReportRequest.newBuilder().build()); + call.halfClose(); + call.request(1); + OrcaLoadReport expect = OrcaLoadReport.newBuilder().putUtilization("omg", 100).build(); + verify(listener).onMessage(eq(expect)); + defaultTestService.setMemoryUtilizationMetric(0.5); + ClientCall call2 = channel.newCall( + OpenRcaServiceGrpc.getStreamCoreMetricsMethod(), CallOptions.DEFAULT); + call2.start(listener, new Metadata()); + call2.sendMessage(OrcaLoadReportRequest.newBuilder().build()); + call2.halfClose(); + call2.request(1); + expect = OrcaLoadReport.newBuilder(expect).setMemUtilization(0.5).build(); + verify(listener).onMessage(eq(expect)); + assertThat(((OrcaServiceImpl)orcaServiceImpl).clientCount.get()).isEqualTo(2); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(2); + channel.shutdownNow(); + assertThat(fakeClock.forwardTime(1, TimeUnit.SECONDS)).isEqualTo(0); + assertThat(((OrcaServiceImpl)orcaServiceImpl).clientCount.get()).isEqualTo(0); + ArgumentCaptor callCloseCaptor = ArgumentCaptor.forClass(Status.class); + verify(listener, times(2)).onClose(callCloseCaptor.capture(), any()); + assertThat(callCloseCaptor.getValue().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + } + + @Test + public void testApis() throws Exception { + ImmutableMap firstUtilization = ImmutableMap.of("util", 0.1); + OrcaLoadReport goldenReport = OrcaLoadReport.newBuilder() + .setCpuUtilization(random.nextDouble()) + .setMemUtilization(random.nextDouble()) + .putAllUtilization(firstUtilization) + .putUtilization("queue", 1.0) + .build(); + defaultTestService.setCpuUtilizationMetric(goldenReport.getCpuUtilization()); + defaultTestService.setMemoryUtilizationMetric(goldenReport.getMemUtilization()); + defaultTestService.setAllUtilizationMetrics(firstUtilization); + defaultTestService.putUtilizationMetric("queue", 1.0); + Iterator reports = OpenRcaServiceGrpc.newBlockingStub(channel) + .streamCoreMetrics(OrcaLoadReportRequest.newBuilder().build()); + assertThat(reports.next()).isEqualTo(goldenReport); + + defaultTestService.clearCpuUtilizationMetric(); + defaultTestService.clearMemoryUtilizationMetric(); + fakeClock.forwardTime(1, TimeUnit.SECONDS); + goldenReport = OrcaLoadReport.newBuilder() + .putAllUtilization(firstUtilization) + .putUtilization("queue", 1.0) + .putUtilization("util", 0.1) + .build(); + assertThat(reports.next()).isEqualTo(goldenReport); + defaultTestService.removeUtilizationMetric("util-not-exist"); + defaultTestService.removeUtilizationMetric("queue-not-exist"); + fakeClock.forwardTime(1, TimeUnit.SECONDS); + assertThat(reports.next()).isEqualTo(goldenReport); + + CyclicBarrier barrier = new CyclicBarrier(2); + new Thread(new Runnable() { + @Override + public void run() { + try { + barrier.await(); + } catch (Exception ex) { + throw new AssertionError(ex); + } + defaultTestService.removeUtilizationMetric("util"); + defaultTestService.setMemoryUtilizationMetric(0.4); + defaultTestService.setAllUtilizationMetrics(firstUtilization); + try { + barrier.await(); + } catch (Exception ex) { + throw new AssertionError(ex); + } + } + }).start(); + barrier.await(); + defaultTestService.setMemoryUtilizationMetric(0.4); + defaultTestService.removeUtilizationMetric("util"); + defaultTestService.setAllUtilizationMetrics(firstUtilization); + barrier.await(); + goldenReport = OrcaLoadReport.newBuilder() + .putAllUtilization(firstUtilization) + .setMemUtilization(0.4) + .build(); + fakeClock.forwardTime(1, TimeUnit.SECONDS); + assertThat(reports.next()).isEqualTo(goldenReport); + } +} diff --git a/xds/third_party/envoy/LICENSE b/xds/third_party/envoy/LICENSE index 1e2bdc6ae7b..d6456956733 100644 --- a/xds/third_party/envoy/LICENSE +++ b/xds/third_party/envoy/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner]. + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xds/third_party/envoy/NOTICE b/xds/third_party/envoy/NOTICE index 9a9b287cbe8..8604a8bbd60 100644 --- a/xds/third_party/envoy/NOTICE +++ b/xds/third_party/envoy/NOTICE @@ -1,4 +1,4 @@ Envoy -Copyright 2016-2019 Envoy Project Authors +Copyright The Envoy Project Authors Licensed under Apache License 2.0. See LICENSE for terms. diff --git a/xds/third_party/envoy/import.sh b/xds/third_party/envoy/import.sh index c77ee9272e0..13813fb9f38 100755 --- a/xds/third_party/envoy/import.sh +++ b/xds/third_party/envoy/import.sh @@ -17,8 +17,8 @@ set -e BRANCH=main -# import VERSION from one of the google internal CLs -VERSION=c223756b0856f734a6a5cff2d0b95388cd2583d4 +# import VERSION from the google internal copybara_version.txt for Envoy +VERSION=2f99e0c9f83b6c91b42d215a148ed49ce0f174fd GIT_REPO="https://github.com/envoyproxy/envoy.git" GIT_BASE_DIR=envoy SOURCE_PROTO_BASE_DIR=envoy/api @@ -129,6 +129,10 @@ envoy/extensions/filters/http/fault/v3/fault.proto envoy/extensions/filters/http/rbac/v3/rbac.proto envoy/extensions/filters/http/router/v3/router.proto envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto +envoy/extensions/load_balancing_policies/ring_hash/v3/ring_hash.proto +envoy/extensions/load_balancing_policies/round_robin/v3/round_robin.proto +envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto envoy/extensions/transport_sockets/tls/v3/cert.proto envoy/extensions/transport_sockets/tls/v3/common.proto envoy/extensions/transport_sockets/tls/v3/secret.proto diff --git a/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump.proto b/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump.proto index ddafb56b393..336d5b13eec 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump.proto @@ -13,6 +13,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.admin.v3"; option java_outer_classname = "ConfigDumpProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/admin/v3;adminv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: ConfigDump] diff --git a/xds/third_party/envoy/src/main/proto/envoy/annotations/deprecation.proto b/xds/third_party/envoy/src/main/proto/envoy/annotations/deprecation.proto index ce02ab98a8d..c9a96f1ae27 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/annotations/deprecation.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/annotations/deprecation.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package envoy.annotations; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/annotations"; import "google/protobuf/descriptor.proto"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/annotations/resource.proto b/xds/third_party/envoy/src/main/proto/envoy/annotations/resource.proto index bd794c68dd7..3877afc7fe3 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/annotations/resource.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/annotations/resource.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package envoy.annotations; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/annotations"; import "google/protobuf/descriptor.proto"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/cert.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/cert.proto index 6a9cbddd250..81e2672d9b6 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/cert.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/cert.proto @@ -11,5 +11,6 @@ import public "envoy/api/v2/auth/tls.proto"; option java_package = "io.envoyproxy.envoy.api.v2.auth"; option java_outer_classname = "CertProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/auth"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.transport_sockets.tls.v3"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/common.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/common.proto index c8122f40102..cd55ccd4dbd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/common.proto @@ -17,6 +17,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.auth"; option java_outer_classname = "CommonProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/auth"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.transport_sockets.tls.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; @@ -41,8 +42,7 @@ message TlsParameters { TLSv1_3 = 4; } - // Minimum TLS protocol version. By default, it's ``TLSv1_2`` for clients and ``TLSv1_0`` for - // servers. + // Minimum TLS protocol version. By default, it's ``TLSv1_2`` for both clients and servers. TlsProtocol tls_minimum_protocol_version = 1 [(validate.rules).enum = {defined_only: true}]; // Maximum TLS protocol version. By default, it's ``TLSv1_2`` for clients and ``TLSv1_3`` for diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/secret.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/secret.proto index 3a6d8cf7dcb..4a4ab3bf169 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/secret.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/secret.proto @@ -13,6 +13,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.api.v2.auth"; option java_outer_classname = "SecretProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/auth"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.transport_sockets.tls.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/tls.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/tls.proto index 201973a2b9d..911ada77d3f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/tls.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/auth/tls.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.auth"; option java_outer_classname = "TlsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/auth"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.transport_sockets.tls.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; @@ -81,10 +82,9 @@ message DownstreamTlsContext { bool disable_stateless_session_resumption = 7; } - // If specified, session_timeout will change maximum lifetime (in seconds) of TLS session - // Currently this value is used as a hint to `TLS session ticket lifetime (for TLSv1.2) - // ` - // only seconds could be specified (fractional seconds are going to be ignored). + // If specified, ``session_timeout`` will change the maximum lifetime (in seconds) of the TLS session. + // Currently this value is used as a hint for the `TLS session ticket lifetime (for TLSv1.2) `_. + // Only seconds can be specified (fractional seconds are ignored). google.protobuf.Duration session_timeout = 6 [(validate.rules).duration = { lt {seconds: 4294967296} gte {} diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cds.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cds.proto index 0b657a0fa45..38f7c3c19ec 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cds.proto @@ -15,6 +15,7 @@ import public "envoy/api/v2/cluster.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "CdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option java_generic_services = true; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.cluster.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster.proto index fab95f71b76..b1b6751de4b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster.proto @@ -27,6 +27,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "ClusterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.cluster.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/circuit_breaker.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/circuit_breaker.proto index 510619b2642..c45409bcc73 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/circuit_breaker.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/circuit_breaker.proto @@ -14,8 +14,9 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.cluster"; option java_outer_classname = "CircuitBreakerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/cluster"; option csharp_namespace = "Envoy.Api.V2.ClusterNS"; -option ruby_package = "Envoy.Api.V2.ClusterNS"; +option ruby_package = "Envoy::Api::V2::ClusterNS"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.cluster.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/filter.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/filter.proto index b87ad79d8f3..1609be4ca25 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/filter.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/filter.proto @@ -11,8 +11,9 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.cluster"; option java_outer_classname = "FilterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/cluster"; option csharp_namespace = "Envoy.Api.V2.ClusterNS"; -option ruby_package = "Envoy.Api.V2.ClusterNS"; +option ruby_package = "Envoy::Api::V2::ClusterNS"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.cluster.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/outlier_detection.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/outlier_detection.proto index 6cf35e41ff1..ec8c6ee7311 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/outlier_detection.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/cluster/outlier_detection.proto @@ -12,8 +12,9 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.cluster"; option java_outer_classname = "OutlierDetectionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/cluster"; option csharp_namespace = "Envoy.Api.V2.ClusterNS"; -option ruby_package = "Envoy.Api.V2.ClusterNS"; +option ruby_package = "Envoy::Api::V2::ClusterNS"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.cluster.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/address.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/address.proto index fdcb4e7d94f..3399538be10 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/address.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/address.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "AddressProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/backoff.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/backoff.proto index e45c71e39be..845dfce39e0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/backoff.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/backoff.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "BackoffProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/base.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/base.proto index 32cd90b4ee1..94b346bc3e8 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/base.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/base.proto @@ -21,6 +21,7 @@ import public "envoy/api/v2/core/socket_option.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "BaseProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/config_source.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/config_source.proto index 6cf44dbe9bb..b3b400ae64c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/config_source.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/config_source.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "ConfigSourceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/event_service_config.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/event_service_config.proto index f822f8c6b63..12ec25d4d41 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/event_service_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/event_service_config.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "EventServiceConfigProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/grpc_service.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/grpc_service.proto index dd789644e1d..faafb7f0f7f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/grpc_service.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/grpc_service.proto @@ -17,6 +17,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "GrpcServiceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/health_check.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/health_check.proto index bc4ae3e5c86..347ac9c96b9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/health_check.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/health_check.proto @@ -21,6 +21,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "HealthCheckProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/http_uri.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/http_uri.proto index cd1a0660e33..cb95125b90c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/http_uri.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/http_uri.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "HttpUriProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/protocol.proto index ae1a86424cf..3b7fe358964 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/protocol.proto @@ -12,6 +12,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "ProtocolProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/socket_option.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/socket_option.proto index 39678ad1b8b..da8140596dd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/socket_option.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/core/socket_option.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.core"; option java_outer_classname = "SocketOptionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.core.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/discovery.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/discovery.proto index da2690f867f..fc5370688a7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/discovery.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/discovery.proto @@ -13,6 +13,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "DiscoveryProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.discovery.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/eds.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/eds.proto index d757f17fc2f..4bd92355551 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/eds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/eds.proto @@ -15,6 +15,7 @@ import public "envoy/api/v2/endpoint.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "EdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option java_generic_services = true; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.endpoint.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint.proto index 70bac3c6c4f..13e90521b63 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "EndpointProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.endpoint.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint.proto index 247c9ae265a..2c2e9daa5c0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint.proto @@ -7,3 +7,4 @@ import public "envoy/api/v2/endpoint/endpoint_components.proto"; option java_package = "io.envoyproxy.envoy.api.v2.endpoint"; option java_outer_classname = "EndpointProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint_components.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint_components.proto index 78d45e2e08d..86a533bf0e6 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/endpoint_components.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.endpoint"; option java_outer_classname = "EndpointComponentsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.endpoint.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/load_report.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/load_report.proto index 928aed6102d..09dda612e4b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/load_report.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/endpoint/load_report.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.endpoint"; option java_outer_classname = "LoadReportProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.endpoint.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/lds.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/lds.proto index 01d9949777d..9c66e5426da 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/lds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/lds.proto @@ -15,6 +15,7 @@ import public "envoy/api/v2/listener.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "LdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option java_generic_services = true; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.listener.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener.proto index 1fdd202de42..139816dc286 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener.proto @@ -20,6 +20,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "ListenerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.listener.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener.proto index 273b29cb5dd..d007ba51c1f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener.proto @@ -7,5 +7,6 @@ import public "envoy/api/v2/listener/listener_components.proto"; option java_package = "io.envoyproxy.envoy.api.v2.listener"; option java_outer_classname = "ListenerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/listener"; option csharp_namespace = "Envoy.Api.V2.ListenerNS"; -option ruby_package = "Envoy.Api.V2.ListenerNS"; +option ruby_package = "Envoy::Api::V2::ListenerNS"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto index 08738962c5e..4ebae87f5db 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto @@ -18,8 +18,9 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.listener"; option java_outer_classname = "ListenerComponentsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/listener"; option csharp_namespace = "Envoy.Api.V2.ListenerNS"; -option ruby_package = "Envoy.Api.V2.ListenerNS"; +option ruby_package = "Envoy::Api::V2::ListenerNS"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.listener.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/udp_listener_config.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/udp_listener_config.proto index d4d29531f3a..d1642ab4213 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/udp_listener_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/udp_listener_config.proto @@ -11,8 +11,9 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.api.v2.listener"; option java_outer_classname = "UdpListenerConfigProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/listener"; option csharp_namespace = "Envoy.Api.V2.ListenerNS"; -option ruby_package = "Envoy.Api.V2.ListenerNS"; +option ruby_package = "Envoy::Api::V2::ListenerNS"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.listener.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/rds.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/rds.proto index faa5fdcf319..2ac30541aef 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/rds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/rds.proto @@ -15,6 +15,7 @@ import public "envoy/api/v2/route.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "RdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option java_generic_services = true; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.route.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/route.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/route.proto index 549f134a7f4..4f9e40a4409 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/route.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "RouteProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.route.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route.proto index ec13e9e5c80..0c52d051dd0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route.proto @@ -7,3 +7,4 @@ import public "envoy/api/v2/route/route_components.proto"; option java_package = "io.envoyproxy.envoy.api.v2.route"; option java_outer_classname = "RouteProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/route"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route_components.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route_components.proto index d73fbb8674c..062e73231d6 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/route/route_components.proto @@ -22,6 +22,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2.route"; option java_outer_classname = "RouteComponentsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2/route"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.route.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; @@ -675,8 +676,8 @@ message RouteAction { message FilterState { // The name of the Object in the per-request filterState, which is an - // Envoy::Http::Hashable object. If there is no data associated with the key, - // or the stored object is not Envoy::Http::Hashable, no hash will be produced. + // Envoy::Hashable object. If there is no data associated with the key, + // or the stored object is not Envoy::Hashable, no hash will be produced. string key = 1 [(validate.rules).string = {min_bytes: 1}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/scoped_route.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/scoped_route.proto index 0841bd08723..f3902d9d9e7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/scoped_route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/scoped_route.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "ScopedRouteProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.route.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/srds.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/srds.proto index 0edb99a1ecc..4f0ecab7657 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/srds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/srds.proto @@ -15,6 +15,7 @@ import public "envoy/api/v2/scoped_route.proto"; option java_package = "io.envoyproxy.envoy.api.v2"; option java_outer_classname = "SrdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/api/v2;apiv2"; option java_generic_services = true; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.route.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto index bb53286380c..b851949692f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto @@ -17,6 +17,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.accesslog.v3"; option java_outer_classname = "AccesslogProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3;accesslogv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common access log types] @@ -29,9 +30,7 @@ message AccessLog { reserved "config"; - // The name of the access log extension to instantiate. - // The name must match one of the compiled in loggers. - // See the :ref:`extensions listed in typed_config below ` for the default list of available loggers. + // The name of the access log extension configuration. string name = 1; // Filter which is used to determine if the access log needs to be written. @@ -83,6 +82,7 @@ message AccessLogFilter { GrpcStatusFilter grpc_status_filter = 10; // Extension filter. + // [#extension-category: envoy.access_loggers.extension_filters] ExtensionFilter extension_filter = 11; // Metadata Filter @@ -110,7 +110,7 @@ message ComparisonFilter { Op op = 1 [(validate.rules).enum = {defined_only: true}]; // Value to compare against. - core.v3.RuntimeUInt32 value = 2; + core.v3.RuntimeUInt32 value = 2 [(validate.rules).message = {required: true}]; } // Filters on HTTP response/status code. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto index 0e8de366333..bde4d5c3965 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto @@ -32,6 +32,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.bootstrap.v3"; option java_outer_classname = "BootstrapProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/bootstrap/v3;bootstrapv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Bootstrap] @@ -40,7 +41,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // ` for more detail. // Bootstrap :ref:`configuration overview `. -// [#next-free-field: 33] +// [#next-free-field: 34] message Bootstrap { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Bootstrap"; @@ -248,9 +249,6 @@ message Bootstrap { // when :ref:`dns_resolvers ` and // :ref:`use_tcp_for_dns_lookups ` are // specified. - // Setting this value causes failure if the - // ``envoy.restart_features.use_apple_api_for_dns_lookups`` runtime value is true during - // server startup. Apple' API only uses UDP for DNS resolution. // This field is deprecated in favor of *dns_resolution_config* // which aggregates all of the DNS resolver configuration in a single message. bool use_tcp_for_dns_lookups = 20 @@ -260,23 +258,22 @@ message Bootstrap { // This may be overridden on a per-cluster basis in cds_config, when // :ref:`dns_resolution_config ` // is specified. - // *dns_resolution_config* will be deprecated once - // :ref:'typed_dns_resolver_config ' - // is fully supported. - core.v3.DnsResolutionConfig dns_resolution_config = 30; + // This field is deprecated in favor of + // :ref:`typed_dns_resolver_config `. + core.v3.DnsResolutionConfig dns_resolution_config = 30 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, // or any other DNS resolver types and the related parameters. - // For example, an object of :ref:`DnsResolutionConfig ` - // can be packed into this *typed_dns_resolver_config*. This configuration will replace the - // :ref:'dns_resolution_config ' - // configuration eventually. - // TODO(yanjunxiang): Investigate the deprecation plan for *dns_resolution_config*. + // For example, an object of + // :ref:`CaresDnsResolverConfig ` + // can be packed into this *typed_dns_resolver_config*. This configuration replaces the + // :ref:`dns_resolution_config ` + // configuration. // During the transition period when both *dns_resolution_config* and *typed_dns_resolver_config* exists, - // this configuration is optional. - // When *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. + // when *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. // When *typed_dns_resolver_config* is missing, the default behavior is in place. - // [#not-implemented-hide:] + // [#extension-category: envoy.network.dns_resolver] core.v3.TypedExtensionConfig typed_dns_resolver_config = 31; // Specifies optional bootstrap extensions to be instantiated at startup time. @@ -329,11 +326,15 @@ message Bootstrap { // // Note that the 'set-cookie' header cannot be registered as inline header. repeated CustomInlineHeader inline_headers = 32; + + // Optional path to a file with performance tracing data created by "Perfetto" SDK in binary + // ProtoBuf format. The default value is "envoy.pftrace". + string perf_tracing_file_path = 33; } // Administration interface :ref:`operations documentation // `. -// [#next-free-field: 6] +// [#next-free-field: 7] message Admin { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Admin"; @@ -359,6 +360,10 @@ message Admin { // Additional socket options that may not be present in Envoy source code or // precompiled binaries. repeated core.v3.SocketOption socket_options = 4; + + // Indicates whether :ref:`global_downstream_max_connections ` + // should apply to the admin interface or not. + bool ignore_global_conn_limit = 6; } // Cluster manager :ref:`architecture overview `. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/aggregate/v2alpha/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/aggregate/v2alpha/cluster.proto index a0fdadd7572..3a6506eb8dc 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/aggregate/v2alpha/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/aggregate/v2alpha/cluster.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.cluster.aggregate.v2alpha"; option java_outer_classname = "ClusterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/cluster/aggregate/v2alpha"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.clusters.aggregate.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/circuit_breaker.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/circuit_breaker.proto index 82cd329b91a..fe798ceb090 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/circuit_breaker.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/circuit_breaker.proto @@ -14,6 +14,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.cluster.v3"; option java_outer_classname = "CircuitBreakerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3;clusterv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Circuit breakers] @@ -59,10 +60,12 @@ message CircuitBreakers { // The maximum number of pending requests that Envoy will allow to the // upstream cluster. If not specified, the default is 1024. + // This limit is applied as a connection limit for non-HTTP traffic. google.protobuf.UInt32Value max_pending_requests = 3; // The maximum number of parallel requests that Envoy will make to the // upstream cluster. If not specified, the default is 1024. + // This limit does not apply to non-HTTP traffic. google.protobuf.UInt32Value max_requests = 4; // The maximum number of parallel retries that Envoy will allow to the @@ -102,4 +105,17 @@ message CircuitBreakers { // :ref:`RoutingPriority`, the default values // are used. repeated Thresholds thresholds = 1; + + // Optional per-host limits which apply to each individual host in a cluster. + // + // .. note:: + // currently only the :ref:`max_connections + // ` field is supported for per-host limits. + // + // If multiple per-host :ref:`Thresholds` + // are defined with the same :ref:`RoutingPriority`, + // the first one in the list is used. If no per-host Thresholds are defined for a given + // :ref:`RoutingPriority`, + // the cluster will not have per-host limits. + repeated Thresholds per_host_thresholds = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto index d6213d6fe94..84bab4673b4 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto @@ -32,6 +32,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.cluster.v3"; option java_outer_classname = "ClusterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3;clusterv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Cluster configuration] @@ -43,7 +44,7 @@ message ClusterCollection { } // Configuration for a single upstream cluster. -// [#next-free-field: 56] +// [#next-free-field: 57] message Cluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster"; @@ -112,9 +113,9 @@ message Cluster { // Use the new :ref:`load_balancing_policy // ` field to determine the LB policy. - // [#next-major-version: In the v3 API, we should consider deprecating the lb_policy field - // and instead using the new load_balancing_policy field as the one and only mechanism for - // configuring this.] + // This has been deprecated in favor of using the :ref:`load_balancing_policy + // ` field without + // setting any value in :ref:`lb_policy`. LOAD_BALANCING_POLICY_CONFIG = 7; } @@ -123,15 +124,28 @@ message Cluster { // only perform a lookup for addresses in the IPv6 family. If AUTO is // specified, the DNS resolver will first perform a lookup for addresses in // the IPv6 family and fallback to a lookup for addresses in the IPv4 family. + // This is semantically equivalent to a non-existent V6_PREFERRED option. + // AUTO is a legacy name that is more opaque than + // necessary and will be deprecated in favor of V6_PREFERRED in a future major version of the API. + // If V4_PREFERRED is specified, the DNS resolver will first perform a lookup for addresses in the + // IPv4 family and fallback to a lookup for addresses in the IPv6 family. i.e., the callback + // target will only get v6 addresses if there were NO v4 addresses to return. + // If ALL is specified, the DNS resolver will perform a lookup for both IPv4 and IPv6 families, + // and return all resolved addresses. When this is used, Happy Eyeballs will be enabled for + // upstream connections. Refer to :ref:`Happy Eyeballs Support ` + // for more information. // For cluster types other than // :ref:`STRICT_DNS` and // :ref:`LOGICAL_DNS`, // this setting is // ignored. + // [#next-major-version: deprecate AUTO in favor of a V6_PREFERRED option.] enum DnsLookupFamily { AUTO = 0; V4_ONLY = 1; V6_ONLY = 2; + V4_PREFERRED = 3; + ALL = 4; } enum ClusterProtocolSelection { @@ -337,6 +351,40 @@ message Cluster { bool list_as_any = 7; } + // Configuration for :ref:`slow start mode `. + message SlowStartConfig { + // Represents the size of slow start window. + // If set, the newly created host remains in slow start mode starting from its creation time + // for the duration of slow start window. + google.protobuf.Duration slow_start_window = 1; + + // This parameter controls the speed of traffic increase over the slow start window. Defaults to 1.0, + // so that endpoint would get linearly increasing amount of traffic. + // When increasing the value for this parameter, the speed of traffic ramp-up increases non-linearly. + // The value of aggression parameter should be greater than 0.0. + // By tuning the parameter, is possible to achieve polynomial or exponential shape of ramp-up curve. + // + // During slow start window, effective weight of an endpoint would be scaled with time factor and aggression: + // `new_weight = weight * max(min_weight_percent, time_factor ^ (1 / aggression))`, + // where `time_factor=(time_since_start_seconds / slow_start_time_seconds)`. + // + // As time progresses, more and more traffic would be sent to endpoint, which is in slow start window. + // Once host exits slow start, time_factor and aggression no longer affect its weight. + core.v3.RuntimeDouble aggression = 2; + + // Configures the minimum percentage of origin weight that avoids too small new weight, + // which may cause endpoints in slow start mode receive no traffic in slow start window. + // If not specified, the default is 10%. + type.v3.Percent min_weight_percent = 3; + } + + // Specific configuration for the RoundRobin load balancing policy. + message RoundRobinLbConfig { + // Configuration for slow start mode. + // If this configuration is not set, slow start will not be not enabled. + SlowStartConfig slow_start_config = 1; + } + // Specific configuration for the LeastRequest load balancing policy. message LeastRequestLbConfig { option (udpa.annotations.versioning).previous_message_type = @@ -370,6 +418,10 @@ message Cluster { // .. note:: // This setting only takes effect if all host weights are not equal. core.v3.RuntimeDouble active_request_bias = 2; + + // Configuration for slow start mode. + // If this configuration is not set, slow start will not be not enabled. + SlowStartConfig slow_start_config = 3; } // Specific configuration for the :ref:`RingHash` @@ -424,9 +476,8 @@ message Cluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster.OriginalDstLbConfig"; - // When true, :ref:`x-envoy-original-dst-host - // ` can be used to override destination - // address. + // When true, a HTTP header can be used to override the original dst address. The default header is + // :ref:`x-envoy-original-dst-host `. // // .. attention:: // @@ -438,10 +489,14 @@ message Cluster { // // If the header appears multiple times only the first value is used. bool use_http_header = 1; + + // The http header to override destination address if :ref:`use_http_header `. + // is set to true. If the value is empty, :ref:`x-envoy-original-dst-host ` will be used. + string http_header_name = 2; } // Common configuration for all load balancer implementations. - // [#next-free-field: 8] + // [#next-free-field: 9] message CommonLbConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster.CommonLbConfig"; @@ -550,6 +605,14 @@ message Cluster { // Common Configuration for all consistent hashing load balancers (MaglevLb, RingHashLb, etc.) ConsistentHashingLbConfig consistent_hashing_lb_config = 7; + + // This controls what hosts are considered valid when using + // :ref:`host overrides `, which is used by some + // filters to modify the load balancing decision. + // + // If this is unset then [UNKNOWN, HEALTHY, DEGRADED] will be applied by default. If this is + // set with an empty set of statuses then host overrides will be ignored by the load balancing. + core.v3.HealthStatusSet override_host_status = 8; } message RefreshRate { @@ -690,11 +753,9 @@ message Cluster { // emitting stats for the cluster and access logging the cluster name. This will appear as // additional information in configuration dumps of a cluster's current status as // :ref:`observability_name ` - // and as an additional tag "upstream_cluster.name" while tracing. Note: access logging using - // this field is presently enabled with runtime feature - // `envoy.reloadable_features.use_observable_cluster_name`. Any ``:`` in the name will be - // converted to ``_`` when emitting statistics. This should not be confused with :ref:`Router - // Filter Header `. + // and as an additional tag "upstream_cluster.name" while tracing. Note: Any ``:`` in the name + // will be converted to ``_`` when emitting statistics. This should not be confused with + // :ref:`Router Filter Header `. string alt_stat_name = 28 [(udpa.annotations.field_migrate).rename = "observability_name"]; oneof cluster_discovery_type { @@ -859,41 +920,34 @@ message Cluster { // :ref:`STRICT_DNS` // and :ref:`LOGICAL_DNS` // this setting is ignored. - // Setting this value causes failure if the - // ``envoy.restart_features.use_apple_api_for_dns_lookups`` runtime value is true during - // server startup. Apple's API only allows overriding DNS resolvers via system settings. // This field is deprecated in favor of *dns_resolution_config* // which aggregates all of the DNS resolver configuration in a single message. repeated core.v3.Address dns_resolvers = 18 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Always use TCP queries instead of UDP queries for DNS lookups. - // Setting this value causes failure if the - // ``envoy.restart_features.use_apple_api_for_dns_lookups`` runtime value is true during - // server startup. Apple' API only uses UDP for DNS resolution. // This field is deprecated in favor of *dns_resolution_config* // which aggregates all of the DNS resolver configuration in a single message. bool use_tcp_for_dns_lookups = 45 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // DNS resolution configuration which includes the underlying dns resolver addresses and options. - // *dns_resolution_config* will be deprecated once - // :ref:'typed_dns_resolver_config ' - // is fully supported. - core.v3.DnsResolutionConfig dns_resolution_config = 53; + // This field is deprecated in favor of + // :ref:`typed_dns_resolver_config `. + core.v3.DnsResolutionConfig dns_resolution_config = 53 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, // or any other DNS resolver types and the related parameters. - // For example, an object of :ref:`DnsResolutionConfig ` - // can be packed into this *typed_dns_resolver_config*. This configuration will replace the - // :ref:'dns_resolution_config ' - // configuration eventually. - // TODO(yanjunxiang): Investigate the deprecation plan for *dns_resolution_config*. + // For example, an object of + // :ref:`CaresDnsResolverConfig ` + // can be packed into this *typed_dns_resolver_config*. This configuration replaces the + // :ref:`dns_resolution_config ` + // configuration. // During the transition period when both *dns_resolution_config* and *typed_dns_resolver_config* exists, - // this configuration is optional. - // When *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. + // when *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. // When *typed_dns_resolver_config* is missing, the default behavior is in place. - // [#not-implemented-hide:] + // [#extension-category: envoy.network.dns_resolver] core.v3.TypedExtensionConfig typed_dns_resolver_config = 55; // Optional configuration for having cluster readiness block on warm-up. Currently, only applicable for @@ -951,6 +1005,9 @@ message Cluster { // Optional configuration for the LeastRequest load balancing policy. LeastRequestLbConfig least_request_lb_config = 37; + + // Optional configuration for the RoundRobin load balancing policy. + RoundRobinLbConfig round_robin_lb_config = 56; } // Common configuration for all load balancer implementations. @@ -1007,9 +1064,8 @@ message Cluster { // servers of this cluster. repeated Filter filters = 40; - // New mechanism for LB policy configuration. Used only if the - // :ref:`lb_policy` field has the value - // :ref:`LOAD_BALANCING_POLICY_CONFIG`. + // If this field is set and is supported by the client, it will supersede the value of + // :ref:`lb_policy`. LoadBalancingPolicy load_balancing_policy = 41; // [#not-implemented-hide:] @@ -1126,6 +1182,11 @@ message UpstreamConnectionOptions { // If set then set SO_KEEPALIVE on the socket to enable TCP Keepalives. core.v3.TcpKeepalive tcp_keepalive = 1; + + // If enabled, associates the interface name of the local address with the upstream connection. + // This can be used by extensions during processing of requests. The association mechanism is + // implementation specific. Defaults to false due to performance concerns. + bool set_local_interface_name_on_upstream_connections = 2; } message TrackClusterStats { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/filter.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/filter.proto index 7d11b87bcd5..c6b8722b923 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/filter.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/filter.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.cluster.v3"; option java_outer_classname = "FilterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3;clusterv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Upstream filters] @@ -19,12 +20,12 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message Filter { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.cluster.Filter"; - // The name of the filter to instantiate. The name must match a - // supported upstream filter. Note that Envoy's :ref:`downstream network - // filters ` are not valid upstream filters. + // The name of the filter configuration. string name = 1 [(validate.rules).string = {min_len: 1}]; // Filter specific configuration which depends on the filter being // instantiated. See the supported filters for further documentation. + // Note that Envoy's :ref:`downstream network + // filters ` are not valid upstream filters. google.protobuf.Any typed_config = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto index b19e95db99b..85438863143 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto @@ -12,13 +12,14 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.cluster.v3"; option java_outer_classname = "OutlierDetectionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3;clusterv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Outlier detection] // See the :ref:`architecture overview ` for // more information on outlier detection. -// [#next-free-field: 22] +// [#next-free-field: 23] message OutlierDetection { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.cluster.OutlierDetection"; @@ -154,4 +155,10 @@ message OutlierDetection { // for more information. If not specified, the default value (300000ms or 300s) or // :ref:`base_ejection_time` value is applied, whatever is larger. google.protobuf.Duration max_ejection_time = 21 [(validate.rules).duration = {gt {}}]; + + // The maximum amount of jitter to add to the ejection time, in order to prevent + // a 'thundering herd' effect where all proxies try to reconnect to host at the same time. + // See :ref:`max_ejection_time_jitter` + // Defaults to 0s. + google.protobuf.Duration max_ejection_time_jitter = 22; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto index 06876d5f8e4..3f1b6fe3dc8 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "AddressProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Network addresses] @@ -30,9 +31,9 @@ message Pipe { uint32 mode = 2 [(validate.rules).uint32 = {lte: 511}]; } -// [#not-implemented-hide:] The address represents an envoy internal listener. -// TODO(lambdai): Make this address available for listener and endpoint. -// TODO(asraa): When address available, remove workaround from test/server/server_fuzz_test.cc:30. +// The address represents an envoy internal listener. +// [#comment: TODO(lambdai): Make this address available for listener and endpoint. +// TODO(asraa): When address available, remove workaround from test/server/server_fuzz_test.cc:30.] message EnvoyInternalAddress { oneof address_name_specifier { option (validate.required) = true; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/backoff.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/backoff.proto index 3ffa97bb029..1899d1abf11 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/backoff.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/backoff.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "BackoffProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Backoff Strategy] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto index d6c507b8dec..f8d94a49dd6 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto @@ -23,6 +23,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "BaseProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common types] @@ -296,6 +297,15 @@ message RuntimeFeatureFlag { string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; } +// Query parameter name/value pair. +message QueryParameter { + // The key of the query parameter. Case sensitive. + string key = 1 [(validate.rules).string = {min_len: 1}]; + + // The value of the query parameter. + string value = 2; +} + // Header name/value pair. message HeaderValue { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HeaderValue"; @@ -320,12 +330,33 @@ message HeaderValueOption { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HeaderValueOption"; + // Describes the supported actions types for header append action. + enum HeaderAppendAction { + // This action will append the specified value to the existing values if the header + // already exists. If the header doesn't exist then this will add the header with + // specified key and value. + APPEND_IF_EXISTS_OR_ADD = 0; + + // This action will add the header if it doesn't already exist. If the header + // already exists then this will be a no-op. + ADD_IF_ABSENT = 1; + + // This action will overwrite the specified value by discarding any existing values if + // the header already exists. If the header doesn't exist then this will add the header + // with specified key and value. + OVERWRITE_IF_EXISTS_OR_ADD = 2; + } + // Header name/value pair that this option applies to. HeaderValue header = 1 [(validate.rules).message = {required: true}]; // Should the value be appended? If true (default), the value is appended to // existing values. Otherwise it replaces any existing values. google.protobuf.BoolValue append = 2; + + // [#not-implemented-hide:] Describes the action taken to append/overwrite the given value for an existing header + // or to only add this header if it's absent. Value defaults to :ref:`APPEND_IF_EXISTS_OR_ADD`. + HeaderAppendAction append_action = 3 [(validate.rules).enum = {defined_only: true}]; } // Wrapper for a set of headers. @@ -342,7 +373,7 @@ message WatchedDirectory { string path = 1 [(validate.rules).string = {min_len: 1}]; } -// Data source consisting of either a file or an inline value. +// Data source consisting of a file, an inline value, or an environment variable. message DataSource { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.DataSource"; @@ -357,6 +388,9 @@ message DataSource { // String inlined in the configuration. string inline_string = 3; + + // Environment variable data source. + string environment_variable = 4 [(validate.rules).string = {min_len: 1}]; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto index 43519c010b7..a49a05de8d4 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto @@ -2,8 +2,11 @@ syntax = "proto3"; package envoy.config.core.v3; +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/grpc_service.proto"; +import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -17,6 +20,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "ConfigSourceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Configuration sources] @@ -38,7 +42,7 @@ enum ApiVersion { // API configuration source. This identifies the API type and cluster that Envoy // will use to fetch an xDS API. -// [#next-free-field: 9] +// [#next-free-field: 10] message ApiConfigSource { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.ApiConfigSource"; @@ -106,6 +110,16 @@ message ApiConfigSource { // Skip the node identifier in subsequent discovery requests for streaming gRPC config types. bool set_node_on_first_message_only = 7; + + // A list of config validators that will be executed when a new update is + // received from the ApiConfigSource. Note that each validator handles a + // specific xDS service type, and only the validators corresponding to the + // type url (in `:ref: DiscoveryResponse` or `:ref: DeltaDiscoveryResponse`) + // will be invoked. + // If the validator returns false or throws an exception, the config will be rejected by + // the client, and a NACK will be sent. + // [#extension-category: envoy.config.validators] + repeated TypedExtensionConfig config_validators = 9; } // Aggregated Discovery Service (ADS) options. This is currently empty, but when @@ -142,13 +156,49 @@ message RateLimitSettings { google.protobuf.DoubleValue fill_rate = 2 [(validate.rules).double = {gt: 0.0}]; } +// Local filesystem path configuration source. +message PathConfigSource { + // Path on the filesystem to source and watch for configuration updates. + // When sourcing configuration for a :ref:`secret `, + // the certificate and key files are also watched for updates. + // + // .. note:: + // + // The path to the source must exist at config load time. + // + // .. note:: + // + // If `watched_directory` is *not* configured, Envoy will watch the file path for *moves.* + // This is because in general only moves are atomic. The same method of swapping files as is + // demonstrated in the :ref:`runtime documentation ` can be + // used here also. If `watched_directory` is configured, no watch will be placed directly on + // this path. Instead, the configured `watched_directory` will be used to trigger reloads of + // this path. This is required in certain deployment scenarios. See below for more information. + string path = 1 [(validate.rules).string = {min_len: 1}]; + + // If configured, this directory will be watched for *moves.* When an entry in this directory is + // moved to, the `path` will be reloaded. This is required in certain deployment scenarios. + // + // Specifically, if trying to load an xDS resource using a + // `Kubernetes ConfigMap `_, the + // following configuration might be used: + // 1. Store xds.yaml inside a ConfigMap. + // 2. Mount the ConfigMap to `/config_map/xds` + // 3. Configure path `/config_map/xds/xds.yaml` + // 4. Configure watched directory `/config_map/xds` + // + // The above configuration will ensure that Envoy watches the owning directory for moves which is + // required due to how Kubernetes manages ConfigMap symbolic links during atomic updates. + WatchedDirectory watched_directory = 2; +} + // Configuration for :ref:`listeners `, :ref:`clusters // `, :ref:`routes // `, :ref:`endpoints // ` etc. may either be sourced from the // filesystem or from an xDS API source. Filesystem configs are watched with // inotify for updates. -// [#next-free-field: 8] +// [#next-free-field: 9] message ConfigSource { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.ConfigSource"; @@ -161,20 +211,11 @@ message ConfigSource { oneof config_source_specifier { option (validate.required) = true; - // Path on the filesystem to source and watch for configuration updates. - // When sourcing configuration for :ref:`secret `, - // the certificate and key files are also watched for updates. - // - // .. note:: - // - // The path to the source must exist at config load time. - // - // .. note:: - // - // Envoy will only watch the file path for *moves.* This is because in general only moves - // are atomic. The same method of swapping files as is demonstrated in the - // :ref:`runtime documentation ` can be used here also. - string path = 1; + // Deprecated in favor of `path_config_source`. Use that field instead. + string path = 1 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Local filesystem path configuration source. + PathConfigSource path_config_source = 8; // API configuration source. ApiConfigSource api_config_source = 2; @@ -211,3 +252,32 @@ message ConfigSource { // turn expect to be delivered. ApiVersion resource_api_version = 6 [(validate.rules).enum = {defined_only: true}]; } + +// Configuration source specifier for a late-bound extension configuration. The +// parent resource is warmed until all the initial extension configurations are +// received, unless the flag to apply the default configuration is set. +// Subsequent extension updates are atomic on a per-worker basis. Once an +// extension configuration is applied to a request or a connection, it remains +// constant for the duration of processing. If the initial delivery of the +// extension configuration fails, due to a timeout for example, the optional +// default configuration is applied. Without a default configuration, the +// extension is disabled, until an extension configuration is received. The +// behavior of a disabled extension depends on the context. For example, a +// filter chain with a disabled extension filter rejects all incoming streams. +message ExtensionConfigSource { + ConfigSource config_source = 1 [(validate.rules).any = {required: true}]; + + // Optional default configuration to use as the initial configuration if + // there is a failure to receive the initial extension configuration or if + // `apply_default_config_without_warming` flag is set. + google.protobuf.Any default_config = 2; + + // Use the default config as the initial configuration without warming and + // waiting for the first discovery response. Requires the default configuration + // to be supplied. + bool apply_default_config_without_warming = 3; + + // A set of permitted extension type URLs. Extension configuration updates are rejected + // if they do not match any type URL in the set. + repeated string type_urls = 4 [(validate.rules).repeated = {min_items: 1}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/event_service_config.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/event_service_config.proto index b3552e3975a..68c8df4076e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/event_service_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/event_service_config.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "EventServiceConfigProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#not-implemented-hide:] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/extension.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/extension.proto index ba66da6a8e3..80afce693cd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/extension.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/extension.proto @@ -2,8 +2,6 @@ syntax = "proto3"; package envoy.config.core.v3; -import "envoy/config/core/v3/config_source.proto"; - import "google/protobuf/any.proto"; import "udpa/annotations/status.proto"; @@ -12,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "ExtensionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Extension configuration] @@ -24,38 +23,10 @@ message TypedExtensionConfig { string name = 1 [(validate.rules).string = {min_len: 1}]; // The typed config for the extension. The type URL will be used to identify - // the extension. In the case that the type URL is *udpa.type.v1.TypedStruct*, - // the inner type URL of *TypedStruct* will be utilized. See the + // the extension. In the case that the type URL is *xds.type.v3.TypedStruct* + // (or, for historical reasons, *udpa.type.v1.TypedStruct*), the inner type + // URL of *TypedStruct* will be utilized. See the // :ref:`extension configuration overview // ` for further details. google.protobuf.Any typed_config = 2 [(validate.rules).any = {required: true}]; } - -// Configuration source specifier for a late-bound extension configuration. The -// parent resource is warmed until all the initial extension configurations are -// received, unless the flag to apply the default configuration is set. -// Subsequent extension updates are atomic on a per-worker basis. Once an -// extension configuration is applied to a request or a connection, it remains -// constant for the duration of processing. If the initial delivery of the -// extension configuration fails, due to a timeout for example, the optional -// default configuration is applied. Without a default configuration, the -// extension is disabled, until an extension configuration is received. The -// behavior of a disabled extension depends on the context. For example, a -// filter chain with a disabled extension filter rejects all incoming streams. -message ExtensionConfigSource { - ConfigSource config_source = 1 [(validate.rules).any = {required: true}]; - - // Optional default configuration to use as the initial configuration if - // there is a failure to receive the initial extension configuration or if - // `apply_default_config_without_warming` flag is set. - google.protobuf.Any default_config = 2; - - // Use the default config as the initial configuration without warming and - // waiting for the first discovery response. Requires the default configuration - // to be supplied. - bool apply_default_config_without_warming = 3; - - // A set of permitted extension type URLs. Extension configuration updates are rejected - // if they do not match any type URL in the set. - repeated string type_urls = 4 [(validate.rules).repeated = {min_items: 1}]; -} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto index a7f29c8f529..4fb60955805 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto @@ -18,6 +18,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "GrpcServiceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: gRPC services] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto index 304297e7c01..83cce7ccdb5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto @@ -20,6 +20,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "HealthCheckProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Health check] @@ -53,6 +54,12 @@ enum HealthStatus { DEGRADED = 5; } +message HealthStatusSet { + // An order-independent set of health status. + repeated HealthStatus statuses = 1 + [(validate.rules).repeated = {items {enum {defined_only: true}}}]; +} + // [#next-free-field: 25] message HealthCheck { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HealthCheck"; @@ -73,7 +80,7 @@ message HealthCheck { } } - // [#next-free-field: 12] + // [#next-free-field: 13] message HttpHealthCheck { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HealthCheck.HttpHealthCheck"; @@ -118,6 +125,18 @@ message HealthCheck { // range are required. Only statuses in the range [100, 600) are allowed. repeated type.v3.Int64Range expected_statuses = 9; + // Specifies a list of HTTP response statuses considered retriable. If provided, responses in this range + // will count towards the configured :ref:`unhealthy_threshold `, + // but will not result in the host being considered immediately unhealthy. Ranges follow half-open semantics of + // :ref:`Int64Range `. The start and end of each range are required. + // Only statuses in the range [100, 600) are allowed. The :ref:`expected_statuses ` + // field takes precedence for any range overlaps with this field i.e. if status code 200 is both retriable and expected, a 200 response will + // be considered a successful health check. By default all responses not in + // :ref:`expected_statuses ` will result in + // the host being considered immediately unhealthy i.e. if status code 200 is expected and there are no configured retriable statuses, any + // non-200 response will result in the host being marked unhealthy. + repeated type.v3.Int64Range retriable_statuses = 12; + // Use specified application protocol for health checks. type.v3.CodecClientType codec_client_type = 10 [(validate.rules).enum = {defined_only: true}]; @@ -173,6 +192,12 @@ message HealthCheck { // the :ref:`hostname ` field. string authority = 2 [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; + + // Specifies a list of key-value pairs that should be added to the metadata of each GRPC call + // that is sent to the health checked cluster. For more information, including details on header value syntax, + // see the documentation on :ref:`custom request headers + // `. + repeated HeaderValueOption initial_metadata = 3 [(validate.rules).repeated = {max_items: 1000}]; } // Custom health check. @@ -243,8 +268,10 @@ message HealthCheck { uint32 interval_jitter_percent = 18; // The number of unhealthy health checks required before a host is marked - // unhealthy. Note that for *http* health checking if a host responds with 503 - // this threshold is ignored and the host is considered unhealthy immediately. + // unhealthy. Note that for *http* health checking if a host responds with a code not in + // :ref:`expected_statuses ` + // or :ref:`retriable_statuses `, + // this threshold is ignored and the host is considered immediately unhealthy. google.protobuf.UInt32Value unhealthy_threshold = 4 [(validate.rules).message = {required: true}]; // The number of healthy health checks required before a host is marked diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/http_uri.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/http_uri.proto index 5d1fc239e07..ec0f71f9055 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/http_uri.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/http_uri.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "HttpUriProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: HTTP Service URI ] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto index 8f2347eb551..f18a2053d9d 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto @@ -8,6 +8,8 @@ import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "xds/annotations/v3/status.proto"; + import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -16,6 +18,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "ProtocolProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Protocol options] @@ -26,11 +29,38 @@ message TcpProtocolOptions { "envoy.api.v2.core.TcpProtocolOptions"; } +// Config for keepalive probes in a QUIC connection. +// Note that QUIC keep-alive probing packets work differently from HTTP/2 keep-alive PINGs in a sense that the probing packet +// itself doesn't timeout waiting for a probing response. Quic has a shorter idle timeout than TCP, so it doesn't rely on such probing to discover dead connections. If the peer fails to respond, the connection will idle timeout eventually. Thus, they are configured differently from :ref:`connection_keepalive `. +message QuicKeepAliveSettings { + // The max interval for a connection to send keep-alive probing packets (with PING or PATH_RESPONSE). The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout while not less than 1s to avoid throttling the connection or flooding the peer with probes. + // + // If :ref:`initial_interval ` is absent or zero, a client connection will use this value to start probing. + // + // If zero, disable keepalive probing. + // If absent, use the QUICHE default interval to probe. + google.protobuf.Duration max_interval = 1 [(validate.rules).duration = { + lte {} + gte {seconds: 1} + }]; + + // The interval to send the first few keep-alive probing packets to prevent connection from hitting the idle timeout. Subsequent probes will be sent, each one with an interval exponentially longer than previous one, till it reaches :ref:`max_interval `. And the probes afterwards will always use :ref:`max_interval `. + // + // The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout and smaller than max_interval to take effect. + // + // If absent or zero, disable keepalive probing for a server connection. For a client connection, if :ref:`max_interval ` is also zero, do not keepalive, otherwise use max_interval or QUICHE default to probe all the time. + google.protobuf.Duration initial_interval = 2 [(validate.rules).duration = { + lte {} + gte {seconds: 1} + }]; +} + // QUIC protocol options which apply to both downstream and upstream connections. +// [#next-free-field: 6] message QuicProtocolOptions { // Maximum number of streams that the client can negotiate per connection. 100 // if not specified. - google.protobuf.UInt32Value max_concurrent_streams = 1; + google.protobuf.UInt32Value max_concurrent_streams = 1 [(validate.rules).uint32 = {gte: 1}]; // `Initial stream-level flow-control receive window // `_ size. Valid values range from @@ -53,6 +83,17 @@ message QuicProtocolOptions { // window size now, so it's also the minimum. google.protobuf.UInt32Value initial_connection_window_size = 3 [(validate.rules).uint32 = {lte: 25165824 gte: 1}]; + + // The number of timeouts that can occur before port migration is triggered for QUIC clients. + // This defaults to 1. If set to 0, port migration will not occur on path degrading. + // Timeout here refers to QUIC internal path degrading timeout mechanism, such as PTO. + // This has no effect on server sessions. + google.protobuf.UInt32Value num_timeouts_to_trigger_port_migration = 4 + [(validate.rules).uint32 = {lte: 5 gte: 0}]; + + // Probes the peer at the configured interval to solicit traffic, i.e. ACK or PATH_RESPONSE, from the peer to push back connection idle timeout. + // If absent, use the default keepalive behavior of which a client connection sends PINGs every 15s, and a server connection doesn't do anything. + QuicKeepAliveSettings connection_keepalive = 5; } message UpstreamHttpProtocolOptions { @@ -60,15 +101,26 @@ message UpstreamHttpProtocolOptions { "envoy.api.v2.core.UpstreamHttpProtocolOptions"; // Set transport socket `SNI `_ for new - // upstream connections based on the downstream HTTP host/authority header, as seen by the - // :ref:`router filter `. + // upstream connections based on the downstream HTTP host/authority header or any other arbitrary + // header when :ref:`override_auto_sni_header ` + // is set, as seen by the :ref:`router filter `. bool auto_sni = 1; // Automatic validate upstream presented certificate for new upstream connections based on the - // downstream HTTP host/authority header, as seen by the - // :ref:`router filter `. - // This field is intended to set with `auto_sni` field. + // downstream HTTP host/authority header or any other arbitrary header when :ref:`override_auto_sni_header ` + // is set, as seen by the :ref:`router filter `. + // This field is intended to be set with `auto_sni` field. bool auto_san_validation = 2; + + // An optional alternative to the host/authority header to be used for setting the SNI value. + // It should be a valid downstream HTTP header, as seen by the + // :ref:`router filter `. + // If unset, host/authority header will be used for populating the SNI. If the specified header + // is not found or the value is empty, host/authority header will be used instead. + // This field is intended to be set with `auto_sni` and/or `auto_san_validation` fields. + // If none of these fields are set then setting this would be a no-op. + string override_auto_sni_header = 3 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME ignore_empty: true}]; } // Configures the alternate protocols cache which tracks alternate protocols that can be used to @@ -76,6 +128,24 @@ message UpstreamHttpProtocolOptions { // HTTP Alternative Services and https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-04 // for the "HTTPS" DNS resource record. message AlternateProtocolsCacheOptions { + // Allows pre-populating the cache with HTTP/3 alternate protocols entries with a 7 day lifetime. + // This will cause Envoy to attempt HTTP/3 to those upstreams, even if the upstreams have not + // advertised HTTP/3 support. These entries will be overwritten by alt-svc + // response headers or cached values. + // As with regular cached entries, if the origin response would result in clearing an existing + // alternate protocol cache entry, pre-populated entries will also be cleared. + // Adding a cache entry with hostname=foo.com port=123 is the equivalent of getting + // response headers + // alt-svc: h3=:"123"; ma=86400" in a response to a request to foo.com:123 + message AlternateProtocolsCacheEntry { + // The host name for the alternate protocol entry. + string hostname = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME ignore_empty: true}]; + + // The port for the alternate protocol entry. + uint32 port = 2 [(validate.rules).uint32 = {lt: 65535 gt: 0}]; + } + // The name of the cache. Multiple named caches allow independent alternate protocols cache // configurations to operate within a single Envoy process using different configurations. All // alternate protocols cache options with the same name *must* be equal in all fields when @@ -91,6 +161,16 @@ message AlternateProtocolsCacheOptions { // it is possible for the maximum entries in the cache to go slightly above the configured // value depending on timing. This is similar to how other circuit breakers work. google.protobuf.UInt32Value max_entries = 2 [(validate.rules).uint32 = {gt: 0}]; + + // Allows configuring a persistent + // :ref:`key value store ` to flush + // alternate protocols entries to disk. + // This function is currently only supported if concurrency is 1 + // Cached entries will take precedence over pre-populated entries below. + TypedExtensionConfig key_value_store_config = 3; + + // Allows pre-populating the cache with entries, as described above. + repeated AlternateProtocolsCacheEntry prepopulated_entries = 4; } // [#next-free-field: 7] @@ -112,7 +192,7 @@ message HttpProtocolOptions { // is incremented for each rejected request. REJECT_REQUEST = 1; - // Drop the header with name containing underscores. The header is dropped before the filter chain is + // Drop the client header with name containing underscores. The header is dropped before the filter chain is // invoked and as such filters will not see dropped headers. The // "httpN.dropped_headers_with_underscores" is incremented for each dropped header. DROP_HEADER = 2; @@ -138,10 +218,10 @@ message HttpProtocolOptions { // The maximum duration of a connection. The duration is defined as a period since a connection // was established. If not set, there is no max duration. When max_connection_duration is reached - // the connection will be closed. Drain sequence will occur prior to closing the connection if - // if's applicable. See :ref:`drain_timeout + // and if there are no active streams, the connection will be closed. If the connection is a + // downstream connection and there are any active streams, the drain sequence will kick-in, + // and the connection will be force-closed after the drain period. See :ref:`drain_timeout // `. - // Note: not implemented for upstream connections. google.protobuf.Duration max_connection_duration = 3; // The maximum number of headers. If unconfigured, the default @@ -156,6 +236,8 @@ message HttpProtocolOptions { // Action to take when a client request with a header name containing underscore characters is received. // If this setting is not specified, the value defaults to ALLOW. // Note: upstream responses are not affected by this setting. + // Note: this only affects client headers. It does not affect headers added + // by Envoy filters and does not have any impact if added to cluster config. HeadersWithUnderscoresAction headers_with_underscores_action = 5; // Optional maximum requests for both upstream and downstream connections. @@ -232,7 +314,7 @@ message Http1ProtocolOptions { // Allows Envoy to process requests/responses with both `Content-Length` and `Transfer-Encoding` // headers set. By default such messages are rejected, but if option is enabled - Envoy will // remove Content-Length header and process message. - // See `RFC7230, sec. 3.3.3 ` for details. + // See `RFC7230, sec. 3.3.3 `_ for details. // // .. attention:: // Enabling this option might lead to request smuggling vulnerability, especially if traffic @@ -254,7 +336,9 @@ message KeepaliveSettings { google.protobuf.Duration interval = 1 [(validate.rules).duration = {gte {nanos: 1000000}}]; // How long to wait for a response to a keepalive PING. If a response is not received within this - // time period, the connection will be aborted. + // time period, the connection will be aborted. Note that in order to prevent the influence of + // Head-of-line (HOL) blocking the timeout period is extended when *any* frame is received on + // the connection, under the assumption that if a frame is received the connection is healthy. google.protobuf.Duration timeout = 2 [(validate.rules).duration = { required: true gte {nanos: 1000000} @@ -270,6 +354,8 @@ message KeepaliveSettings { // If this is zero, this type of PING will not be sent. // If an interval ping is outstanding, a second ping will not be sent as the // interval ping will determine if the connection is dead. + // + // The same feature for HTTP/3 is given by inheritance from QUICHE which uses :ref:`connection idle_timeout ` and the current PTO of the connection to decide whether to probe before sending a new request. google.protobuf.Duration connection_idle_interval = 4 [(validate.rules).duration = {gte {nanos: 1000000}}]; } @@ -349,8 +435,6 @@ message Http2ProtocolOptions { // be written into the socket). Exceeding this limit triggers flood mitigation and connection is // terminated. The ``http2.outbound_flood`` stat tracks the number of terminated connections due // to flood mitigation. The default limit is 10000. - // NOTE: flood and abuse mitigation for upstream connections is presently enabled by the - // `envoy.reloadable_features.upstream_http2_flood_checks` flag. google.protobuf.UInt32Value max_outbound_frames = 7 [(validate.rules).uint32 = {gte: 1}]; // Limit the number of pending outbound downstream frames of types PING, SETTINGS and RST_STREAM, @@ -358,8 +442,6 @@ message Http2ProtocolOptions { // this limit triggers flood mitigation and connection is terminated. The // ``http2.outbound_control_flood`` stat tracks the number of terminated connections due to flood // mitigation. The default limit is 1000. - // NOTE: flood and abuse mitigation for upstream connections is presently enabled by the - // `envoy.reloadable_features.upstream_http2_flood_checks` flag. google.protobuf.UInt32Value max_outbound_control_frames = 8 [(validate.rules).uint32 = {gte: 1}]; // Limit the number of consecutive inbound frames of types HEADERS, CONTINUATION and DATA with an @@ -368,8 +450,6 @@ message Http2ProtocolOptions { // stat tracks the number of connections terminated due to flood mitigation. // Setting this to 0 will terminate connection upon receiving first frame with an empty payload // and no end stream flag. The default limit is 1. - // NOTE: flood and abuse mitigation for upstream connections is presently enabled by the - // `envoy.reloadable_features.upstream_http2_flood_checks` flag. google.protobuf.UInt32Value max_consecutive_inbound_frames_with_empty_payload = 9; // Limit the number of inbound PRIORITY frames allowed per each opened stream. If the number @@ -383,8 +463,6 @@ message Http2ProtocolOptions { // `opened_streams` is incremented when Envoy send the HEADERS frame for a new stream. The // ``http2.inbound_priority_frames_flood`` stat tracks // the number of connections terminated due to flood mitigation. The default limit is 100. - // NOTE: flood and abuse mitigation for upstream connections is presently enabled by the - // `envoy.reloadable_features.upstream_http2_flood_checks` flag. google.protobuf.UInt32Value max_inbound_priority_frames_per_stream = 10; // Limit the number of inbound WINDOW_UPDATE frames allowed per DATA frame sent. If the number @@ -401,8 +479,6 @@ message Http2ProtocolOptions { // flood mitigation. The default max_inbound_window_update_frames_per_data_frame_sent value is 10. // Setting this to 1 should be enough to support HTTP/2 implementations with basic flow control, // but more complex implementations that try to estimate available bandwidth require at least 2. - // NOTE: flood and abuse mitigation for upstream connections is presently enabled by the - // `envoy.reloadable_features.upstream_http2_flood_checks` flag. google.protobuf.UInt32Value max_inbound_window_update_frames_per_data_frame_sent = 11 [(validate.rules).uint32 = {gte: 1}]; @@ -473,6 +549,7 @@ message GrpcProtocolOptions { } // A message which allows using HTTP/3. +// [#next-free-field: 6] message Http3ProtocolOptions { QuicProtocolOptions quic_protocol_options = 1; @@ -483,6 +560,14 @@ message Http3ProtocolOptions { // If set, this overrides any HCM :ref:`stream_error_on_invalid_http_messaging // `. google.protobuf.BoolValue override_stream_error_on_invalid_http_message = 2; + + // Allows proxying Websocket and other upgrades over HTTP/3 CONNECT using + // the header mechanisms from the `HTTP/2 extended connect RFC + // `_ + // and settings `proposed for HTTP/3 + // `_ + // Note that HTTP/3 CONNECT is not yet an RFC. + bool allow_extended_connect = 5 [(xds.annotations.v3.field_status).work_in_progress = true]; } // A message to control transformations to the :scheme header diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto index 40b33f33ff5..9cfdbe5f66c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto @@ -7,6 +7,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "ProxyProtocolProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Proxy Protocol] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/resolver.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/resolver.proto index 21d40425f7a..f4d103ab038 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/resolver.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/resolver.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "ResolverProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Resolver] @@ -17,9 +18,6 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Configuration of DNS resolver option flags which control the behavior of the DNS resolver. message DnsResolverOptions { // Use TCP for all DNS queries instead of the default protocol UDP. - // Setting this value causes failure if the - // ``envoy.restart_features.use_apple_api_for_dns_lookups`` runtime value is true during - // server startup. Apple's API only uses UDP for DNS resolution. bool use_tcp_for_dns_lookups = 1; // Do not use the default search domains; only query hostnames as-is or as aliases. @@ -31,9 +29,6 @@ message DnsResolutionConfig { // A list of dns resolver addresses. If specified, the DNS client library will perform resolution // via the underlying DNS resolvers. Otherwise, the default system resolvers // (e.g., /etc/resolv.conf) will be used. - // Setting this value causes failure if the - // ``envoy.restart_features.use_apple_api_for_dns_lookups`` runtime value is true during - // server startup. Apple's API only allows overriding DNS resolvers via system settings. repeated Address resolvers = 1 [(validate.rules).repeated = {min_items: 1}]; // Configuration of DNS resolver option flags which control the behavior of the DNS resolver. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto index b22169b86ae..e7605fb6889 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto @@ -9,12 +9,33 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "SocketOptionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Socket Option ] // Generic socket option message. This would be used to set socket options that // might not exist in upstream kernels or precompiled Envoy binaries. +// +// For example: +// +// .. code-block:: json +// +// { +// "description": "support tcp keep alive", +// "state": 0, +// "level": 1, +// "name": 9, +// "int_value": 1, +// } +// +// 1 means SOL_SOCKET and 9 means SO_KEEPALIVE on Linux. +// With the above configuration, `TCP Keep-Alives `_ +// can be enabled in socket with Linux, which can be used in +// :ref:`listener's` or +// :ref:`admin's ` socket_options etc. +// +// It should be noted that the name or level may have different values on different platforms. // [#next-free-field: 7] message SocketOption { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.SocketOption"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto index b2a1c5e13ee..7259725e05a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto @@ -14,6 +14,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "SubstitutionFormatStringProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Substitution format string] @@ -105,7 +106,8 @@ message SubstitutionFormatString { // // content_type: "text/html; charset=UTF-8" // - string content_type = 4; + string content_type = 4 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; // Specifies a collection of Formatter plugins that can be called from the access log configuration. // See the formatters extensions documentation for details. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/udp_socket_config.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/udp_socket_config.proto index 00033eabdb8..ec9f77f0687 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/udp_socket_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/udp_socket_config.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "UdpSocketConfigProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: UDP socket config] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto index b22a644eeae..7edfb66c9a8 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.endpoint.v3"; option java_outer_classname = "EndpointProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3;endpointv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Endpoint configuration] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto index 0a9aac105e7..49f38211e80 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto @@ -9,7 +9,6 @@ import "envoy/config/core/v3/health_check.proto"; import "google/protobuf/wrappers.proto"; -import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -17,6 +16,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.endpoint.v3"; option java_outer_classname = "EndpointComponentsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3;endpointv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Endpoints] @@ -122,9 +122,8 @@ message LedsClusterLocalityConfig { } // A group of endpoints belonging to a Locality. -// One can have multiple LocalityLbEndpoints for a locality, but this is -// generally only done if the different groups need to have different load -// balancing weights or different priorities. +// One can have multiple LocalityLbEndpoints for a locality, but only if +// they have different priorities. // [#next-free-field: 9] message LocalityLbEndpoints { option (udpa.annotations.versioning).previous_message_type = diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto index c114fa72662..85ecae7f2d2 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.endpoint.v3"; option java_outer_classname = "LoadReportProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3;endpointv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Load Report] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/filter/accesslog/v2/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/filter/accesslog/v2/accesslog.proto index 25d27bfbd10..7f38515421c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/filter/accesslog/v2/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/filter/accesslog/v2/accesslog.proto @@ -16,6 +16,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.filter.accesslog.v2"; option java_outer_classname = "AccesslogProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/filter/accesslog/v2;accesslogv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.accesslog.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/filter/fault/v2/fault.proto b/xds/third_party/envoy/src/main/proto/envoy/config/filter/fault/v2/fault.proto index 016140d10f8..d23e50b1916 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/filter/fault/v2/fault.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/filter/fault/v2/fault.proto @@ -14,6 +14,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.filter.fault.v2"; option java_outer_classname = "FaultProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/filter/fault/v2;faultv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.filters.common.fault.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/fault/v2/fault.proto b/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/fault/v2/fault.proto index cb99b0d71bb..109dfb4cfbe 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/fault/v2/fault.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/fault/v2/fault.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.filter.http.fault.v2"; option java_outer_classname = "FaultProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/filter/http/fault/v2;faultv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.filters.http.fault.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/router/v2/router.proto b/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/router/v2/router.proto index c95500cf816..e47e73f8c7a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/router/v2/router.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/filter/http/router/v2/router.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.filter.http.router.v2"; option java_outer_classname = "RouterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/filter/http/router/v2;routerv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.filters.http.router.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto b/xds/third_party/envoy/src/main/proto/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto index 3e7a4dc1776..6286e979a1f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto @@ -24,6 +24,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.filter.network.http_connection_manager.v2"; option java_outer_classname = "HttpConnectionManagerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2;http_connection_managerv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.filters.network.http_connection_manager.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v2/api_listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v2/api_listener.proto index 6709d5fe0b5..ae47c7d338a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v2/api_listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v2/api_listener.proto @@ -10,6 +10,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.config.listener.v2"; option java_outer_classname = "ApiListenerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/listener/v2;listenerv2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.config.listener.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto index 77db7caaff5..a3610e65688 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto @@ -10,6 +10,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.config.listener.v3"; option java_outer_classname = "ApiListenerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3;listenerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: API listener] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto index a5cd4bfe976..d8982b0a97a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto @@ -13,7 +13,9 @@ import "envoy/config/listener/v3/udp_listener_config.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "xds/annotations/v3/status.proto"; import "xds/core/v3/collection_entry.proto"; +import "xds/type/matcher/v3/matcher.proto"; import "envoy/annotations/deprecation.proto"; import "udpa/annotations/security.proto"; @@ -24,6 +26,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.listener.v3"; option java_outer_classname = "ListenerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3;listenerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Listener configuration] @@ -35,7 +38,7 @@ message ListenerCollection { repeated xds.core.v3.CollectionEntry entries = 1; } -// [#next-free-field: 30] +// [#next-free-field: 33] message Listener { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Listener"; @@ -103,7 +106,8 @@ message Listener { // The address that the listener should listen on. In general, the address must be unique, though // that is governed by the bind rules of the OS. E.g., multiple listeners can listen on port 0 on // Linux as the actual port will be allocated by the OS. - core.v3.Address address = 2 [(validate.rules).message = {required: true}]; + // Required unless *api_listener* or *listener_specifier* is populated. + core.v3.Address address = 2; // Optional prefix to use on listener stats. If empty, the stats will be rooted at // `listener.

    .`. If non-empty, stats will be rooted at @@ -119,6 +123,25 @@ message Listener { // :ref:`FAQ entry `. repeated FilterChain filter_chains = 3; + // :ref:`Matcher API ` resolving the filter chain name from the + // network properties. This matcher is used as a replacement for the filter chain match condition + // :ref:`filter_chain_match + // `. If specified, all + // :ref:`filter_chains ` must have a + // non-empty and unique :ref:`name ` field + // and not specify :ref:`filter_chain_match + // ` field. + // + // .. note:: + // + // Once matched, each connection is permanently bound to its filter chain. + // If the matcher changes but the filter chain remains the same, the + // connections bound to the filter chain are not drained. If, however, the + // filter chain is removed or structurally modified, then the drain for its + // connections is initiated. + xds.type.matcher.v3.Matcher filter_chain_matcher = 32 + [(xds.annotations.v3.field_status).work_in_progress = true]; + // If a connection is redirected using *iptables*, the port on which the proxy // receives it might be different from the original destination address. When this flag is set to // true, the listener hands off redirected connections to the listener associated with the @@ -153,7 +176,6 @@ message Listener { // UDP Listener filters can be specified when the protocol in the listener socket address in // :ref:`protocol ` is :ref:`UDP // `. - // UDP listeners currently support a single filter. repeated ListenerFilter listener_filters = 9; // The timeout to wait for all listener filters to complete operation. If the timeout is reached, @@ -315,4 +337,12 @@ message Listener { // [#not-implemented-hide:] InternalListenerConfig internal_listener = 27; } + + // Enable MPTCP (multi-path TCP) on this listener. Clients will be allowed to establish + // MPTCP connections. Non-MPTCP clients will fall back to regular TCP. + bool enable_mptcp = 30; + + // Whether the listener should limit connections based upon the value of + // :ref:`global_downstream_max_connections `. + bool ignore_global_conn_limit = 31; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto index e737b14b174..aed27c37148 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto @@ -4,13 +4,15 @@ package envoy.config.listener.v3; import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/base.proto"; -import "envoy/config/core/v3/extension.proto"; +import "envoy/config/core/v3/config_source.proto"; import "envoy/type/v3/range.proto"; import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "xds/annotations/v3/status.proto"; + import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -19,6 +21,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.listener.v3"; option java_outer_classname = "ListenerComponentsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3;listenerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Listener components] @@ -32,8 +35,7 @@ message Filter { reserved "config"; - // The name of the filter to instantiate. The name must match a - // :ref:`supported filter `. + // The name of the filter configuration. string name = 1 [(validate.rules).string = {min_len: 1}]; oneof config_type { @@ -258,10 +260,11 @@ message FilterChain { // establishment, the connection is summarily closed. google.protobuf.Duration transport_socket_connect_timeout = 9; - // [#not-implemented-hide:] The unique name (or empty) by which this filter chain is known. If no - // name is provided, Envoy will allocate an internal UUID for the filter chain. If the filter - // chain is to be dynamically updated or removed via FCDS a unique name must be provided. - string name = 7; + // The unique name (or empty) by which this filter chain is known. + // Note: :ref:`filter_chain_matcher + // ` + // requires that filter chains are uniquely named within a listener. + string name = 7 [(xds.annotations.v3.field_status).work_in_progress = true]; // [#not-implemented-hide:] The configuration to specify whether the filter chain will be built on-demand. // If this field is not empty, the filter chain will be built on-demand. @@ -333,6 +336,7 @@ message ListenerFilterChainMatchPredicate { } } +// [#next-free-field: 6] message ListenerFilter { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.ListenerFilter"; @@ -341,8 +345,7 @@ message ListenerFilter { reserved "config"; - // The name of the filter to instantiate. The name must match a - // :ref:`supported filter `. + // The name of the filter configuration. string name = 1 [(validate.rules).string = {min_len: 1}]; oneof config_type { @@ -350,6 +353,12 @@ message ListenerFilter { // instantiated. See the supported filters for further documentation. // [#extension-category: envoy.filters.listener,envoy.filters.udp_listener] google.protobuf.Any typed_config = 3; + + // Configuration source specifier for an extension configuration discovery + // service. In case of a failure and without the default configuration, the + // listener closes the connections. + // [#not-implemented-hide:] + core.v3.ExtensionConfigSource config_discovery = 5; } // Optional match predicate used to disable the filter. The filter is enabled when this field is empty. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto index 1432e1911b5..89dc34a06b8 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto @@ -16,6 +16,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.listener.v3"; option java_outer_classname = "QuicConfigProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3;listenerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: QUIC listener config] @@ -29,11 +30,14 @@ message QuicProtocolOptions { core.v3.QuicProtocolOptions quic_protocol_options = 1; // Maximum number of milliseconds that connection will be alive when there is - // no network activity. 300000ms if not specified. + // no network activity. + // + // If it is less than 1ms, Envoy will use 1ms. 300000ms if not specified. google.protobuf.Duration idle_timeout = 2; // Connection timeout in milliseconds before the crypto handshake is finished. - // 20000ms if not specified. + // + // If it is less than 5000ms, Envoy will use 5000ms. 20000ms if not specified. google.protobuf.Duration crypto_handshake_timeout = 3; // Runtime flag that controls whether the listener is enabled or not. If not specified, defaults diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/udp_listener_config.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/udp_listener_config.proto index 57088ac5fe1..f3f03d23ed7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/udp_listener_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/udp_listener_config.proto @@ -11,6 +11,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.config.listener.v3"; option java_outer_classname = "UdpListenerConfigProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3;listenerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: UDP listener config] @@ -33,10 +34,6 @@ message UdpListenerConfig { // Configuration for QUIC protocol. If empty, QUIC will not be enabled on this listener. Set // to the default object to enable QUIC without modifying any additional options. - // - // .. warning:: - // QUIC support is currently alpha and should be used with caution. Please - // see :ref:`here ` for details. QuicProtocolOptions quic_options = 7; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto b/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto index d442cffe36a..17ae761ea3c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.metrics.v3"; option java_outer_classname = "StatsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/metrics/v3;metricsv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Stats] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto index 85fa761dbdd..3868df23482 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto @@ -14,6 +14,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.overload.v3"; option java_outer_classname = "OverloadProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/overload/v3;overloadv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Overload Manager] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v2/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v2/rbac.proto index 943ac33e085..941d6217720 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v2/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v2/rbac.proto @@ -16,6 +16,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.rbac.v2"; option java_outer_classname = "RbacProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v2;rbacv2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Role Based Access Control (RBAC)] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto index d66f9be2b49..8abde899d7e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.config.rbac.v3; import "envoy/config/core/v3/address.proto"; +import "envoy/config/core/v3/extension.proto"; import "envoy/config/route/v3/route_components.proto"; import "envoy/type/matcher/v3/metadata.proto"; import "envoy/type/matcher/v3/path.proto"; @@ -21,6 +22,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.rbac.v3"; option java_outer_classname = "RbacProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3;rbacv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Role Based Access Control (RBAC)] @@ -146,7 +148,7 @@ message Policy { } // Permission defines an action (or actions) that a principal can take. -// [#next-free-field: 12] +// [#next-free-field: 13] message Permission { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Permission"; @@ -218,6 +220,10 @@ message Permission { // Please refer to :ref:`this FAQ entry ` to learn to // setup SNI. type.matcher.v3.StringMatcher requested_server_name = 9; + + // Extension for configuring custom matchers for RBAC. + // [#extension-category: envoy.rbac.matchers] + core.v3.TypedExtensionConfig matcher = 12; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto index e2bf52165be..8579f0af7c8 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto @@ -4,7 +4,6 @@ package envoy.config.route.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/config_source.proto"; -import "envoy/config/core/v3/extension.proto"; import "envoy/config/route/v3/route_components.proto"; import "google/protobuf/wrappers.proto"; @@ -16,13 +15,14 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.route.v3"; option java_outer_classname = "RouteProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/route/v3;routev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: HTTP route configuration] // * Routing :ref:`architecture overview ` // * HTTP :ref:`router filter ` -// [#next-free-field: 13] +// [#next-free-field: 15] message RouteConfiguration { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.RouteConfiguration"; @@ -121,17 +121,20 @@ message RouteConfiguration { // google.protobuf.UInt32Value max_direct_response_body_size_bytes = 11; - // [#not-implemented-hide:] // A list of plugins and their configurations which may be used by a - // :ref:`envoy_v3_api_field_config.route.v3.RouteAction.cluster_specifier_plugin` + // :ref:`cluster specifier plugin name ` // within the route. All *extension.name* fields in this list must be unique. repeated ClusterSpecifierPlugin cluster_specifier_plugins = 12; -} -// Configuration for a cluster specifier plugin. -message ClusterSpecifierPlugin { - // The name of the plugin and its opaque configuration. - core.v3.TypedExtensionConfig extension = 1; + // Specify a set of default request mirroring policies which apply to all routes under its virtual hosts. + // Note that policies are not merged, the most specific non-empty one becomes the mirror policies. + repeated RouteAction.RequestMirrorPolicy request_mirror_policies = 13; + + // By default, port in :authority header (if any) is used in host matching. + // With this option enabled, Envoy will ignore the port number in the :authority header (if any) when picking VirtualHost. + // NOTE: this option will not strip the port number (if any) contained in route config + // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`.domains field. + bool ignore_port_in_host_matching = 14; } message Vhds { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto index dfb8b8ed1a1..b3ec0c594ac 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto @@ -5,6 +5,7 @@ package envoy.config.route.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/proxy_protocol.proto"; +import "envoy/type/matcher/v3/metadata.proto"; import "envoy/type/matcher/v3/regex.proto"; import "envoy/type/matcher/v3/string.proto"; import "envoy/type/metadata/v3/metadata.proto"; @@ -16,6 +17,9 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "xds/annotations/v3/status.proto"; +import "xds/type/matcher/v3/matcher.proto"; + import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; @@ -25,6 +29,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.route.v3"; option java_outer_classname = "RouteComponentsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/route/v3;routev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: HTTP route components] @@ -36,7 +41,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // host header. This allows a single listener to service multiple top level domain path trees. Once // a virtual host is selected based on the domain, the routes are processed in order to see which // upstream cluster to route to or whether to perform a redirect. -// [#next-free-field: 21] +// [#next-free-field: 23] message VirtualHost { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.VirtualHost"; @@ -86,8 +91,15 @@ message VirtualHost { // The list of routes that will be matched, in order, for incoming requests. // The first route that matches will be used. + // Only one of this and `matcher` can be specified. repeated Route routes = 3; + // [#next-major-version: This should be included in a oneof with routes wrapped in a message.] + // The match tree to use when resolving route actions for incoming requests. Only one of this and `routes` + // can be specified. + xds.type.matcher.v3.Matcher matcher = 21 + [(xds.annotations.v3.field_status).work_in_progress = true]; + // Specifies the type of TLS enforcement the virtual host expects. If this option is not // specified, there is no TLS requirement for the virtual host. TlsRequirementType require_tls = 4 [(validate.rules).enum = {defined_only: true}]; @@ -186,6 +198,11 @@ message VirtualHost { // If set and a route-specific limit is not set, the bytes actually buffered will be the minimum // value of this and the listener per_connection_buffer_limit_bytes. google.protobuf.UInt32Value per_request_buffer_limit_bytes = 18; + + // Specify a set of default request mirroring policies for every route under this virtual host. + // It takes precedence over the route config mirror policy entirely. + // That is, policies are not merged, the most specific non-empty one becomes the mirror policies. + repeated RouteAction.RequestMirrorPolicy request_mirror_policies = 22; } // A filter-defined action type. @@ -311,7 +328,7 @@ message Route { message WeightedCluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster"; - // [#next-free-field: 12] + // [#next-free-field: 13] message ClusterWeight { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster.ClusterWeight"; @@ -320,9 +337,31 @@ message WeightedCluster { reserved "per_filter_config"; + // Only one of *name* and *cluster_header* may be specified. + // [#next-major-version: Need to add back the validation rule: (validate.rules).string = {min_len: 1}] // Name of the upstream cluster. The cluster must exist in the // :ref:`cluster manager configuration `. - string name = 1 [(validate.rules).string = {min_len: 1}]; + string name = 1 [(udpa.annotations.field_migrate).oneof_promotion = "cluster_specifier"]; + + // Only one of *name* and *cluster_header* may be specified. + // [#next-major-version: Need to add back the validation rule: (validate.rules).string = {min_len: 1 }] + // Envoy will determine the cluster to route to by reading the value of the + // HTTP header named by cluster_header from the request headers. If the + // header is not found or the referenced cluster does not exist, Envoy will + // return a 404 response. + // + // .. attention:: + // + // Internally, Envoy always uses the HTTP/2 *:authority* header to represent the HTTP/1 + // *Host* header. Thus, if attempting to match on *Host*, match on *:authority* instead. + // + // .. note:: + // + // If the header appears multiple times only the first value is used. + string cluster_header = 12 [ + (validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}, + (udpa.annotations.field_migrate).oneof_promotion = "cluster_specifier" + ]; // An integer between 0 and :ref:`total_weight // `. When a request matches the route, @@ -403,9 +442,31 @@ message WeightedCluster { // configuration file will be used as the default weight. See the :ref:`runtime documentation // ` for how key names map to the underlying implementation. string runtime_key_prefix = 2; + + oneof random_value_specifier { + // Specifies the header name that is used to look up the random value passed in the request header. + // This is used to ensure consistent cluster picking across multiple proxy levels for weighted traffic. + // If header is not present or invalid, Envoy will fall back to use the internally generated random value. + // This header is expected to be single-valued header as we only want to have one selected value throughout + // the process for the consistency. And the value is a unsigned number between 0 and UINT64_MAX. + string header_name = 4; + } +} + +// Configuration for a cluster specifier plugin. +message ClusterSpecifierPlugin { + // The name of the plugin and its opaque configuration. + core.v3.TypedExtensionConfig extension = 1 [(validate.rules).message = {required: true}]; + + // If is_optional is not set or is set to false and the plugin defined by this message is not a + // supported type, the containing resource is NACKed. If is_optional is set to true, the resource + // would not be NACKed for this reason. In this case, routes referencing this plugin's name would + // not be treated as an illegal configuration, but would result in a failure if the route is + // selected. + bool is_optional = 2; } -// [#next-free-field: 13] +// [#next-free-field: 15] message RouteMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteMatch"; @@ -470,6 +531,17 @@ message RouteMatch { // Note that CONNECT support is currently considered alpha in Envoy. // [#comment: TODO(htuch): Replace the above comment with an alpha tag.] ConnectMatcher connect_matcher = 12; + + // If specified, the route is a path-separated prefix rule meaning that the + // ``:path`` header (without the query string) must either exactly match the + // ``path_separated_prefix`` or have it as a prefix, followed by ``/`` + // + // For example, ``/api/dev`` would match + // ``/api/dev``, ``/api/dev/``, ``/api/dev/v1``, and ``/api/dev?param=true`` + // but would not match ``/api/developer`` + // + // Expect the value to not contain ``?`` or ``#`` and not to end in ``/`` + string path_separated_prefix = 14 [(validate.rules).string = {pattern: "^[^?#]+[^?#/]$"}]; } // Indicates that prefix/path matching should be case sensitive. The default @@ -506,6 +578,14 @@ message RouteMatch { // against all the specified query parameters. If the number of specified // query parameters is nonzero, they all must match the *path* header's // query string for a match to occur. + // + // .. note:: + // + // If query parameters are used to pass request message fields when + // `grpc_json_transcoder `_ + // is used, the transcoded message fields maybe different. The query parameters are + // url encoded, but the message fields are not. For example, if a query + // parameter is "foo%20bar", the message field will be "foo bar". repeated QueryParameterMatcher query_parameters = 7; // If specified, only gRPC requests will be matched. The router will check @@ -518,6 +598,12 @@ message RouteMatch { // // [#next-major-version: unify with RBAC] TlsContextMatchOptions tls_context = 11; + + // Specifies a set of dynamic metadata matchers on which the route should match. + // The router will check the dynamic metadata against all the specified dynamic metadata matchers. + // If the number of specified dynamic metadata matchers is nonzero, they all must match the + // dynamic metadata for a match to occur. + repeated type.matcher.v3.MetadataMatcher dynamic_metadata = 13; } // [#next-free-field: 12] @@ -570,7 +656,7 @@ message CorsPolicy { core.v3.RuntimeFractionalPercent shadow_enabled = 10; } -// [#next-free-field: 38] +// [#next-free-field: 40] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction"; @@ -602,6 +688,7 @@ message RouteAction { // .. note:: // // Shadowing will not be triggered if the primary cluster does not exist. + // [#next-free-field: 6] message RequestMirrorPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction.RequestMirrorPolicy"; @@ -610,9 +697,30 @@ message RouteAction { reserved "runtime_key"; + // Only one of *cluster* and *cluster_header* can be specified. + // [#next-major-version: Need to add back the validation rule: (validate.rules).string = {min_len: 1}] // Specifies the cluster that requests will be mirrored to. The cluster must // exist in the cluster manager configuration. - string cluster = 1 [(validate.rules).string = {min_len: 1}]; + string cluster = 1 [(udpa.annotations.field_migrate).oneof_promotion = "cluster_specifier"]; + + // Only one of *cluster* and *cluster_header* can be specified. + // Envoy will determine the cluster to route to by reading the value of the + // HTTP header named by cluster_header from the request headers. Only the first value in header is used, + // and no shadow request will happen if the value is not found in headers. Envoy will not wait for + // the shadow cluster to respond before returning the response from the primary cluster. + // + // .. attention:: + // + // Internally, Envoy always uses the HTTP/2 *:authority* header to represent the HTTP/1 + // *Host* header. Thus, if attempting to match on *Host*, match on *:authority* instead. + // + // .. note:: + // + // If the header appears multiple times only the first value is used. + string cluster_header = 5 [ + (validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}, + (udpa.annotations.field_migrate).oneof_promotion = "cluster_specifier" + ]; // If not specified, all requests to the target cluster will be mirrored. // @@ -705,8 +813,8 @@ message RouteAction { "envoy.api.v2.route.RouteAction.HashPolicy.FilterState"; // The name of the Object in the per-request filterState, which is an - // Envoy::Http::Hashable object. If there is no data associated with the key, - // or the stored object is not Envoy::Http::Hashable, no hash will be produced. + // Envoy::Hashable object. If there is no data associated with the key, + // or the stored object is not Envoy::Hashable, no hash will be produced. string key = 1 [(validate.rules).string = {min_len: 1}]; } @@ -847,13 +955,15 @@ message RouteAction { // for additional documentation. WeightedCluster weighted_clusters = 3; - // [#not-implemented-hide:] - // Name of the cluster specifier plugin to use to determine the cluster for - // requests on this route. The plugin name must be defined in the associated - // :ref:`envoy_v3_api_field_config.route.v3.RouteConfiguration.cluster_specifier_plugins` - // in the - // :ref:`envoy_v3_api_field_config.core.v3.TypedExtensionConfig.name` field. + // Name of the cluster specifier plugin to use to determine the cluster for requests on this route. + // The cluster specifier plugin name must be defined in the associated + // :ref:`cluster specifier plugins ` + // in the :ref:`name ` field. string cluster_specifier_plugin = 37; + + // Custom cluster specifier plugin configuration to use to determine the cluster for requests + // on this route. + ClusterSpecifierPlugin inline_cluster_specifier_plugin = 39; } // The HTTP status code to use when configured cluster is not found. @@ -934,20 +1044,29 @@ message RouteAction { oneof host_rewrite_specifier { // Indicates that during forwarding, the host header will be swapped with - // this value. + // this value. Using this option will append the + // :ref:`config_http_conn_man_headers_x-forwarded-host` header if + // :ref:`append_x_forwarded_host ` + // is set. string host_rewrite_literal = 6 [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; // Indicates that during forwarding, the host header will be swapped with // the hostname of the upstream host chosen by the cluster manager. This // option is applicable only when the destination cluster for a route is of - // type *strict_dns* or *logical_dns*. Setting this to true with other cluster - // types has no effect. + // type *strict_dns* or *logical_dns*. Setting this to true with other cluster types + // has no effect. Using this option will append the + // :ref:`config_http_conn_man_headers_x-forwarded-host` header if + // :ref:`append_x_forwarded_host ` + // is set. google.protobuf.BoolValue auto_host_rewrite = 7; // Indicates that during forwarding, the host header will be swapped with the content of given // downstream or :ref:`custom ` header. - // If header value is empty, host header is left intact. + // If header value is empty, host header is left intact. Using this option will append the + // :ref:`config_http_conn_man_headers_x-forwarded-host` header if + // :ref:`append_x_forwarded_host ` + // is set. // // .. attention:: // @@ -963,6 +1082,10 @@ message RouteAction { // Indicates that during forwarding, the host header will be swapped with // the result of the regex substitution executed on path value with query and fragment removed. // This is useful for transitioning variable content between path segment and subdomain. + // Using this option will append the + // :ref:`config_http_conn_man_headers_x-forwarded-host` header if + // :ref:`append_x_forwarded_host ` + // is set. // // For example with the following config: // @@ -978,6 +1101,15 @@ message RouteAction { type.matcher.v3.RegexMatchAndSubstitute host_rewrite_path_regex = 35; } + // If set, then a host rewrite action (one of + // :ref:`host_rewrite_literal `, + // :ref:`auto_host_rewrite `, + // :ref:`host_rewrite_header `, or + // :ref:`host_rewrite_path_regex `) + // causes the original value of the host header, if any, to be appended to the + // :ref:`config_http_conn_man_headers_x-forwarded-host` HTTP header. + bool append_x_forwarded_host = 38; + // Specifies the upstream timeout for the route. If not specified, the default is 15s. This // spans between the point at which the entire downstream request (i.e. end-of-stream) has been // processed and when the upstream response has been completely processed. A value of 0 will @@ -1027,7 +1159,9 @@ message RouteAction { // should not be set if this field is used. google.protobuf.Any retry_policy_typed_config = 33; - // Indicates that the route has request mirroring policies. + // Specify a set of route request mirroring policies. + // It takes precedence over the virtual host and route config mirror policy entirely. + // That is, policies are not merged, the most specific non-empty one becomes the mirror policies. repeated RequestMirrorPolicy request_mirror_policies = 30; // Optionally specifies the :ref:`routing priority `. @@ -1135,7 +1269,7 @@ message RouteAction { } // HTTP retry :ref:`architecture overview `. -// [#next-free-field: 12] +// [#next-free-field: 14] message RetryPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RetryPolicy"; @@ -1276,8 +1410,8 @@ message RetryPolicy { google.protobuf.UInt32Value num_retries = 2 [(udpa.annotations.field_migrate).rename = "max_retries"]; - // Specifies a non-zero upstream timeout per retry attempt. This parameter is optional. The - // same conditions documented for + // Specifies a non-zero upstream timeout per retry attempt (including the initial attempt). This + // parameter is optional. The same conditions documented for // :ref:`config_http_filters_router_x-envoy-upstream-rq-per-try-timeout-ms` apply. // // .. note:: @@ -1289,6 +1423,27 @@ message RetryPolicy { // would have been exhausted. google.protobuf.Duration per_try_timeout = 3; + // Specifies an upstream idle timeout per retry attempt (including the initial attempt). This + // parameter is optional and if absent there is no per try idle timeout. The semantics of the per + // try idle timeout are similar to the + // :ref:`route idle timeout ` and + // :ref:`stream idle timeout + // ` + // both enforced by the HTTP connection manager. The difference is that this idle timeout + // is enforced by the router for each individual attempt and thus after all previous filters have + // run, as opposed to *before* all previous filters run for the other idle timeouts. This timeout + // is useful in cases in which total request timeout is bounded by a number of retries and a + // :ref:`per_try_timeout `, but + // there is a desire to ensure each try is making incremental progress. Note also that similar + // to :ref:`per_try_timeout `, + // this idle timeout does not start until after both the entire request has been received by the + // router *and* a connection pool connection has been obtained. Unlike + // :ref:`per_try_timeout `, + // the idle timer continues once the response starts streaming back to the downstream client. + // This ensures that response data continues to make progress without using one of the HTTP + // connection manager idle timeouts. + google.protobuf.Duration per_try_idle_timeout = 13; + // Specifies an implementation of a RetryPriority which is used to determine the // distribution of load across priorities used for retries. Refer to // :ref:`retry plugin configuration ` for more details. @@ -1300,6 +1455,11 @@ message RetryPolicy { // details. repeated RetryHostPredicate retry_host_predicate = 5; + // Retry options predicates that will be applied prior to retrying a request. These predicates + // allow customizing request behavior between retries. + // [#comment: add [#extension-category: envoy.retry_options_predicates] when there are built-in extensions] + repeated core.v3.TypedExtensionConfig retry_options_predicates = 12; + // The maximum number of times host selection will be reattempted before giving up, at which // point the host that was last selected will be routed to. If unspecified, this will default to // retrying once. @@ -1477,7 +1637,7 @@ message DirectResponseAction { "envoy.api.v2.route.DirectResponseAction"; // Specifies the HTTP response status to be returned. - uint32 status = 1 [(validate.rules).uint32 = {lt: 600 gte: 100}]; + uint32 status = 1 [(validate.rules).uint32 = {lt: 600 gte: 200}]; // Specifies the content of the response body. If this setting is omitted, // no body is included in the generated response. @@ -1588,7 +1748,7 @@ message VirtualCluster { message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; - // [#next-free-field: 10] + // [#next-free-field: 11] message Action { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action"; @@ -1662,6 +1822,28 @@ message RateLimit { "envoy.api.v2.route.RateLimit.Action.RemoteAddress"; } + // The following descriptor entry is appended to the descriptor and is populated using the + // masked address from :ref:`x-forwarded-for `: + // + // .. code-block:: cpp + // + // ("masked_remote_address", "") + message MaskedRemoteAddress { + // Length of prefix mask len for IPv4 (e.g. 0, 32). + // Defaults to 32 when unset. + // For example, trusted address from x-forwarded-for is `192.168.1.1`, + // the descriptor entry is ("masked_remote_address", "192.168.1.1/32"); + // if mask len is 24, the descriptor entry is ("masked_remote_address", "192.168.1.0/24"). + google.protobuf.UInt32Value v4_prefix_mask_len = 1 [(validate.rules).uint32 = {lte: 32}]; + + // Length of prefix mask len for IPv6 (e.g. 0, 128). + // Defaults to 128 when unset. + // For example, trusted address from x-forwarded-for is `2001:abcd:ef01:2345:6789:abcd:ef01:234`, + // the descriptor entry is ("masked_remote_address", "2001:abcd:ef01:2345:6789:abcd:ef01:234/128"); + // if mask len is 64, the descriptor entry is ("masked_remote_address", "2001:abcd:ef01:2345::/64"). + google.protobuf.UInt32Value v6_prefix_mask_len = 2 [(validate.rules).uint32 = {lte: 128}]; + } + // The following descriptor entry is appended to the descriptor: // // .. code-block:: cpp @@ -1688,6 +1870,9 @@ message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action.HeaderValueMatch"; + // The key to use in the descriptor entry. Defaults to `header_match`. + string descriptor_key = 4; + // The value to use in the descriptor entry. string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; @@ -1791,8 +1976,17 @@ message RateLimit { MetaData metadata = 8; // Rate limit descriptor extension. See the rate limit descriptor extensions documentation. + // + // :ref:`HTTP matching input functions ` are + // permitted as descriptor extensions. The input functions are only + // looked up if there is no rate limit descriptor extension matching + // the type URL. + // // [#extension-category: envoy.rate_limit_descriptors] core.v3.TypedExtensionConfig extension = 9; + + // Rate limit on masked remote address. + MaskedRemoteAddress masked_remote_address = 10; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/scoped_route.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/scoped_route.proto index eb47d7e1089..4ac0ca7c23d 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/scoped_route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/scoped_route.proto @@ -2,6 +2,9 @@ syntax = "proto3"; package envoy.config.route.v3; +import "envoy/config/route/v3/route.proto"; + +import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -9,6 +12,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.route.v3"; option java_outer_classname = "ScopedRouteProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/route/v3;routev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: HTTP scoped routing configuration] @@ -16,7 +20,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Specifies a routing scope, which associates a // :ref:`Key` to a -// :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration` (identified by its resource name). +// :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration`. +// The :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration` can be obtained dynamically +// via RDS (:ref:`route_configuration_name`) +// or specified inline (:ref:`route_configuration`). // // The HTTP connection manager builds up a table consisting of these Key to // RouteConfiguration mappings, and looks up the RouteConfiguration to use per @@ -73,6 +80,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // would result in the routing table defined by the `route-config1` // RouteConfiguration being assigned to the HTTP request/stream. // +// [#next-free-field: 6] message ScopedRouteConfiguration { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.ScopedRouteConfiguration"; @@ -113,7 +121,12 @@ message ScopedRouteConfiguration { // The resource name to use for a :ref:`envoy_v3_api_msg_service.discovery.v3.DiscoveryRequest` to an // RDS server to fetch the :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration` associated // with this scope. - string route_configuration_name = 2 [(validate.rules).string = {min_len: 1}]; + string route_configuration_name = 2 + [(udpa.annotations.field_migrate).oneof_promotion = "route_config"]; + + // The :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration` associated with the scope. + RouteConfiguration route_configuration = 5 + [(udpa.annotations.field_migrate).oneof_promotion = "route_config"]; // The key to match against. Key key = 3 [(validate.rules).message = {required: true}]; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/datadog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/datadog.proto index 0992601a8ac..3034eecaf55 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/datadog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/datadog.proto @@ -8,6 +8,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "DatadogProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Datadog tracer] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/dynamic_ot.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/dynamic_ot.proto index 55c6d401b33..928b096bb0f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/dynamic_ot.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/dynamic_ot.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "DynamicOtProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Dynamically loadable OpenTracing tracer] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/http_tracer.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/http_tracer.proto index fba830b987b..778b9e718a7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/http_tracer.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/http_tracer.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "HttpTracerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Tracing] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/lightstep.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/lightstep.proto index 849749baaa0..db866c82557 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/lightstep.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/lightstep.proto @@ -8,6 +8,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "LightstepProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: LightStep tracer] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/opencensus.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/opencensus.proto index 1a9a879b21e..595f4fe2783 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/opencensus.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/opencensus.proto @@ -11,6 +11,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "OpencensusProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: OpenCensus tracer] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/service.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/service.proto index d102499b626..85477cccbf2 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/service.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/service.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "ServiceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Trace Service] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/trace.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/trace.proto index 6ed394147db..02d6fa28bd9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/trace.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/trace.proto @@ -13,3 +13,4 @@ import public "envoy/config/trace/v2/zipkin.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "TraceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/zipkin.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/zipkin.proto index a825d85bb7f..d052c7176b3 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/zipkin.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v2/zipkin.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v2"; option java_outer_classname = "ZipkinProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v2;tracev2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Zipkin tracer] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto index c101ab2f03c..1a01f6a33c8 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "DatadogProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.tracers.datadog.v4alpha"; option (udpa.annotations.file_status).package_version_status = ACTIVE; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto index c2810687154..954c4a422ab 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto @@ -12,6 +12,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "DynamicOtProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.tracers.dynamic_ot.v4alpha"; option (udpa.annotations.file_status).package_version_status = ACTIVE; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/http_tracer.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/http_tracer.proto index d3c59a8cbb0..8bd5151f4b1 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/http_tracer.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/http_tracer.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "HttpTracerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Tracing] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/lightstep.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/lightstep.proto index b5cff53fea9..0e2680832f0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/lightstep.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/lightstep.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "LightstepProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.tracers.lightstep.v4alpha"; option (udpa.annotations.file_status).package_version_status = ACTIVE; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto index ee2241e729a..9b2d2361a49 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto @@ -14,6 +14,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "OpencensusProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.tracers.opencensus.v4alpha"; option (udpa.annotations.file_status).package_version_status = ACTIVE; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/service.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/service.proto index 1e01ff61847..4cb8c44c424 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/service.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/service.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "ServiceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Trace Service] diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto index 472e38b5abb..5e5895e26bb 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto @@ -13,3 +13,4 @@ import public "envoy/config/trace/v3/zipkin.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "TraceProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto index 2c1026b8304..1d76b813768 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "ZipkinProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; option (udpa.annotations.file_migrate).move_to_package = "envoy.extensions.tracers.zipkin.v4alpha"; option (udpa.annotations.file_status).package_version_status = ACTIVE; @@ -50,8 +51,7 @@ message ZipkinConfig { string collector_cluster = 1 [(validate.rules).string = {min_len: 1}]; // The API endpoint of the Zipkin service where the spans will be sent. When - // using a standard Zipkin installation, the API endpoint is typically - // /api/v1/spans, which is the default value. + // using a standard Zipkin installation. string collector_endpoint = 2 [(validate.rules).string = {min_len: 1}]; // Determines whether a 128bit trace id will be used when creating a new @@ -62,8 +62,7 @@ message ZipkinConfig { // The default value is true. google.protobuf.BoolValue shared_span_context = 4; - // Determines the selected collector endpoint version. By default, the ``HTTP_JSON_V1`` will be - // used. + // Determines the selected collector endpoint version. CollectorEndpointVersion collector_endpoint_version = 5; // Optional hostname to use when sending spans to the collector_cluster. Useful for collectors diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto index aead1c45173..4f44ac9cd5c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.clusters.aggregate.v3"; option java_outer_classname = "ClusterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/clusters/aggregate/v3;aggregatev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Aggregate cluster configuration] diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto index 62da059e264..ab24f5d2374 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.filters.common.fault.v3"; option java_outer_classname = "FaultProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/common/fault/v3;faultv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common fault injection types] diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/fault/v3/fault.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/fault/v3/fault.proto index 0c7fbb4480c..64dbf89e435 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/fault/v3/fault.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/fault/v3/fault.proto @@ -15,6 +15,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.filters.http.fault.v3"; option java_outer_classname = "FaultProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/fault/v3;faultv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Fault Injection] diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto index 7ad7ac5e6aa..008818456e2 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto @@ -10,6 +10,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.extensions.filters.http.rbac.v3"; option java_outer_classname = "RbacProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/rbac/v3;rbacv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: RBAC] diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto index ce595c057c0..7ce8b37dbb7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto @@ -13,6 +13,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.filters.http.router.v3"; option java_outer_classname = "RouterProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3;routerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Router] diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto index 3fb4bfa09e2..d7e8e799d30 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.filters.network.http_connection_manager.v3; import "envoy/config/accesslog/v3/accesslog.proto"; +import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/config_source.proto"; import "envoy/config/core/v3/extension.proto"; @@ -28,13 +29,14 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3"; option java_outer_classname = "HttpConnectionManagerProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3;http_connection_managerv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: HTTP connection manager] // HTTP connection manager :ref:`configuration overview `. // [#extension: envoy.filters.network.http_connection_manager] -// [#next-free-field: 49] +// [#next-free-field: 50] message HttpConnectionManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager"; @@ -201,6 +203,10 @@ message HttpConnectionManager { // Whether unix socket addresses should be considered internal. bool unix_sockets = 1; + + // List of CIDR ranges that are treated as internal. If unset, then RFC1918 / RFC4193 + // IP addresses will be considered internal. + repeated config.core.v3.CidrRange cidr_ranges = 2; } // [#next-free-field: 7] @@ -301,6 +307,54 @@ message HttpConnectionManager { type.http.v3.PathTransformation http_filter_transformation = 2; } + // Configures the manner in which the Proxy-Status HTTP response header is + // populated. + // + // See the [Proxy-Status + // RFC](https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-proxy-status-08). + // [#comment:TODO: Update this with the non-draft URL when finalized.] + // + // The Proxy-Status header is a string of the form: + // + // "; error=; details=
    " + // [#next-free-field: 7] + message ProxyStatusConfig { + // If true, the details field of the Proxy-Status header is not populated with stream_info.response_code_details. + // This value defaults to `false`, i.e. the `details` field is populated by default. + bool remove_details = 1; + + // If true, the details field of the Proxy-Status header will not contain + // connection termination details. This value defaults to `false`, i.e. the + // `details` field will contain connection termination details by default. + bool remove_connection_termination_details = 2; + + // If true, the details field of the Proxy-Status header will not contain an + // enumeration of the Envoy ResponseFlags. This value defaults to `false`, + // i.e. the `details` field will contain a list of ResponseFlags by default. + bool remove_response_flags = 3; + + // If true, overwrites the existing Status header with the response code + // recommended by the Proxy-Status spec. + // This value defaults to `false`, i.e. the HTTP response code is not + // overwritten. + bool set_recommended_response_code = 4; + + // The name of the proxy as it appears at the start of the Proxy-Status + // header. + // + // If neither of these values are set, this value defaults to `server_name`, + // which itself defaults to "envoy". + oneof proxy_name { + // If `use_node_id` is set, Proxy-Status headers will use the Envoy's node + // ID as the name of the proxy. + bool use_node_id = 5; + + // If `literal_proxy_name` is set, Proxy-Status headers will use this + // value as the name of the proxy. + string literal_proxy_name = 6; + } + } + reserved 27, 11; reserved "idle_timeout"; @@ -706,6 +760,11 @@ message HttpConnectionManager { // setting this option will strip a trailing dot, if present, from the host section, // leaving the port as is (e.g. host value `example.com.:443` will be updated to `example.com:443`). bool strip_trailing_host_dot = 47; + + // Proxy-Status HTTP response header configuration. + // If this config is set, the Proxy-Status HTTP response header field is + // populated. By default, it is not. + ProxyStatusConfig proxy_status_config = 49; } // The configuration to customize local reply returned by Envoy. @@ -911,7 +970,7 @@ message ScopedRoutes { // Configuration source specifier for RDS. // This config source is used to subscribe to RouteConfiguration resources specified in // ScopedRouteConfiguration messages. - config.core.v3.ConfigSource rds_config_source = 3 [(validate.rules).message = {required: true}]; + config.core.v3.ConfigSource rds_config_source = 3; oneof config_specifier { option (validate.required) = true; @@ -954,9 +1013,7 @@ message HttpFilter { reserved "config"; - // The name of the filter configuration. The name is used as a fallback to - // select an extension if the type of the configuration proto is not - // sufficient. It also serves as a resource name in ExtensionConfigDS. + // The name of the filter configuration. It also serves as a resource name in ExtensionConfigDS. string name = 1 [(validate.rules).string = {min_len: 1}]; oneof config_type { diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto new file mode 100644 index 00000000000..97efd918325 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; + +package envoy.extensions.load_balancing_policies.least_request.v3; + +import "envoy/config/cluster/v3/cluster.proto"; +import "envoy/config/core/v3/base.proto"; + +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3"; +option java_outer_classname = "LeastRequestProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/load_balancing_policies/least_request/v3;least_requestv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Least Request Load Balancing Policy] + +// This configuration allows the built-in LEAST_REQUEST LB policy to be configured via the LB policy +// extension point. See the :ref:`load balancing architecture overview +// ` for more information. +// [#extension: envoy.clusters.lb_policy] +message LeastRequest { + // The number of random healthy hosts from which the host with the fewest active requests will + // be chosen. Defaults to 2 so that we perform two-choice selection if the field is not set. + google.protobuf.UInt32Value choice_count = 1 [(validate.rules).uint32 = {gte: 2}]; + + // The following formula is used to calculate the dynamic weights when hosts have different load + // balancing weights: + // + // `weight = load_balancing_weight / (active_requests + 1)^active_request_bias` + // + // The larger the active request bias is, the more aggressively active requests will lower the + // effective weight when all host weights are not equal. + // + // `active_request_bias` must be greater than or equal to 0.0. + // + // When `active_request_bias == 0.0` the Least Request Load Balancer doesn't consider the number + // of active requests at the time it picks a host and behaves like the Round Robin Load + // Balancer. + // + // When `active_request_bias > 0.0` the Least Request Load Balancer scales the load balancing + // weight by the number of active requests at the time it does a pick. + // + // The value is cached for performance reasons and refreshed whenever one of the Load Balancer's + // host sets changes, e.g., whenever there is a host membership update or a host load balancing + // weight change. + // + // .. note:: + // This setting only takes effect if all host weights are not equal. + config.core.v3.RuntimeDouble active_request_bias = 2; + + // Configuration for slow start mode. + // If this configuration is not set, slow start will not be not enabled. + config.cluster.v3.Cluster.SlowStartConfig slow_start_config = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/ring_hash/v3/ring_hash.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/ring_hash/v3/ring_hash.proto new file mode 100644 index 00000000000..9408734becb --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/ring_hash/v3/ring_hash.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package envoy.extensions.load_balancing_policies.ring_hash.v3; + +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3"; +option java_outer_classname = "RingHashProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/load_balancing_policies/ring_hash/v3;ring_hashv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Ring Hash Load Balancing Policy] + +// This configuration allows the built-in RING_HASH LB policy to be configured via the LB policy +// extension point. See the :ref:`load balancing architecture overview +// ` for more information. +// [#extension: envoy.clusters.lb_policy] +// [#next-free-field: 6] +message RingHash { + // The hash function used to hash hosts onto the ketama ring. + enum HashFunction { + // Currently defaults to XX_HASH. + DEFAULT_HASH = 0; + + // Use `xxHash `_. + XX_HASH = 1; + + // Use `MurmurHash2 `_, this is compatible with + // std:hash in GNU libstdc++ 3.4.20 or above. This is typically the case when compiled + // on Linux and not macOS. + MURMUR_HASH_2 = 2; + } + + // The hash function used to hash hosts onto the ketama ring. The value defaults to + // :ref:`XX_HASH`. + HashFunction hash_function = 1 [(validate.rules).enum = {defined_only: true}]; + + // Minimum hash ring size. The larger the ring is (that is, the more hashes there are for each + // provided host) the better the request distribution will reflect the desired weights. Defaults + // to 1024 entries, and limited to 8M entries. See also + // :ref:`maximum_ring_size`. + google.protobuf.UInt64Value minimum_ring_size = 2 [(validate.rules).uint64 = {lte: 8388608}]; + + // Maximum hash ring size. Defaults to 8M entries, and limited to 8M entries, but can be lowered + // to further constrain resource use. See also + // :ref:`minimum_ring_size`. + google.protobuf.UInt64Value maximum_ring_size = 3 [(validate.rules).uint64 = {lte: 8388608}]; + + // If set to `true`, the cluster will use hostname instead of the resolved + // address as the key to consistently hash to an upstream host. Only valid for StrictDNS clusters with hostnames which resolve to a single IP address. + bool use_hostname_for_hashing = 4; + + // Configures percentage of average cluster load to bound per upstream host. For example, with a value of 150 + // no upstream host will get a load more than 1.5 times the average load of all the hosts in the cluster. + // If not specified, the load is not bounded for any upstream host. Typical value for this parameter is between 120 and 200. + // Minimum is 100. + // + // This is implemented based on the method described in the paper https://arxiv.org/abs/1608.01350. For the specified + // `hash_balance_factor`, requests to any upstream host are capped at `hash_balance_factor/100` times the average number of requests + // across the cluster. When a request arrives for an upstream host that is currently serving at its max capacity, linear probing + // is used to identify an eligible host. Further, the linear probe is implemented using a random jump in hosts ring/table to identify + // the eligible host (this technique is as described in the paper https://arxiv.org/abs/1908.08762 - the random jump avoids the + // cascading overflow effect when choosing the next host in the ring/table). + // + // If weights are specified on the hosts, they are respected. + // + // This is an O(N) algorithm, unlike other load balancers. Using a lower `hash_balance_factor` results in more hosts + // being probed, so use a higher value if you require better performance. + google.protobuf.UInt32Value hash_balance_factor = 5 [(validate.rules).uint32 = {gte: 100}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/round_robin/v3/round_robin.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/round_robin/v3/round_robin.proto new file mode 100644 index 00000000000..4875b632f9c --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/round_robin/v3/round_robin.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package envoy.extensions.load_balancing_policies.round_robin.v3; + +import "envoy/config/cluster/v3/cluster.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3"; +option java_outer_classname = "RoundRobinProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/load_balancing_policies/round_robin/v3;round_robinv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Round Robin Load Balancing Policy] + +// This configuration allows the built-in ROUND_ROBIN LB policy to be configured via the LB policy +// extension point. See the :ref:`load balancing architecture overview +// ` for more information. +// [#extension: envoy.clusters.lb_policy] +message RoundRobin { + // Configuration for slow start mode. + // If this configuration is not set, slow start will not be not enabled. + config.cluster.v3.Cluster.SlowStartConfig slow_start_config = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto new file mode 100644 index 00000000000..cad403dd35e --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package envoy.extensions.load_balancing_policies.wrr_locality.v3; + +import "envoy/config/cluster/v3/cluster.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3"; +option java_outer_classname = "WrrLocalityProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/load_balancing_policies/wrr_locality/v3;wrr_localityv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Weighted Round Robin Locality-Picking Load Balancing Policy] + +// Configuration for the wrr_locality LB policy. See the :ref:`load balancing architecture overview +// ` for more information. +// [#extension: envoy.clusters.lb_policy] +message WrrLocality { + // The child LB policy to create for endpoint-picking within the chosen locality. + config.cluster.v3.LoadBalancingPolicy endpoint_picking_policy = 1 + [(validate.rules).message = {required: true}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/cert.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/cert.proto index b451d45381c..8a5f8962bd1 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/cert.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/cert.proto @@ -9,3 +9,4 @@ import public "envoy/extensions/transport_sockets/tls/v3/tls.proto"; option java_package = "io.envoyproxy.envoy.extensions.transport_sockets.tls.v3"; option java_outer_classname = "CertProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3;tlsv3"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto index 82dcb37cd7c..d38d4edf911 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto @@ -9,6 +9,7 @@ import "envoy/type/matcher/v3/string.proto"; import "google/protobuf/any.proto"; import "google/protobuf/wrappers.proto"; +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/sensitive.proto"; import "udpa/annotations/status.proto"; @@ -18,6 +19,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.transport_sockets.tls.v3"; option java_outer_classname = "CommonProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3;tlsv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common TLS configuration] @@ -42,8 +44,7 @@ message TlsParameters { TLSv1_3 = 4; } - // Minimum TLS protocol version. By default, it's ``TLSv1_2`` for clients and ``TLSv1_0`` for - // servers. + // Minimum TLS protocol version. By default, it's ``TLSv1_2`` for both clients and servers. TlsProtocol tls_minimum_protocol_version = 1 [(validate.rules).enum = {defined_only: true}]; // Maximum TLS protocol version. By default, it's ``TLSv1_2`` for clients and ``TLSv1_3`` for @@ -56,6 +57,8 @@ message TlsParameters { // // If not specified, a default list will be used. Defaults are different for server (downstream) and // client (upstream) TLS configurations. + // Defaults will change over time in response to security considerations; If you care, configure + // it instead of using the default. // // In non-FIPS builds, the default server cipher list is: // @@ -63,16 +66,8 @@ message TlsParameters { // // [ECDHE-ECDSA-AES128-GCM-SHA256|ECDHE-ECDSA-CHACHA20-POLY1305] // [ECDHE-RSA-AES128-GCM-SHA256|ECDHE-RSA-CHACHA20-POLY1305] - // ECDHE-ECDSA-AES128-SHA - // ECDHE-RSA-AES128-SHA - // AES128-GCM-SHA256 - // AES128-SHA // ECDHE-ECDSA-AES256-GCM-SHA384 // ECDHE-RSA-AES256-GCM-SHA384 - // ECDHE-ECDSA-AES256-SHA - // ECDHE-RSA-AES256-SHA - // AES256-GCM-SHA384 - // AES256-SHA // // In builds using :ref:`BoringSSL FIPS `, the default server cipher list is: // @@ -80,16 +75,8 @@ message TlsParameters { // // ECDHE-ECDSA-AES128-GCM-SHA256 // ECDHE-RSA-AES128-GCM-SHA256 - // ECDHE-ECDSA-AES128-SHA - // ECDHE-RSA-AES128-SHA - // AES128-GCM-SHA256 - // AES128-SHA // ECDHE-ECDSA-AES256-GCM-SHA384 // ECDHE-RSA-AES256-GCM-SHA384 - // ECDHE-ECDSA-AES256-SHA - // ECDHE-RSA-AES256-SHA - // AES256-GCM-SHA384 - // AES256-SHA // // In non-FIPS builds, the default client cipher list is: // @@ -149,7 +136,7 @@ message PrivateKeyProvider { } } -// [#next-free-field: 8] +// [#next-free-field: 9] message TlsCertificate { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.TlsCertificate"; @@ -168,6 +155,21 @@ message TlsCertificate { // applies to dynamic secrets, when the *TlsCertificate* is delivered via SDS. config.core.v3.DataSource private_key = 2 [(udpa.annotations.sensitive) = true]; + // `Pkcs12` data containing TLS certificate, chain, and private key. + // + // If *pkcs12* is a filesystem path, the file will be read, but no watch will + // be added to the parent directory, since *pkcs12* isn't used by SDS. + // This field is mutually exclusive with *certificate_chain*, *private_key* and *private_key_provider*. + // This can't be marked as ``oneof`` due to API compatibility reasons. Setting + // both :ref:`private_key `, + // :ref:`certificate_chain `, + // or :ref:`private_key_provider ` + // and :ref:`pkcs12 ` + // fields will result in an error. Use :ref:`password + // ` + // to specify the password to unprotect the `PKCS12` data, if necessary. + config.core.v3.DataSource pkcs12 = 8 [(udpa.annotations.sensitive) = true]; + // If specified, updates of file-based *certificate_chain* and *private_key* // sources will be triggered by this watch. The certificate/key pair will be // read together and validated for atomic read consistency (i.e. no @@ -253,7 +255,26 @@ message CertificateProviderPluginInstance { string certificate_name = 2; } -// [#next-free-field: 14] +// Matcher for subject alternative names, to match both type and value of the SAN. +message SubjectAltNameMatcher { + // Indicates the choice of GeneralName as defined in section 4.2.1.5 of RFC 5280 to match + // against. + enum SanType { + SAN_TYPE_UNSPECIFIED = 0; + EMAIL = 1; + DNS = 2; + URI = 3; + IP_ADDRESS = 4; + } + + // Specification of type of SAN. Note that the default enum value is an invalid choice. + SanType san_type = 1 [(validate.rules).enum = {defined_only: true not_in: 0}]; + + // Matcher for SAN value. + type.matcher.v3.StringMatcher matcher = 2 [(validate.rules).message = {required: true}]; +} + +// [#next-free-field: 17] message CertificateValidationContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CertificateValidationContext"; @@ -283,8 +304,8 @@ message CertificateValidationContext { // `, // :ref:`verify_certificate_hash // `, or - // :ref:`match_subject_alt_names - // `) is also + // :ref:`match_typed_subject_alt_names + // `) is also // specified. // // It can optionally contain certificate revocation lists, in which case Envoy will verify @@ -292,6 +313,9 @@ message CertificateValidationContext { // that if a CRL is provided for any certificate authority in a trust chain, a CRL must be // provided for all certificate authorities in that chain. Failure to do so will result in // verification failure for both revoked and unrevoked certificates from that chain. + // The behavior of requiring all certificates to contain CRLs if any do can be altered by + // setting :ref:`only_verify_leaf_cert_crl ` + // true. If set to true, only the final certificate in the chain undergoes CRL verification. // // See :ref:`the TLS overview ` for a list of common // system CA locations. @@ -388,6 +412,8 @@ message CertificateValidationContext { // An optional list of Subject Alternative name matchers. If specified, Envoy will verify that the // Subject Alternative Name of the presented certificate matches one of the specified matchers. + // The matching uses "any" semantics, that is to say, the SAN is verified if at least one matcher is + // matched. // // When a certificate has wildcard DNS SAN entries, to match a specific client, it should be // configured with exact match type in the :ref:`string matcher `. @@ -396,15 +422,26 @@ message CertificateValidationContext { // // .. code-block:: yaml // - // match_subject_alt_names: - // exact: "api.example.com" + // match_typed_subject_alt_names: + // - san_type: DNS + // matcher: + // exact: "api.example.com" // // .. attention:: // // Subject Alternative Names are easily spoofable and verifying only them is insecure, // therefore this option must be used together with :ref:`trusted_ca // `. - repeated type.matcher.v3.StringMatcher match_subject_alt_names = 9; + repeated SubjectAltNameMatcher match_typed_subject_alt_names = 15; + + // This field is deprecated in favor of + // :ref:`match_typed_subject_alt_names + // `. + // Note that if both this field and :ref:`match_typed_subject_alt_names + // ` + // are specified, the former (deprecated field) is ignored. + repeated type.matcher.v3.StringMatcher match_subject_alt_names = 9 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // [#not-implemented-hide:] Must present signed certificate time-stamp. google.protobuf.BoolValue require_signed_certificate_timestamp = 6; @@ -417,7 +454,9 @@ message CertificateValidationContext { // for any certificate authority in a trust chain, a CRL must be provided // for all certificate authorities in that chain. Failure to do so will // result in verification failure for both revoked and unrevoked certificates - // from that chain. + // from that chain. This default behavior can be altered by setting + // :ref:`only_verify_leaf_cert_crl ` to + // true. config.core.v3.DataSource crl = 7; // If specified, Envoy will not reject expired certificates. @@ -433,4 +472,15 @@ message CertificateValidationContext { // Refer to the documentation for the specified validator. If you do not want a custom validation algorithm, do not set this field. // [#extension-category: envoy.tls.cert_validator] config.core.v3.TypedExtensionConfig custom_validator_config = 12; + + // If this option is set to true, only the certificate at the end of the + // certificate chain will be subject to validation by :ref:`CRL `. + bool only_verify_leaf_cert_crl = 14; + + // Config for the max number of intermediate certificates in chain that are parsed during verification. + // This does not include the leaf certificate. If configured, and the certificate chain is longer than allowed, the certificates + // above the limit are ignored, and certificate validation will fail. The default limit is 100, + // though this can be system-dependent. + // https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_verify_depth.html + google.protobuf.UInt32Value max_verify_depth = 16 [(validate.rules).uint32 = {lte: 100}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto index f7c849c0334..83ad364c4bf 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto @@ -14,6 +14,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.transport_sockets.tls.v3"; option java_outer_classname = "SecretProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3;tlsv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Secrets configuration] diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto index f680207955a..03cf5be8e64 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package envoy.extensions.transport_sockets.tls.v3; +import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/extensions/transport_sockets/tls/v3/common.proto"; import "envoy/extensions/transport_sockets/tls/v3/secret.proto"; @@ -17,6 +18,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.extensions.transport_sockets.tls.v3"; option java_outer_classname = "TlsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3;tlsv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: TLS transport socket] @@ -109,10 +111,9 @@ message DownstreamTlsContext { bool disable_stateless_session_resumption = 7; } - // If specified, session_timeout will change maximum lifetime (in seconds) of TLS session - // Currently this value is used as a hint to `TLS session ticket lifetime (for TLSv1.2) - // ` - // only seconds could be specified (fractional seconds are going to be ignored). + // If specified, ``session_timeout`` will change the maximum lifetime (in seconds) of the TLS session. + // Currently this value is used as a hint for the `TLS session ticket lifetime (for TLSv1.2) `_. + // Only seconds can be specified (fractional seconds are ignored). google.protobuf.Duration session_timeout = 6 [(validate.rules).duration = { lt {seconds: 4294967296} gte {} @@ -124,8 +125,23 @@ message DownstreamTlsContext { OcspStaplePolicy ocsp_staple_policy = 8 [(validate.rules).enum = {defined_only: true}]; } +// TLS key log configuration. +// The key log file format is "format used by NSS for its SSLKEYLOGFILE debugging output" (text taken from openssl man page) +message TlsKeyLog { + // The path to save the TLS key log. + string path = 1 [(validate.rules).string = {min_len: 1}]; + + // The local IP address that will be used to filter the connection which should save the TLS key log + // If it is not set, any local IP address will be matched. + repeated config.core.v3.CidrRange local_address_range = 2; + + // The remote IP address that will be used to filter the connection which should save the TLS key log + // If it is not set, any remote IP address will be matched. + repeated config.core.v3.CidrRange remote_address_range = 3; +} + // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 15] +// [#next-free-field: 16] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; @@ -299,4 +315,7 @@ message CommonTlsContext { // Custom TLS handshaker. If empty, defaults to native TLS handshaking // behavior. config.core.v3.TypedExtensionConfig custom_handshaker = 13; + + // TLS key log configuration + TlsKeyLog key_log = 15; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/ads.proto b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/ads.proto index d70e0cdc8e1..1da1606bf64 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/ads.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/ads.proto @@ -9,17 +9,18 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.service.discovery.v2"; option java_outer_classname = "AdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2;discoveryv2"; option java_generic_services = true; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Aggregated Discovery Service (ADS)] -// [#not-implemented-hide:] Discovery services for endpoints, clusters, routes, +// Discovery services for endpoints, clusters, routes, // and listeners are retained in the package `envoy.api.v2` for backwards // compatibility with existing management servers. New development in discovery // services should proceed in the package `envoy.service.discovery.v2`. -// See https://github.com/lyft/envoy-api#apis for a description of the role of +// See https://github.com/envoyproxy/envoy-api#apis for a description of the role of // ADS and how it is intended to be used by a management server. ADS requests // have the same structure as their singleton xDS counterparts, but can // multiplex many resource types on a single stream. The type_url in the diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/sds.proto b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/sds.proto index 4d01d475c59..d7a30dad40f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/sds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v2/sds.proto @@ -13,6 +13,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.service.discovery.v2"; option java_outer_classname = "SdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2;discoveryv2"; option java_generic_services = true; option (udpa.annotations.file_migrate).move_to_package = "envoy.service.secret.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/ads.proto b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/ads.proto index 03021559ab6..2a07622714c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/ads.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/ads.proto @@ -10,17 +10,18 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.service.discovery.v3"; option java_outer_classname = "AdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3;discoveryv3"; option java_generic_services = true; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Aggregated Discovery Service (ADS)] -// [#not-implemented-hide:] Discovery services for endpoints, clusters, routes, +// Discovery services for endpoints, clusters, routes, // and listeners are retained in the package `envoy.api.v2` for backwards // compatibility with existing management servers. New development in discovery // services should proceed in the package `envoy.service.discovery.v2`. -// See https://github.com/lyft/envoy-api#apis for a description of the role of +// See https://github.com/envoyproxy/envoy-api#apis for a description of the role of // ADS and how it is intended to be used by a management server. ADS requests // have the same structure as their singleton xDS counterparts, but can // multiplex many resource types on a single stream. The type_url in the diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto index 4a474d0fe26..ab269b08769 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto @@ -10,17 +10,40 @@ import "google/rpc/status.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; +import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.service.discovery.v3"; option java_outer_classname = "DiscoveryProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3;discoveryv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common discovery API components] +// Specifies a resource to be subscribed to. +message ResourceLocator { + // The resource name to subscribe to. + string name = 1; + + // A set of dynamic parameters used to match against the dynamic parameter + // constraints on the resource. This allows clients to select between + // multiple variants of the same resource. + map dynamic_parameters = 2; +} + +// Specifies a concrete resource name. +message ResourceName { + // The name of the resource. + string name = 1; + + // Dynamic parameter constraints associated with this resource. To be used by client-side caches + // (including xDS proxies) when matching subscribed resource locators. + DynamicParameterConstraints dynamic_parameter_constraints = 2; +} + // A DiscoveryRequest requests a set of versioned resources of the same type for // a given Envoy node on some API. -// [#next-free-field: 7] +// [#next-free-field: 8] message DiscoveryRequest { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DiscoveryRequest"; @@ -44,6 +67,15 @@ message DiscoveryRequest { // which will be explicitly enumerated in resource_names. repeated string resource_names = 3; + // [#not-implemented-hide:] + // Alternative to *resource_names* field that allows specifying dynamic + // parameters along with each resource name. Clients that populate this + // field must be able to handle responses from the server where resources + // are wrapped in a Resource message. + // Note that it is legal for a request to have some resources listed + // in *resource_names* and others in *resource_locators*. + repeated ResourceLocator resource_locators = 7; + // Type of the resource that is being requested, e.g. // "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment". This is implicit // in requests made via singleton xDS APIs such as CDS, LDS, etc. but is @@ -140,7 +172,7 @@ message DiscoveryResponse { // In particular, initial_resource_versions being sent at the "start" of every // gRPC stream actually entails a message for each type_url, each with its own // initial_resource_versions. -// [#next-free-field: 8] +// [#next-free-field: 10] message DeltaDiscoveryRequest { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DeltaDiscoveryRequest"; @@ -179,6 +211,20 @@ message DeltaDiscoveryRequest { // A list of Resource names to remove from the list of tracked resources. repeated string resource_names_unsubscribe = 4; + // [#not-implemented-hide:] + // Alternative to *resource_names_subscribe* field that allows specifying dynamic parameters + // along with each resource name. + // Note that it is legal for a request to have some resources listed + // in *resource_names_subscribe* and others in *resource_locators_subscribe*. + repeated ResourceLocator resource_locators_subscribe = 8; + + // [#not-implemented-hide:] + // Alternative to *resource_names_unsubscribe* field that allows specifying dynamic parameters + // along with each resource name. + // Note that it is legal for a request to have some resources listed + // in *resource_names_unsubscribe* and others in *resource_locators_unsubscribe*. + repeated ResourceLocator resource_locators_unsubscribe = 9; + // Informs the server of the versions of the resources the xDS client knows of, to enable the // client to continue the same logical xDS session even in the face of gRPC stream reconnection. // It will not be populated: [1] in the very first stream of a session, since the client will @@ -201,7 +247,7 @@ message DeltaDiscoveryRequest { google.rpc.Status error_detail = 7; } -// [#next-free-field: 8] +// [#next-free-field: 9] message DeltaDiscoveryResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DeltaDiscoveryResponse"; @@ -223,6 +269,11 @@ message DeltaDiscoveryResponse { // Removed resources for missing resources can be ignored. repeated string removed_resources = 6; + // Alternative to removed_resources that allows specifying which variant of + // a resource is being removed. This variant must be used for any resource + // for which dynamic parameter constraints were sent to the client. + repeated ResourceName removed_resource_names = 8; + // The nonce provides a way for DeltaDiscoveryRequests to uniquely // reference a DeltaDiscoveryResponse when (N)ACKing. The nonce is required. string nonce = 5; @@ -232,7 +283,56 @@ message DeltaDiscoveryResponse { config.core.v3.ControlPlane control_plane = 7; } -// [#next-free-field: 8] +// A set of dynamic parameter constraints associated with a variant of an individual xDS resource. +// These constraints determine whether the resource matches a subscription based on the set of +// dynamic parameters in the subscription, as specified in the +// :ref:`ResourceLocator.dynamic_parameters` +// field. This allows xDS implementations (clients, servers, and caching proxies) to determine +// which variant of a resource is appropriate for a given client. +message DynamicParameterConstraints { + // A single constraint for a given key. + message SingleConstraint { + message Exists { + } + + // The key to match against. + string key = 1; + + oneof constraint_type { + option (validate.required) = true; + + // Matches this exact value. + string value = 2; + + // Key is present (matches any value except for the key being absent). + // This allows setting a default constraint for clients that do + // not send a key at all, while there may be other clients that need + // special configuration based on that key. + Exists exists = 3; + } + } + + message ConstraintList { + repeated DynamicParameterConstraints constraints = 1; + } + + oneof type { + // A single constraint to evaluate. + SingleConstraint constraint = 1; + + // A list of constraints that match if any one constraint in the list + // matches. + ConstraintList or_constraints = 2; + + // A list of constraints that must all match. + ConstraintList and_constraints = 3; + + // The inverse (NOT) of a set of constraints. + DynamicParameterConstraints not_constraints = 4; + } +} + +// [#next-free-field: 9] message Resource { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Resource"; @@ -246,8 +346,15 @@ message Resource { } // The resource's name, to distinguish it from others of the same type of resource. + // Only one of *name* or *resource_name* may be set. string name = 3; + // Alternative to the *name* field, to be used when the server supports + // multiple variants of the named resource that are differentiated by + // dynamic parameter constraints. + // Only one of *name* or *resource_name* may be set. + ResourceName resource_name = 8; + // The aliases are a list of other names that this resource can go by. repeated string aliases = 4; diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v2/lrs.proto b/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v2/lrs.proto index 7ab87c2dfb0..c39d74aacf6 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v2/lrs.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v2/lrs.proto @@ -12,6 +12,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.service.load_stats.v2"; option java_outer_classname = "LrsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v2;load_statsv2"; option java_generic_services = true; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v3/lrs.proto b/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v3/lrs.proto index 0b565ebe723..6f7545376da 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v3/lrs.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/load_stats/v3/lrs.proto @@ -13,6 +13,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.service.load_stats.v3"; option java_outer_classname = "LrsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v3;load_statsv3"; option java_generic_services = true; option (udpa.annotations.file_status).package_version_status = ACTIVE; diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto b/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto index 1d940d6a2df..89d92efd2d0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto @@ -17,6 +17,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.service.status.v3"; option java_outer_classname = "CsdsProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/status/v3;statusv3"; option java_generic_services = true; option (udpa.annotations.file_status).package_version_status = ACTIVE; diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/http.proto b/xds/third_party/envoy/src/main/proto/envoy/type/http.proto index c1c787411fa..51768f17368 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/http.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/http.proto @@ -7,6 +7,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.type"; option java_outer_classname = "HttpProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: HTTP] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/http/v3/path_transformation.proto b/xds/third_party/envoy/src/main/proto/envoy/type/http/v3/path_transformation.proto index 0b3d72009f5..50350c48f9d 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/http/v3/path_transformation.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/http/v3/path_transformation.proto @@ -8,6 +8,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.http.v3"; option java_outer_classname = "PathTransformationProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/http/v3;httpv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Path Transformations API] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/metadata.proto index ed58d04adb0..20da230b4fd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/metadata.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher"; option java_outer_classname = "MetadataProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Metadata matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/number.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/number.proto index e488f16a4a0..4c5b4db38d0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/number.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/number.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher"; option java_outer_classname = "NumberProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Number matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/path.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/path.proto index 860a1c69f18..1a97bbc154a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/path.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/path.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher"; option java_outer_classname = "PathProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Path matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/regex.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/regex.proto index 6c499235bbe..6daa16e478f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/regex.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/regex.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher"; option java_outer_classname = "RegexProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Regex matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/string.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/string.proto index 499eaf21775..b4571ce727a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/string.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher"; option java_outer_classname = "StringProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: String matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto index 68710dc7185..d3316e88a88 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "MetadataProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Metadata matcher] @@ -101,4 +102,7 @@ message MetadataMatcher { // The MetadataMatcher is matched if the value retrieved by path is matched to this value. ValueMatcher value = 3 [(validate.rules).message = {required: true}]; + + // If true, the match result will be inverted. + bool invert = 4; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/node.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/node.proto index fe507312135..baa92fb6035 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/node.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/node.proto @@ -11,6 +11,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "NodeProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Node matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/number.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/number.proto index 2379efdcbd2..99681c989ca 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/number.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/number.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "NumberProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Number matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/path.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/path.proto index 0ce89871c9d..d332a17d6b7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/path.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/path.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "PathProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Path matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/regex.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/regex.proto index 3e7bb477ecb..f18bd03e2ba 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/regex.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/regex.proto @@ -12,6 +12,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "RegexProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Regex matcher] @@ -44,6 +45,12 @@ message RegexMatcher { // // This field is deprecated; regexp validation should be performed on the management server // instead of being done by each individual client. + // + // .. note:: + // + // Although this field is deprecated, the program size will still be checked against the + // global ``re2.max_program_size.error_level`` runtime value. + // google.protobuf.UInt32Value max_program_size = 1 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } @@ -55,7 +62,8 @@ message RegexMatcher { GoogleRE2 google_re2 = 1 [(validate.rules).message = {required: true}]; } - // The regex match string. The string must be supported by the configured engine. + // The regex match string. The string must be supported by the configured engine. The regex is matched + // against the full string, not as a partial match. string regex = 2 [(validate.rules).string = {min_len: 1}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto index c64edde142f..efea6c0ab4b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "StringProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: String matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/struct.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/struct.proto index c753d07a5c0..ad72e2cc783 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/struct.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/struct.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "StructProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Struct matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto index 040332273ba..bd46acc0713 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto @@ -12,6 +12,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher.v3"; option java_outer_classname = "ValueProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Value matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/value.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/value.proto index aaecd14e8ec..89d341bbbaa 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/value.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/value.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.matcher"; option java_outer_classname = "ValueProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Value matcher] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v2/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v2/metadata.proto index 43a1a7ca927..75f025009da 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v2/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v2/metadata.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.metadata.v2"; option java_outer_classname = "MetadataProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/metadata/v2;metadatav2"; option (udpa.annotations.file_migrate).move_to_package = "envoy.type.metadata.v3"; option (udpa.annotations.file_status).package_version_status = FROZEN; diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto index 5dd58b23c62..0d535374b81 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.metadata.v3"; option java_outer_classname = "MetadataProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/metadata/v3;metadatav3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Metadata] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/percent.proto b/xds/third_party/envoy/src/main/proto/envoy/type/percent.proto index fc41a26662f..6457e2a035f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/percent.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/percent.proto @@ -8,6 +8,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type"; option java_outer_classname = "PercentProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Percent] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/range.proto b/xds/third_party/envoy/src/main/proto/envoy/type/range.proto index 79aaa81975c..9e66e6f2258 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/range.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/range.proto @@ -7,6 +7,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.type"; option java_outer_classname = "RangeProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Range] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/semantic_version.proto b/xds/third_party/envoy/src/main/proto/envoy/type/semantic_version.proto index 80fe016bfa1..f6a508cc958 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/semantic_version.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/semantic_version.proto @@ -7,6 +7,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.type"; option java_outer_classname = "SemanticVersionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Semantic Version] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v2/custom_tag.proto b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v2/custom_tag.proto index 7506ae88612..c37b662e51d 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v2/custom_tag.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v2/custom_tag.proto @@ -10,6 +10,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.tracing.v2"; option java_outer_classname = "CustomTagProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/tracing/v2;tracingv2"; option (udpa.annotations.file_status).package_version_status = FROZEN; // [#protodoc-title: Custom Tag] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto index ad99cafb22b..feb57e8eb66 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto @@ -11,6 +11,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.tracing.v3"; option java_outer_classname = "CustomTagProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/tracing/v3;tracingv3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Custom Tag] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/http.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/http.proto index fec15d11f87..a1a5a04fc87 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/v3/http.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/http.proto @@ -7,6 +7,7 @@ import "udpa/annotations/status.proto"; option java_package = "io.envoyproxy.envoy.type.v3"; option java_outer_classname = "HttpProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: HTTP] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/percent.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/percent.proto index 3a89a3f44fd..e041ecddc78 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/v3/percent.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/percent.proto @@ -9,6 +9,7 @@ import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.type.v3"; option java_outer_classname = "PercentProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Percent] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/range.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/range.proto index de1d55b09a2..3b1af814858 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/v3/range.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/range.proto @@ -8,6 +8,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.type.v3"; option java_outer_classname = "RangeProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Range] diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/semantic_version.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/semantic_version.proto index a4126336f03..e1567612ab7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/v3/semantic_version.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/semantic_version.proto @@ -8,6 +8,7 @@ import "udpa/annotations/versioning.proto"; option java_package = "io.envoyproxy.envoy.type.v3"; option java_outer_classname = "SemanticVersionProto"; option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Semantic Version] diff --git a/xds/third_party/istio/src/main/proto/security/proto/providers/google/meshca.proto b/xds/third_party/istio/src/main/proto/security/proto/providers/google/meshca.proto deleted file mode 100644 index c02b7f58287..00000000000 --- a/xds/third_party/istio/src/main/proto/security/proto/providers/google/meshca.proto +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2019 Istio Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package google.security.meshca.v1; - -import "google/protobuf/duration.proto"; - -option java_multiple_files = true; -option java_outer_classname = "MeshCaProto"; -option java_package = "com.google.security.meshca.v1"; - -// Certificate request message. -message MeshCertificateRequest { - // The request ID must be a valid UUID with the exception that zero UUID is - // not supported (00000000-0000-0000-0000-000000000000). - string request_id = 1; - // PEM-encoded certificate request. - string csr = 2; - // Optional: requested certificate validity period. - google.protobuf.Duration validity = 3; - // Reserved 4 -} - -// Certificate response message. -message MeshCertificateResponse { - // PEM-encoded certificate chain. - // Leaf cert is element '0'. Root cert is element 'n'. - repeated string cert_chain = 1; -} - -// Service for managing certificates issued by the CSM CA. -service MeshCertificateService { - // Using provided CSR, returns a signed certificate that represents a GCP - // service account identity. - rpc CreateCertificate(MeshCertificateRequest) - returns (MeshCertificateResponse) { - } -} diff --git a/xds/third_party/protoc-gen-validate/import.sh b/xds/third_party/protoc-gen-validate/import.sh index 62b6158b600..4e30b0e1180 100755 --- a/xds/third_party/protoc-gen-validate/import.sh +++ b/xds/third_party/protoc-gen-validate/import.sh @@ -16,9 +16,9 @@ # Update GIT_ORIGIN_REV_ID then in this directory run ./import.sh set -e -BRANCH=master +BRANCH=main # import GIT_ORIGIN_REV_ID from one of the google internal CLs -GIT_ORIGIN_REV_ID=ab56c3dd1cf9b516b62c5087e1ec1471bd63631e +GIT_ORIGIN_REV_ID=dfcdc5ea103dda467963fb7079e4df28debcfd28 GIT_REPO="https://github.com/envoyproxy/protoc-gen-validate.git" GIT_BASE_DIR=protoc-gen-validate SOURCE_PROTO_BASE_DIR=protoc-gen-validate diff --git a/xds/third_party/protoc-gen-validate/src/main/proto/validate/validate.proto b/xds/third_party/protoc-gen-validate/src/main/proto/validate/validate.proto index 7767f0aab92..705d382aac4 100644 --- a/xds/third_party/protoc-gen-validate/src/main/proto/validate/validate.proto +++ b/xds/third_party/protoc-gen-validate/src/main/proto/validate/validate.proto @@ -13,6 +13,8 @@ extend google.protobuf.MessageOptions { // Disabled nullifies any validation rules for this message, including any // message fields associated with it that do support validation. optional bool disabled = 1071; + // Ignore skips generation of validation methods for this message. + optional bool ignored = 1072; } // Validation rules applied at the oneof level @@ -93,6 +95,10 @@ message FloatRules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated float not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // DoubleRules describes the constraints applied to `double` values @@ -125,6 +131,10 @@ message DoubleRules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated double not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // Int32Rules describes the constraints applied to `int32` values @@ -157,6 +167,10 @@ message Int32Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated int32 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // Int64Rules describes the constraints applied to `int64` values @@ -189,6 +203,10 @@ message Int64Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated int64 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // UInt32Rules describes the constraints applied to `uint32` values @@ -221,6 +239,10 @@ message UInt32Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated uint32 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // UInt64Rules describes the constraints applied to `uint64` values @@ -253,6 +275,10 @@ message UInt64Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated uint64 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // SInt32Rules describes the constraints applied to `sint32` values @@ -285,6 +311,10 @@ message SInt32Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated sint32 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // SInt64Rules describes the constraints applied to `sint64` values @@ -317,6 +347,10 @@ message SInt64Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated sint64 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // Fixed32Rules describes the constraints applied to `fixed32` values @@ -349,6 +383,10 @@ message Fixed32Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated fixed32 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // Fixed64Rules describes the constraints applied to `fixed64` values @@ -381,6 +419,10 @@ message Fixed64Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated fixed64 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // SFixed32Rules describes the constraints applied to `sfixed32` values @@ -413,6 +455,10 @@ message SFixed32Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated sfixed32 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // SFixed64Rules describes the constraints applied to `sfixed64` values @@ -445,6 +491,10 @@ message SFixed64Rules { // NotIn specifies that this field cannot be equal to one of the specified // values repeated sfixed64 not_in = 7; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 8; } // BoolRules describes the constraints applied to `bool` values @@ -474,7 +524,6 @@ message StringRules { optional uint64 max_len = 3; // LenBytes specifies that this field must be the specified number of bytes - // at a minimum optional uint64 len_bytes = 20; // MinBytes specifies that this field must be the specified number of bytes @@ -564,6 +613,10 @@ message StringRules { // Setting to false will enable a looser validations that only disallows // \r\n\0 characters, which can be used to bypass header matching rules. optional bool strict = 25 [default = true]; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 26; } // WellKnownRegex contain some well-known patterns. @@ -633,6 +686,10 @@ message BytesRules { // format bool ipv6 = 12; } + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 14; } // EnumRules describe the constraints applied to enum values @@ -683,6 +740,10 @@ message RepeatedRules { // Repeated message fields will still execute validation against each item // unless skip is specified here. optional FieldRules items = 4; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 5; } // MapRules describe the constraints applied to `map` values @@ -706,6 +767,10 @@ message MapRules { // in the field. Message values will still have their validations evaluated // unless skip is specified here. optional FieldRules values = 5; + + // IgnoreEmpty specifies that the validation rules of this field should be + // evaluated only if the field is not empty + optional bool ignore_empty = 6; } // AnyRules describe constraints applied exclusively to the diff --git a/xds/third_party/xds/import.sh b/xds/third_party/xds/import.sh index 36889a52bba..d7054e3b47c 100755 --- a/xds/third_party/xds/import.sh +++ b/xds/third_party/xds/import.sh @@ -18,7 +18,7 @@ set -e BRANCH=main # import VERSION from one of the google internal CLs -VERSION=cb28da3451f158a947dfc45090fe92b07b243bc1 +VERSION=d92e9ce0af512a73a3a126b32fa4920bee12e180 GIT_REPO="https://github.com/cncf/xds.git" GIT_BASE_DIR=xds SOURCE_PROTO_BASE_DIR=xds @@ -26,26 +26,30 @@ TARGET_PROTO_BASE_DIR=src/main/proto # Sorted alphabetically. FILES=( udpa/annotations/migrate.proto -xds/annotations/v3/migrate.proto udpa/annotations/security.proto -xds/annotations/v3/security.proto udpa/annotations/security.proto -xds/annotations/v3/security.proto udpa/annotations/sensitive.proto -xds/annotations/v3/sensitive.proto udpa/annotations/status.proto -xds/annotations/v3/status.proto udpa/annotations/versioning.proto -xds/annotations/v3/versioning.proto -xds/data/orca/v3/orca_load_report.proto -xds/service/orca/v3/orca.proto udpa/type/v1/typed_struct.proto -xds/type/v3/typed_struct.proto +xds/annotations/v3/migrate.proto +xds/annotations/v3/security.proto +xds/annotations/v3/security.proto +xds/annotations/v3/sensitive.proto +xds/annotations/v3/status.proto +xds/annotations/v3/versioning.proto xds/core/v3/authority.proto xds/core/v3/collection_entry.proto xds/core/v3/context_params.proto +xds/core/v3/extension.proto xds/core/v3/resource_locator.proto xds/core/v3/resource_name.proto +xds/data/orca/v3/orca_load_report.proto +xds/service/orca/v3/orca.proto +xds/type/matcher/v3/matcher.proto +xds/type/matcher/v3/regex.proto +xds/type/matcher/v3/string.proto +xds/type/v3/typed_struct.proto ) pushd `git rev-parse --show-toplevel`/xds/third_party/xds diff --git a/xds/third_party/xds/src/main/proto/xds/core/v3/extension.proto b/xds/third_party/xds/src/main/proto/xds/core/v3/extension.proto new file mode 100644 index 00000000000..dd489eb9912 --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/core/v3/extension.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package xds.core.v3; + +option java_outer_classname = "ExtensionProto"; +option java_multiple_files = true; +option java_package = "com.github.xds.core.v3"; +option go_package = "github.com/cncf/xds/go/xds/core/v3"; + +import "validate/validate.proto"; +import "google/protobuf/any.proto"; + +// Message type for extension configuration. +message TypedExtensionConfig { + // The name of an extension. This is not used to select the extension, instead + // it serves the role of an opaque identifier. + string name = 1 [(validate.rules).string = {min_len: 1}]; + + // The typed config for the extension. The type URL will be used to identify + // the extension. In the case that the type URL is *xds.type.v3.TypedStruct* + // (or, for historical reasons, *udpa.type.v1.TypedStruct*), the inner type + // URL of *TypedStruct* will be utilized. See the + // :ref:`extension configuration overview + // ` for further details. + google.protobuf.Any typed_config = 2 [(validate.rules).any = {required: true}]; +} diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto new file mode 100644 index 00000000000..4966b456bee --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto @@ -0,0 +1,139 @@ +syntax = "proto3"; + +package xds.type.matcher.v3; + +import "xds/annotations/v3/status.proto"; +import "xds/core/v3/extension.proto"; +import "xds/type/matcher/v3/string.proto"; + +import "validate/validate.proto"; + +option java_package = "com.github.xds.type.matcher.v3"; +option java_outer_classname = "MatcherProto"; +option java_multiple_files = true; +option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; + +// [#protodoc-title: Unified Matcher API] + +// A matcher, which may traverse a matching tree in order to result in a match action. +// During matching, the tree will be traversed until a match is found, or if no match +// is found the action specified by the most specific on_no_match will be evaluated. +// As an on_no_match might result in another matching tree being evaluated, this process +// might repeat several times until the final OnMatch (or no match) is decided. +message Matcher { + option (xds.annotations.v3.message_status).work_in_progress = true; + + // What to do if a match is successful. + message OnMatch { + oneof on_match { + option (validate.required) = true; + + // Nested matcher to evaluate. + // If the nested matcher does not match and does not specify + // on_no_match, then this matcher is considered not to have + // matched, even if a predicate at this level or above returned + // true. + Matcher matcher = 1; + + // Protocol-specific action to take. + core.v3.TypedExtensionConfig action = 2; + } + } + + // A linear list of field matchers. + // The field matchers are evaluated in order, and the first match + // wins. + message MatcherList { + // Predicate to determine if a match is successful. + message Predicate { + // Predicate for a single input field. + message SinglePredicate { + // Protocol-specific specification of input field to match on. + // [#extension-category: envoy.matching.common_inputs] + core.v3.TypedExtensionConfig input = 1 [(validate.rules).message = {required: true}]; + + oneof matcher { + option (validate.required) = true; + + // Built-in string matcher. + type.matcher.v3.StringMatcher value_match = 2; + + // Extension for custom matching logic. + // [#extension-category: envoy.matching.input_matchers] + core.v3.TypedExtensionConfig custom_match = 3; + } + } + + // A list of two or more matchers. Used to allow using a list within a oneof. + message PredicateList { + repeated Predicate predicate = 1 [(validate.rules).repeated = {min_items: 2}]; + } + + oneof match_type { + option (validate.required) = true; + + // A single predicate to evaluate. + SinglePredicate single_predicate = 1; + + // A list of predicates to be OR-ed together. + PredicateList or_matcher = 2; + + // A list of predicates to be AND-ed together. + PredicateList and_matcher = 3; + + // The invert of a predicate + Predicate not_matcher = 4; + } + } + + // An individual matcher. + message FieldMatcher { + // Determines if the match succeeds. + Predicate predicate = 1 [(validate.rules).message = {required: true}]; + + // What to do if the match succeeds. + OnMatch on_match = 2 [(validate.rules).message = {required: true}]; + } + + // A list of matchers. First match wins. + repeated FieldMatcher matchers = 1 [(validate.rules).repeated = {min_items: 1}]; + } + + message MatcherTree { + // A map of configured matchers. Used to allow using a map within a oneof. + message MatchMap { + map map = 1 [(validate.rules).map = {min_pairs: 1}]; + } + + // Protocol-specific specification of input field to match on. + core.v3.TypedExtensionConfig input = 1 [(validate.rules).message = {required: true}]; + + // Exact or prefix match maps in which to look up the input value. + // If the lookup succeeds, the match is considered successful, and + // the corresponding OnMatch is used. + oneof tree_type { + option (validate.required) = true; + + MatchMap exact_match_map = 2; + + // Longest matching prefix wins. + MatchMap prefix_match_map = 3; + + // Extension for custom matching logic. + core.v3.TypedExtensionConfig custom_match = 4; + } + } + + oneof matcher_type { + // A linear list of matchers to evaluate. + MatcherList matcher_list = 1; + + // A match tree to evaluate. + MatcherTree matcher_tree = 2; + } + + // Optional OnMatch to use if no matcher above matched (e.g., if there are no matchers specified + // above, or if none of the matches specified above succeeded). + // If no matcher above matched and this field is not populated, the match will be considered unsuccessful. + OnMatch on_no_match = 3; +} diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/regex.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/regex.proto new file mode 100644 index 00000000000..3ff4ca95c2c --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/regex.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package xds.type.matcher.v3; + +import "validate/validate.proto"; + +option java_package = "com.github.xds.type.matcher.v3"; +option java_outer_classname = "RegexProto"; +option java_multiple_files = true; +option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; + +// [#protodoc-title: Regex matcher] + +// A regex matcher designed for safety when used with untrusted input. +message RegexMatcher { + // Google's `RE2 `_ regex engine. The regex + // string must adhere to the documented `syntax + // `_. The engine is designed to + // complete execution in linear time as well as limit the amount of memory + // used. + // + // Envoy supports program size checking via runtime. The runtime keys + // `re2.max_program_size.error_level` and `re2.max_program_size.warn_level` + // can be set to integers as the maximum program size or complexity that a + // compiled regex can have before an exception is thrown or a warning is + // logged, respectively. `re2.max_program_size.error_level` defaults to 100, + // and `re2.max_program_size.warn_level` has no default if unset (will not + // check/log a warning). + // + // Envoy emits two stats for tracking the program size of regexes: the + // histogram `re2.program_size`, which records the program size, and the + // counter `re2.exceeded_warn_level`, which is incremented each time the + // program size exceeds the warn level threshold. + message GoogleRE2 {} + + oneof engine_type { + option (validate.required) = true; + + // Google's RE2 regex engine. + GoogleRE2 google_re2 = 1 [ (validate.rules).message = {required : true} ]; + } + + // The regex match string. The string must be supported by the configured + // engine. + string regex = 2 [ (validate.rules).string = {min_len : 1} ]; +} diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto new file mode 100644 index 00000000000..fdc598e174a --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto @@ -0,0 +1,66 @@ +syntax = "proto3"; + +package xds.type.matcher.v3; + +import "xds/type/matcher/v3/regex.proto"; + +import "validate/validate.proto"; + +option java_package = "com.github.xds.type.matcher.v3"; +option java_outer_classname = "StringProto"; +option java_multiple_files = true; +option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; + +// [#protodoc-title: String matcher] + +// Specifies the way to match a string. +// [#next-free-field: 8] +message StringMatcher { + oneof match_pattern { + option (validate.required) = true; + + // The input string must match exactly the string specified here. + // + // Examples: + // + // * *abc* only matches the value *abc*. + string exact = 1; + + // The input string must have the prefix specified here. + // Note: empty prefix is not allowed, please use regex instead. + // + // Examples: + // + // * *abc* matches the value *abc.xyz* + string prefix = 2 [(validate.rules).string = {min_len: 1}]; + + // The input string must have the suffix specified here. + // Note: empty prefix is not allowed, please use regex instead. + // + // Examples: + // + // * *abc* matches the value *xyz.abc* + string suffix = 3 [(validate.rules).string = {min_len: 1}]; + + // The input string must match the regular expression specified here. + RegexMatcher safe_regex = 5 [(validate.rules).message = {required: true}]; + + // The input string must have the substring specified here. + // Note: empty contains match is not allowed, please use regex instead. + // + // Examples: + // + // * *abc* matches the value *xyz.abc.def* + string contains = 7 [(validate.rules).string = {min_len: 1}]; + } + + // If true, indicates the exact/prefix/suffix matching should be case insensitive. This has no + // effect for the safe_regex match. + // For example, the matcher *data* will match both input string *Data* and *data* if set to true. + bool ignore_case = 6; +} + +// Specifies a list of ways to match a string. +message ListStringMatcher { + repeated StringMatcher patterns = 1 [(validate.rules).repeated = {min_items: 1}]; +}

Homepage: