|
34 | 34 | import static org.junit.Assert.assertEquals; |
35 | 35 | import static org.junit.Assert.assertFalse; |
36 | 36 | import static org.junit.Assert.assertNotNull; |
| 37 | +import static org.junit.Assert.assertNull; |
37 | 38 | import static org.junit.Assert.assertSame; |
38 | 39 | import static org.junit.Assert.assertTrue; |
39 | 40 | import static org.mockito.Matchers.any; |
|
53 | 54 | import static org.mockito.Mockito.when; |
54 | 55 |
|
55 | 56 | import io.grpc.Attributes; |
| 57 | +import io.grpc.CallCredentials.MetadataApplier; |
| 58 | +import io.grpc.CallCredentials; |
56 | 59 | import io.grpc.CallOptions; |
57 | 60 | import io.grpc.Channel; |
58 | 61 | import io.grpc.ClientCall; |
59 | 62 | import io.grpc.ClientInterceptor; |
60 | 63 | import io.grpc.Compressor; |
61 | 64 | import io.grpc.CompressorRegistry; |
| 65 | +import io.grpc.Context; |
62 | 66 | import io.grpc.DecompressorRegistry; |
63 | 67 | import io.grpc.DummyLoadBalancerFactory; |
64 | 68 | import io.grpc.IntegerMarshaller; |
|
67 | 71 | import io.grpc.MethodDescriptor; |
68 | 72 | import io.grpc.NameResolver; |
69 | 73 | import io.grpc.ResolvedServerInfo; |
| 74 | +import io.grpc.SecurityLevel; |
70 | 75 | import io.grpc.Status; |
71 | 76 | import io.grpc.StringMarshaller; |
72 | 77 | import io.grpc.TransportManager; |
|
91 | 96 | import java.util.ArrayList; |
92 | 97 | import java.util.Arrays; |
93 | 98 | import java.util.Collections; |
| 99 | +import java.util.LinkedList; |
94 | 100 | import java.util.List; |
95 | 101 | import java.util.concurrent.CyclicBarrier; |
| 102 | +import java.util.concurrent.Executor; |
96 | 103 | import java.util.concurrent.ScheduledExecutorService; |
97 | 104 | import java.util.concurrent.TimeUnit; |
98 | 105 | import java.util.concurrent.atomic.AtomicLong; |
@@ -136,6 +143,8 @@ public class ManagedChannelImplTest { |
136 | 143 | private ClientCall.Listener<Integer> mockCallListener3; |
137 | 144 | @Mock |
138 | 145 | private SharedResourceHolder.Resource<ScheduledExecutorService> timerService; |
| 146 | + @Mock |
| 147 | + private CallCredentials creds; |
139 | 148 |
|
140 | 149 | private ArgumentCaptor<ManagedClientTransport.Listener> transportListenerCaptor = |
141 | 150 | ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); |
@@ -813,6 +822,107 @@ public void uriPattern() { |
813 | 822 | assertFalse(ManagedChannelImpl.URI_PATTERN.matcher(" a:/").matches()); // space not matched |
814 | 823 | } |
815 | 824 |
|
| 825 | + /** |
| 826 | + * Test that information such as the Call's context, MethodDescriptor, authority, executor are |
| 827 | + * propagated to newStream() and applyRequestMetadata(). |
| 828 | + */ |
| 829 | + @Test |
| 830 | + public void informationPropagatedToNewStreamAndCallCredentials() { |
| 831 | + createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR); |
| 832 | + Metadata headers = new Metadata(); |
| 833 | + CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(creds); |
| 834 | + final Context.Key<String> testKey = Context.key("testing"); |
| 835 | + Context ctx = Context.current().withValue(testKey, "testValue"); |
| 836 | + final LinkedList<Context> credsApplyContexts = new LinkedList<Context>(); |
| 837 | + final LinkedList<Context> newStreamContexts = new LinkedList<Context>(); |
| 838 | + doAnswer(new Answer<Void>() { |
| 839 | + @Override |
| 840 | + public Void answer(InvocationOnMock in) throws Throwable { |
| 841 | + credsApplyContexts.add(Context.current()); |
| 842 | + return null; |
| 843 | + } |
| 844 | + }).when(creds).applyRequestMetadata( |
| 845 | + any(MethodDescriptor.class), any(Attributes.class), any(Executor.class), |
| 846 | + any(MetadataApplier.class)); |
| 847 | + |
| 848 | + final ConnectionClientTransport transport = mock(ConnectionClientTransport.class); |
| 849 | + when(transport.getAttrs()).thenReturn(Attributes.EMPTY); |
| 850 | + when(mockTransportFactory.newClientTransport(any(SocketAddress.class), any(String.class), |
| 851 | + any(String.class))).thenReturn(transport); |
| 852 | + doAnswer(new Answer<ClientStream>() { |
| 853 | + @Override |
| 854 | + public ClientStream answer(InvocationOnMock in) throws Throwable { |
| 855 | + newStreamContexts.add(Context.current()); |
| 856 | + return mock(ClientStream.class); |
| 857 | + } |
| 858 | + }).when(transport).newStream( |
| 859 | + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); |
| 860 | + |
| 861 | + // First call will be on delayed transport. Only newCall() is run within the expected context, |
| 862 | + // so that we can verify that the context is explicitly attached before calling newStream() and |
| 863 | + // applyRequestMetadata(), which happens after we detach the context from the thread. |
| 864 | + Context origCtx = ctx.attach(); |
| 865 | + assertEquals("testValue", testKey.get()); |
| 866 | + ClientCall<String, Integer> call = channel.newCall(method, callOptions); |
| 867 | + ctx.detach(origCtx); |
| 868 | + assertNull(testKey.get()); |
| 869 | + call.start(mockCallListener, new Metadata()); |
| 870 | + |
| 871 | + ArgumentCaptor<ManagedClientTransport.Listener> transportListenerCaptor = |
| 872 | + ArgumentCaptor.forClass(ManagedClientTransport.Listener.class); |
| 873 | + verify(mockTransportFactory).newClientTransport( |
| 874 | + same(socketAddress), eq(authority), eq(userAgent)); |
| 875 | + verify(transport).start(transportListenerCaptor.capture()); |
| 876 | + verify(creds, never()).applyRequestMetadata( |
| 877 | + any(MethodDescriptor.class), any(Attributes.class), any(Executor.class), |
| 878 | + any(MetadataApplier.class)); |
| 879 | + |
| 880 | + // applyRequestMetadata() is called after the transport becomes ready. |
| 881 | + transportListenerCaptor.getValue().transportReady(); |
| 882 | + executor.runDueTasks(); |
| 883 | + ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(Attributes.class); |
| 884 | + ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(MetadataApplier.class); |
| 885 | + verify(creds).applyRequestMetadata(same(method), attrsCaptor.capture(), |
| 886 | + same(executor.scheduledExecutorService), applierCaptor.capture()); |
| 887 | + assertEquals("testValue", testKey.get(credsApplyContexts.poll())); |
| 888 | + assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY)); |
| 889 | + assertEquals(SecurityLevel.NONE, |
| 890 | + attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); |
| 891 | + verify(transport, never()).newStream( |
| 892 | + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); |
| 893 | + |
| 894 | + // newStream() is called after apply() is called |
| 895 | + applierCaptor.getValue().apply(new Metadata()); |
| 896 | + verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); |
| 897 | + assertEquals("testValue", testKey.get(newStreamContexts.poll())); |
| 898 | + // The context should not live beyond the scope of newStream() and applyRequestMetadata() |
| 899 | + assertNull(testKey.get()); |
| 900 | + |
| 901 | + |
| 902 | + // Second call will not be on delayed transport |
| 903 | + origCtx = ctx.attach(); |
| 904 | + call = channel.newCall(method, callOptions); |
| 905 | + ctx.detach(origCtx); |
| 906 | + call.start(mockCallListener, new Metadata()); |
| 907 | + |
| 908 | + verify(creds, times(2)).applyRequestMetadata(same(method), attrsCaptor.capture(), |
| 909 | + same(executor.scheduledExecutorService), applierCaptor.capture()); |
| 910 | + assertEquals("testValue", testKey.get(credsApplyContexts.poll())); |
| 911 | + assertEquals(authority, attrsCaptor.getValue().get(CallCredentials.ATTR_AUTHORITY)); |
| 912 | + assertEquals(SecurityLevel.NONE, |
| 913 | + attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); |
| 914 | + // This is from the first call |
| 915 | + verify(transport).newStream( |
| 916 | + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); |
| 917 | + |
| 918 | + // Still, newStream() is called after apply() is called |
| 919 | + applierCaptor.getValue().apply(new Metadata()); |
| 920 | + verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); |
| 921 | + assertEquals("testValue", testKey.get(newStreamContexts.poll())); |
| 922 | + |
| 923 | + assertNull(testKey.get()); |
| 924 | + } |
| 925 | + |
816 | 926 | private static class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { |
817 | 927 | @Override |
818 | 928 | public BackoffPolicy get() { |
|
0 commit comments