From 188158f59ad5d1c6836e934b6d7248a55cdafbcf Mon Sep 17 00:00:00 2001 From: Blake Li Date: Fri, 15 May 2026 21:49:07 +0000 Subject: [PATCH] feat(gax): implement dynamic channel refreshing on 401 retries --- .../com/google/api/gax/grpc/ChannelPool.java | 8 +++ .../google/api/gax/grpc/GrpcCallContext.java | 58 +++++++++++++------ .../api/gax/grpc/GrpcTransportChannel.java | 8 +++ .../api/gax/retrying/BasicRetryingFuture.java | 5 ++ .../api/gax/retrying/RetryingFuture.java | 8 +++ .../retrying/ScheduledRetryingExecutor.java | 27 +++++++++ .../google/api/gax/rpc/ApiCallContext.java | 8 +++ .../api/gax/rpc/ApiResultRetryAlgorithm.java | 8 +++ .../google/api/gax/rpc/TransportChannel.java | 8 +++ 9 files changed, 121 insertions(+), 17 deletions(-) diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index d611c96ff4c8..d35dbc8d12ca 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -82,6 +82,7 @@ class ChannelPool extends ManagedChannel { private ScheduledFuture resizeFuture = null; private final Object entryWriteLock = new Object(); + private long lastRefreshTimeNanos = 0; @VisibleForTesting final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; @@ -441,6 +442,13 @@ void refresh() { // - then thread2 will shut down channel that thread1 will put back into circulation (after it // replaces the list) synchronized (entryWriteLock) { + long now = System.nanoTime(); + if (now - lastRefreshTimeNanos < TimeUnit.SECONDS.toNanos(5)) { + LOG.fine("Channel pool was refreshed recently, skipping duplicate refresh"); + return; + } + lastRefreshTimeNanos = now; + LOG.fine("Refreshing all channels"); ArrayList newEntries = new ArrayList<>(entries.get()); diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 7ff7c54de6f0..fb5e2edb0d07 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -97,6 +97,7 @@ public final class GrpcCallContext implements ApiCallContext { private final ApiCallContextOptions options; private final EndpointContext endpointContext; private final boolean isDirectPath; + @Nullable private final TransportChannel transportChannel; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { @@ -113,7 +114,8 @@ public static GrpcCallContext createDefault() { null, null, null, - false); + false, + null); } /** Returns an instance with the given channel and {@link CallOptions}. */ @@ -131,7 +133,8 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { null, null, null, - false); + false, + null); } private GrpcCallContext( @@ -147,7 +150,8 @@ private GrpcCallContext( @Nullable RetrySettings retrySettings, @Nullable Set retryableCodes, @Nullable EndpointContext endpointContext, - boolean isDirectPath) { + boolean isDirectPath, + @Nullable TransportChannel transportChannel) { this.channel = channel; this.credentials = credentials; Preconditions.checkNotNull(callOptions); @@ -167,6 +171,7 @@ private GrpcCallContext( this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; this.isDirectPath = isDirectPath; + this.transportChannel = transportChannel; } /** @@ -208,7 +213,13 @@ public GrpcCallContext withCredentials(Credentials newCredentials) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); + } + + @Override + public TransportChannel getTransportChannel() { + return transportChannel; } @Override @@ -232,7 +243,8 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { retrySettings, retryableCodes, endpointContext, - transportChannel.isDirectPath()); + transportChannel.isDirectPath(), + inputChannel); } @Override @@ -251,7 +263,8 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -286,7 +299,8 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -335,7 +349,8 @@ public GrpcCallContext withStreamWaitTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** @@ -370,7 +385,8 @@ public GrpcCallContext withStreamIdleTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for channel affinity is not stable yet and may change in the future.") @@ -388,7 +404,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -410,7 +427,8 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -433,7 +451,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -456,7 +475,8 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -558,7 +578,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newRetrySettings, newRetryableCodes, endpointContext, - newIsDirectPath); + newIsDirectPath, + transportChannel); } /** The {@link Channel} set on this context. */ @@ -641,7 +662,8 @@ public GrpcCallContext withChannel(Channel newChannel) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** Returns a new instance with the call options set to the given call options. */ @@ -659,7 +681,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) { @@ -704,7 +727,8 @@ public GrpcCallContext withOption(Key key, T value) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** {@inheritDoc} */ diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java index 2fa0908f17bc..80d471701d5a 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java @@ -66,6 +66,14 @@ public Channel getChannel() { return getManagedChannel(); } + @Override + public void refresh() { + Channel channel = getChannel(); + if (channel instanceof ChannelPool) { + ((ChannelPool) channel).refresh(); + } + } + @Override public void shutdown() { getManagedChannel().shutdown(); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/BasicRetryingFuture.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/BasicRetryingFuture.java index ccf1bfe11c17..8a466bf2f041 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/BasicRetryingFuture.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/BasicRetryingFuture.java @@ -116,6 +116,11 @@ public TimedAttemptSettings getAttemptSettings() { } } + @Override + public RetryingContext getRetryingContext() { + return retryingContext; + } + @Override public ApiFuture peekAttemptResult() { synchronized (lock) { diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/RetryingFuture.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/RetryingFuture.java index 86b16eac6ee0..c677abcdf943 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/RetryingFuture.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/RetryingFuture.java @@ -128,4 +128,12 @@ public interface RetryingFuture extends ApiFuture { * */ ApiFuture getAttemptResult(); + + /** + * Returns the retrying context associated with this future, or {@code null} if none is set. + */ + @com.google.api.core.BetaApi("The surface for passing per operation state is not yet stable") + default RetryingContext getRetryingContext() { + return null; + } } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/ScheduledRetryingExecutor.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/ScheduledRetryingExecutor.java index c796ebd0900e..8ea99c01cb9a 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/ScheduledRetryingExecutor.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/retrying/ScheduledRetryingExecutor.java @@ -111,6 +111,10 @@ public RetryingFuture createFuture( */ @Override public ApiFuture submit(RetryingFuture retryingFuture) { + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + checkForFailedChannelRefresh(retryingFuture); + } + try { ListenableFuture attemptFuture = scheduler.schedule( @@ -122,4 +126,27 @@ public ApiFuture submit(RetryingFuture retryingFuture) { return ApiFutures.immediateFailedFuture(e); } } + + private void checkForFailedChannelRefresh(RetryingFuture retryingFuture) { + ApiFuture lastAttemptResult = retryingFuture.peekAttemptResult(); + if (lastAttemptResult != null && lastAttemptResult.isDone()) { + try { + lastAttemptResult.get(); + } catch (java.util.concurrent.ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof com.google.api.gax.rpc.UnauthenticatedException) { + RetryingContext context = retryingFuture.getRetryingContext(); + if (context instanceof com.google.api.gax.rpc.ApiCallContext) { + com.google.api.gax.rpc.TransportChannel transportChannel = + ((com.google.api.gax.rpc.ApiCallContext) context).getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } + } catch (Exception ignored) { + // Ignore cancellations or interruptions + } + } + } } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java index 09af475e4833..fc7fb5e989fe 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java @@ -63,6 +63,14 @@ public interface ApiCallContext extends RetryingContext { /** Returns a new ApiCallContext with the given channel set. */ ApiCallContext withTransportChannel(TransportChannel channel); + /** + * Returns the {@link TransportChannel} associated with this call context, or {@code null} if none + * is set. + */ + default TransportChannel getTransportChannel() { + return null; + } + /** Returns a new ApiCallContext with the given Endpoint Context. */ ApiCallContext withEndpointContext(EndpointContext endpointContext); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java index 688fc32cd14b..7c8fad8497e9 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java @@ -38,6 +38,10 @@ class ApiResultRetryAlgorithm extends BasicResultRetryAlgorithmBy default, this is a no-op for transports that do not require stateful connection lifecycle + * management. + */ + default void refresh() {} }