Skip to content

Commit 1775ab3

Browse files
committed
core: call newStream() and applyRequestMetadata() under context.
`ClientTransport.newStream()` and `CallCredentials.applyRequestMetadata()` is now called under the context of the call. This can be used to pass any call-specific information to `CallCredentials`.
1 parent aa33c59 commit 1775ab3

File tree

6 files changed

+145
-8
lines changed

6 files changed

+145
-8
lines changed

core/src/main/java/io/grpc/CallCredentials.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ public interface CallCredentials {
6060
* Pass the credential data to the given {@link MetadataApplier}, which will propagate it to
6161
* the request metadata.
6262
*
63-
* <p>It is called for each individual RPC, before the stream is about to be created on a
64-
* transport. Implementations should not block in this method. If metadata is not immediately
65-
* available, e.g., needs to be fetched from network, the implementation may give the {@code
66-
* applier} to an asynchronous task which will eventually call the {@code applier}. The RPC
67-
* proceeds only after the {@code applier} is called.
63+
* <p>It is called for each individual RPC, within the {@link Context} of the call, before the
64+
* stream is about to be created on a transport. Implementations should not block in this
65+
* method. If metadata is not immediately available, e.g., needs to be fetched from network, the
66+
* implementation may give the {@code applier} to an asynchronous task which will eventually call
67+
* the {@code applier}. The RPC proceeds only after the {@code applier} is called.
6868
*
6969
* @param method The method descriptor of this RPC
7070
* @param attrs Additional attributes from the transport, along with the keys defined in this

core/src/main/java/io/grpc/internal/ClientCallImpl.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,12 @@ public void runInContext() {
211211
updateTimeoutHeaders(effectiveDeadline, callOptions.getDeadline(),
212212
context.getDeadline(), headers);
213213
ClientTransport transport = clientTransportProvider.get(callOptions);
214-
stream = transport.newStream(method, headers, callOptions);
214+
Context origContext = context.attach();
215+
try {
216+
stream = transport.newStream(method, headers, callOptions);
217+
} finally {
218+
context.detach(origContext);
219+
}
215220
} else {
216221
stream = new FailingClientStream(DEADLINE_EXCEEDED);
217222
}

core/src/main/java/io/grpc/internal/ClientTransport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ public interface ClientTransport {
5555
* the error information. Any sent messages for this stream will be buffered until creation has
5656
* completed (either successfully or unsuccessfully).
5757
*
58+
* <p>This method is called under the {@link io.grpc.Context} of the {@link io.grpc.ClientCall}.
59+
*
5860
* @param method the descriptor of the remote method to be called for this stream.
5961
* @param headers to send at the beginning of the call
6062
* @param callOptions runtime options of the call

core/src/main/java/io/grpc/internal/DelayedClientTransport.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.common.base.Suppliers;
3838

3939
import io.grpc.CallOptions;
40+
import io.grpc.Context;
4041
import io.grpc.Metadata;
4142
import io.grpc.MethodDescriptor;
4243
import io.grpc.Status;
@@ -370,16 +371,25 @@ private class PendingStream extends DelayedStream {
370371
private final MethodDescriptor<?, ?> method;
371372
private final Metadata headers;
372373
private final CallOptions callOptions;
374+
private final Context context;
373375

374376
private PendingStream(MethodDescriptor<?, ?> method, Metadata headers,
375377
CallOptions callOptions) {
376378
this.method = method;
377379
this.headers = headers;
378380
this.callOptions = callOptions;
381+
this.context = Context.current();
379382
}
380383

381384
private void createRealStream(ClientTransport transport) {
382-
setStream(transport.newStream(method, headers, callOptions));
385+
ClientStream realStream;
386+
Context origContext = context.attach();
387+
try {
388+
realStream = transport.newStream(method, headers, callOptions);
389+
} finally {
390+
context.detach(origContext);
391+
}
392+
setStream(realStream);
383393
}
384394

385395
@Override

core/src/main/java/io/grpc/internal/MetadataApplierImpl.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import io.grpc.CallCredentials.MetadataApplier;
3939
import io.grpc.CallOptions;
40+
import io.grpc.Context;
4041
import io.grpc.Metadata;
4142
import io.grpc.MethodDescriptor;
4243
import io.grpc.Status;
@@ -49,6 +50,7 @@ final class MetadataApplierImpl implements MetadataApplier {
4950
private final MethodDescriptor<?, ?> method;
5051
private final Metadata origHeaders;
5152
private final CallOptions callOptions;
53+
private final Context ctx;
5254

5355
private final Object lock = new Object();
5456

@@ -69,14 +71,22 @@ final class MetadataApplierImpl implements MetadataApplier {
6971
this.method = method;
7072
this.origHeaders = origHeaders;
7173
this.callOptions = callOptions;
74+
this.ctx = Context.current();
7275
}
7376

7477
@Override
7578
public void apply(Metadata headers) {
7679
checkState(!finalized, "apply() or fail() already called");
7780
checkNotNull(headers, "headers");
7881
origHeaders.merge(headers);
79-
finalizeWith(transport.newStream(method, origHeaders, callOptions));
82+
ClientStream realStream;
83+
Context origCtx = ctx.attach();
84+
try {
85+
realStream = transport.newStream(method, origHeaders, callOptions);
86+
} finally {
87+
ctx.detach(origCtx);
88+
}
89+
finalizeWith(realStream);
8090
}
8191

8292
@Override

core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import static org.junit.Assert.assertEquals;
3535
import static org.junit.Assert.assertFalse;
3636
import static org.junit.Assert.assertNotNull;
37+
import static org.junit.Assert.assertNull;
3738
import static org.junit.Assert.assertSame;
3839
import static org.junit.Assert.assertTrue;
3940
import static org.mockito.Matchers.any;
@@ -53,12 +54,15 @@
5354
import static org.mockito.Mockito.when;
5455

5556
import io.grpc.Attributes;
57+
import io.grpc.CallCredentials.MetadataApplier;
58+
import io.grpc.CallCredentials;
5659
import io.grpc.CallOptions;
5760
import io.grpc.Channel;
5861
import io.grpc.ClientCall;
5962
import io.grpc.ClientInterceptor;
6063
import io.grpc.Compressor;
6164
import io.grpc.CompressorRegistry;
65+
import io.grpc.Context;
6266
import io.grpc.DecompressorRegistry;
6367
import io.grpc.DummyLoadBalancerFactory;
6468
import io.grpc.IntegerMarshaller;
@@ -67,6 +71,7 @@
6771
import io.grpc.MethodDescriptor;
6872
import io.grpc.NameResolver;
6973
import io.grpc.ResolvedServerInfo;
74+
import io.grpc.SecurityLevel;
7075
import io.grpc.Status;
7176
import io.grpc.StringMarshaller;
7277
import io.grpc.TransportManager;
@@ -91,8 +96,10 @@
9196
import java.util.ArrayList;
9297
import java.util.Arrays;
9398
import java.util.Collections;
99+
import java.util.LinkedList;
94100
import java.util.List;
95101
import java.util.concurrent.CyclicBarrier;
102+
import java.util.concurrent.Executor;
96103
import java.util.concurrent.ScheduledExecutorService;
97104
import java.util.concurrent.TimeUnit;
98105
import java.util.concurrent.atomic.AtomicLong;
@@ -136,6 +143,8 @@ public class ManagedChannelImplTest {
136143
private ClientCall.Listener<Integer> mockCallListener3;
137144
@Mock
138145
private SharedResourceHolder.Resource<ScheduledExecutorService> timerService;
146+
@Mock
147+
private CallCredentials creds;
139148

140149
private ArgumentCaptor<ManagedClientTransport.Listener> transportListenerCaptor =
141150
ArgumentCaptor.forClass(ManagedClientTransport.Listener.class);
@@ -813,6 +822,107 @@ public void uriPattern() {
813822
assertFalse(ManagedChannelImpl.URI_PATTERN.matcher(" a:/").matches()); // space not matched
814823
}
815824

825+
/**
826+
* Test that information such as the Call's context, MethodDescriptor, authority, executor are
827+
* propagated to newStream() and applyRequestMetadata().
828+
*/
829+
@Test
830+
public void informationPropagatedToNewStreamAndCallCredentials() {
831+
createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR);
832+
Metadata headers = new Metadata();
833+
CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(creds);
834+
final Context.Key<String> testKey = Context.key("testing");
835+
Context ctx = Context.current().withValue(testKey, "testValue");
836+
final LinkedList<Context> credsApplyContexts = new LinkedList<Context>();
837+
final LinkedList<Context> newStreamContexts = new LinkedList<Context>();
838+
doAnswer(new Answer<Void>() {
839+
@Override
840+
public Void answer(InvocationOnMock in) throws Throwable {
841+
credsApplyContexts.add(Context.current());
842+
return null;
843+
}
844+
}).when(creds).applyRequestMetadata(
845+
any(MethodDescriptor.class), any(Attributes.class), any(Executor.class),
846+
any(MetadataApplier.class));
847+
848+
final ConnectionClientTransport transport = mock(ConnectionClientTransport.class);
849+
when(transport.getAttrs()).thenReturn(Attributes.EMPTY);
850+
when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class),
851+
any(String.class))).thenReturn(transport);
852+
doAnswer(new Answer<ClientStream>() {
853+
@Override
854+
public ClientStream answer(InvocationOnMock in) throws Throwable {
855+
newStreamContexts.add(Context.current());
856+
return mock(ClientStream.class);
857+
}
858+
}).when(transport).newStream(
859+
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
860+
861+
// First call will be on delayed transport. Only newCall() is run within the expected context,
862+
// so that we can verify that the context is explicitly attached before calling newStream() and
863+
// applyRequestMetadata(), which happens after we detach the context from the thread.
864+
Context origCtx = ctx.attach();
865+
assertEquals("testValue", testKey.get());
866+
ClientCall<String, Integer> call = channel.newCall(method, callOptions);
867+
ctx.detach(origCtx);
868+
assertNull(testKey.get());
869+
call.start(mockCallListener, new Metadata());
870+
871+
ArgumentCaptor<ManagedClientTransport.Listener> transportListenerCaptor =
872+
ArgumentCaptor.forClass(ManagedClientTransport.Listener.class);
873+
verify(mockTransportFactory).newClientTransport(
874+
same(socketAddress), eq(authority), eq(userAgent));
875+
verify(transport).start(transportListenerCaptor.capture());
876+
verify(creds, never()).applyRequestMetadata(
877+
any(MethodDescriptor.class), any(Attributes.class), any(Executor.class),
878+
any(MetadataApplier.class));
879+
880+
// applyRequestMetadata() is called after the transport becomes ready.
881+
transportListenerCaptor.getValue().transportReady();
882+
executor.runDueTasks();
883+
ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(Attributes.class);
884+
ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(MetadataApplier.class);
885+
verify(creds).applyRequestMetadata(same(method), attrsCaptor.capture(),
886+
same(executor.scheduledExecutorService), applierCaptor.capture());
887+
assertEquals("testValue", testKey.get(credsApplyContexts.poll()));
888+
assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY));
889+
assertEquals(SecurityLevel.NONE,
890+
attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL));
891+
verify(transport, never()).newStream(
892+
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
893+
894+
// newStream() is called after apply() is called
895+
applierCaptor.getValue().apply(new Metadata());
896+
verify(transport).newStream(same(method), any(Metadata.class), same(callOptions));
897+
assertEquals("testValue", testKey.get(newStreamContexts.poll()));
898+
// The context should not live beyond the scope of newStream() and applyRequestMetadata()
899+
assertNull(testKey.get());
900+
901+
902+
// Second call will not be on delayed transport
903+
origCtx = ctx.attach();
904+
call = channel.newCall(method, callOptions);
905+
ctx.detach(origCtx);
906+
call.start(mockCallListener, new Metadata());
907+
908+
verify(creds, times(2)).applyRequestMetadata(same(method), attrsCaptor.capture(),
909+
same(executor.scheduledExecutorService), applierCaptor.capture());
910+
assertEquals("testValue", testKey.get(credsApplyContexts.poll()));
911+
assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY));
912+
assertEquals(SecurityLevel.NONE,
913+
attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL));
914+
// This is from the first call
915+
verify(transport).newStream(
916+
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
917+
918+
// Still, newStream() is called after apply() is called
919+
applierCaptor.getValue().apply(new Metadata());
920+
verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions));
921+
assertEquals("testValue", testKey.get(newStreamContexts.poll()));
922+
923+
assertNull(testKey.get());
924+
}
925+
816926
private static class FakeBackoffPolicyProvider implements BackoffPolicy.Provider {
817927
@Override
818928
public BackoffPolicy get() {

0 commit comments

Comments
 (0)