Skip to content

Commit 5324509

Browse files
authored
core: detect invalid states on server side (eg zero responses for unary) (grpc#3068)
The current check in ServerCallImpl is theoretically unsafe (grpc#3059). Move that check into the stub, and expand the unit tests to cover other interesting edge cases on the server side: client sends one, but zero requests received at onHalfClose client sends one, but > 1 requests received at onHalfClose server sends one, but zero responses sent at onComplete server sends one, but > 1 responses sent via onNext fixes grpc#2243 fixes grpc#3059
1 parent d42a4b2 commit 5324509

6 files changed

Lines changed: 297 additions & 105 deletions

File tree

core/src/main/java/io/grpc/internal/ServerCallImpl.java

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@
3434
import io.grpc.InternalDecompressorRegistry;
3535
import io.grpc.Metadata;
3636
import io.grpc.MethodDescriptor;
37-
import io.grpc.MethodDescriptor.MethodType;
3837
import io.grpc.ServerCall;
3938
import io.grpc.Status;
4039
import java.io.IOException;
4140
import java.io.InputStream;
4241
import java.util.List;
4342

4443
final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
44+
45+
@VisibleForTesting
46+
static String TOO_MANY_RESPONSES = "Too many responses";
47+
@VisibleForTesting
48+
static String MISSING_RESPONSE = "Completed without a response";
49+
4550
private final ServerStream stream;
4651
private final MethodDescriptor<ReqT, RespT> method;
4752
private final Context.CancellableContext context;
@@ -54,6 +59,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
5459
private boolean sendHeadersCalled;
5560
private boolean closeCalled;
5661
private Compressor compressor;
62+
private boolean messageSent;
5763

5864
ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method,
5965
Metadata inboundHeaders, Context.CancellableContext context,
@@ -115,6 +121,13 @@ public void sendHeaders(Metadata headers) {
115121
public void sendMessage(RespT message) {
116122
checkState(sendHeadersCalled, "sendHeaders has not been called");
117123
checkState(!closeCalled, "call is closed");
124+
125+
if (method.getType().serverSendsOneMessage() && messageSent) {
126+
internalClose(Status.INTERNAL.withDescription(TOO_MANY_RESPONSES));
127+
return;
128+
}
129+
130+
messageSent = true;
118131
try {
119132
InputStream resp = method.streamResponse(message);
120133
stream.writeMessage(resp);
@@ -151,6 +164,12 @@ public boolean isReady() {
151164
public void close(Status status, Metadata trailers) {
152165
checkState(!closeCalled, "call already closed");
153166
closeCalled = true;
167+
168+
if (status.isOk() && method.getType().serverSendsOneMessage() && !messageSent) {
169+
internalClose(Status.INTERNAL.withDescription(MISSING_RESPONSE));
170+
return;
171+
}
172+
154173
stream.close(status, trailers);
155174
}
156175

@@ -178,6 +197,15 @@ public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
178197
return method;
179198
}
180199

200+
/**
201+
* Close the {@link ServerStream} because an internal error occurred. Allow the application to
202+
* run until completion, but silently ignore interactions with the {@link ServerStream} from now
203+
* on.
204+
*/
205+
private void internalClose(Status internalError) {
206+
stream.close(internalError, new Metadata());
207+
}
208+
181209
/**
182210
* All of these callbacks are assumed to called on an application thread, and the caller is
183211
* responsible for handling thrown exceptions.
@@ -187,7 +215,6 @@ static final class ServerStreamListenerImpl<ReqT> implements ServerStreamListene
187215
private final ServerCallImpl<ReqT, ?> call;
188216
private final ServerCall.Listener<ReqT> listener;
189217
private final Context.CancellableContext context;
190-
private boolean messageReceived;
191218

192219
public ServerStreamListenerImpl(
193220
ServerCallImpl<ReqT, ?> call, ServerCall.Listener<ReqT> listener,
@@ -216,15 +243,6 @@ public void messageRead(final InputStream message) {
216243
if (call.cancelled) {
217244
return;
218245
}
219-
// Special case for unary calls.
220-
if (messageReceived && call.method.getType() == MethodType.UNARY) {
221-
call.stream.close(Status.INTERNAL.withDescription(
222-
"More than one request messages for unary call or server streaming call"),
223-
new Metadata());
224-
return;
225-
}
226-
messageReceived = true;
227-
228246
listener.onMessage(call.method.parseRequest(message));
229247
} catch (Throwable e) {
230248
t = e;

core/src/main/java/io/grpc/internal/ServerStream.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ public interface ServerStream extends Stream {
4242
* {@link io.grpc.Status.Code#OK} implies normal termination of the
4343
* stream. Any other value implies abnormal termination.
4444
*
45+
* <p>Attempts to read from or write to the stream after closing
46+
* should be ignored by implementations, and should not throw
47+
* exceptions.
48+
*
4549
* @param status details of the closure
4650
* @param trailers an additional block of metadata to pass to the client on stream closure.
4751
*/

core/src/test/java/io/grpc/internal/ServerCallImplTest.java

Lines changed: 131 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import static org.junit.Assert.fail;
2424
import static org.mockito.Matchers.any;
2525
import static org.mockito.Matchers.isA;
26+
import static org.mockito.Matchers.same;
2627
import static org.mockito.Mockito.doThrow;
28+
import static org.mockito.Mockito.never;
29+
import static org.mockito.Mockito.times;
2730
import static org.mockito.Mockito.verify;
2831
import static org.mockito.Mockito.when;
2932

@@ -48,35 +51,41 @@
4851
import org.junit.runner.RunWith;
4952
import org.junit.runners.JUnit4;
5053
import org.mockito.ArgumentCaptor;
51-
import org.mockito.Captor;
5254
import org.mockito.Mock;
53-
import org.mockito.Mockito;
5455
import org.mockito.MockitoAnnotations;
5556

5657
@RunWith(JUnit4.class)
5758
public class ServerCallImplTest {
5859
@Rule public final ExpectedException thrown = ExpectedException.none();
5960
@Mock private ServerStream stream;
6061
@Mock private ServerCall.Listener<Long> callListener;
61-
@Captor private ArgumentCaptor<Status> statusCaptor;
6262

6363
private ServerCallImpl<Long, Long> call;
6464
private Context.CancellableContext context;
6565

66-
private final MethodDescriptor<Long, Long> method = MethodDescriptor.<Long, Long>newBuilder()
67-
.setType(MethodType.UNARY)
68-
.setFullMethodName("/service/method")
69-
.setRequestMarshaller(new LongMarshaller())
70-
.setResponseMarshaller(new LongMarshaller())
71-
.build();
66+
private static final MethodDescriptor<Long, Long> UNARY_METHOD =
67+
MethodDescriptor.<Long, Long>newBuilder()
68+
.setType(MethodType.UNARY)
69+
.setFullMethodName("/service/method")
70+
.setRequestMarshaller(new LongMarshaller())
71+
.setResponseMarshaller(new LongMarshaller())
72+
.build();
73+
74+
private static final MethodDescriptor<Long, Long> CLIENT_STREAMING_METHOD =
75+
MethodDescriptor.<Long, Long>newBuilder()
76+
.setType(MethodType.UNARY)
77+
.setFullMethodName("/service/method")
78+
.setRequestMarshaller(new LongMarshaller())
79+
.setResponseMarshaller(new LongMarshaller())
80+
.build();
7281

7382
private final Metadata requestHeaders = new Metadata();
7483

7584
@Before
7685
public void setUp() {
7786
MockitoAnnotations.initMocks(this);
7887
context = Context.ROOT.withCancellation();
79-
call = new ServerCallImpl<Long, Long>(stream, method, requestHeaders, context,
88+
call = new ServerCallImpl<Long, Long>(stream, UNARY_METHOD, requestHeaders, context,
8089
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance());
8190
}
8291

@@ -158,6 +167,114 @@ public void sendMessage_closesOnFailure() {
158167
verify(stream).close(isA(Status.class), isA(Metadata.class));
159168
}
160169

170+
@Test
171+
public void sendMessage_serverSendsOne_closeOnSecondCall_unary() {
172+
sendMessage_serverSendsOne_closeOnSecondCall(UNARY_METHOD);
173+
}
174+
175+
@Test
176+
public void sendMessage_serverSendsOne_closeOnSecondCall_clientStreaming() {
177+
sendMessage_serverSendsOne_closeOnSecondCall(CLIENT_STREAMING_METHOD);
178+
}
179+
180+
private void sendMessage_serverSendsOne_closeOnSecondCall(
181+
MethodDescriptor<Long, Long> method) {
182+
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
183+
stream,
184+
method,
185+
requestHeaders,
186+
context,
187+
DecompressorRegistry.getDefaultInstance(),
188+
CompressorRegistry.getDefaultInstance());
189+
serverCall.sendHeaders(new Metadata());
190+
serverCall.sendMessage(1L);
191+
verify(stream, times(1)).writeMessage(any(InputStream.class));
192+
verify(stream, never()).close(any(Status.class), any(Metadata.class));
193+
194+
// trying to send a second message causes gRPC to close the underlying stream
195+
serverCall.sendMessage(1L);
196+
verify(stream, times(1)).writeMessage(any(InputStream.class));
197+
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
198+
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
199+
verify(stream, times(1)).close(statusCaptor.capture(), metadataCaptor.capture());
200+
assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
201+
assertEquals(ServerCallImpl.TOO_MANY_RESPONSES, statusCaptor.getValue().getDescription());
202+
assertTrue(metadataCaptor.getValue().keys().isEmpty());
203+
}
204+
205+
@Test
206+
public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_unary() {
207+
sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(UNARY_METHOD);
208+
}
209+
210+
@Test
211+
public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_clientStreaming() {
212+
sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(CLIENT_STREAMING_METHOD);
213+
}
214+
215+
private void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(
216+
MethodDescriptor<Long, Long> method) {
217+
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
218+
stream,
219+
method,
220+
requestHeaders,
221+
context,
222+
DecompressorRegistry.getDefaultInstance(),
223+
CompressorRegistry.getDefaultInstance());
224+
serverCall.sendHeaders(new Metadata());
225+
serverCall.sendMessage(1L);
226+
serverCall.sendMessage(1L);
227+
verify(stream, times(1)).writeMessage(any(InputStream.class));
228+
verify(stream, times(1)).close(any(Status.class), any(Metadata.class));
229+
230+
// App runs to completion but everything is ignored
231+
serverCall.sendMessage(1L);
232+
serverCall.close(Status.OK, new Metadata());
233+
try {
234+
serverCall.close(Status.OK, new Metadata());
235+
fail("calling a second time should still cause an error");
236+
} catch (IllegalStateException expected) {
237+
// noop
238+
}
239+
}
240+
241+
@Test
242+
public void serverSendsOne_okFailsOnMissingResponse_unary() {
243+
serverSendsOne_okFailsOnMissingResponse(UNARY_METHOD);
244+
}
245+
246+
@Test
247+
public void serverSendsOne_okFailsOnMissingResponse_clientStreaming() {
248+
serverSendsOne_okFailsOnMissingResponse(CLIENT_STREAMING_METHOD);
249+
}
250+
251+
private void serverSendsOne_okFailsOnMissingResponse(
252+
MethodDescriptor<Long, Long> method) {
253+
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
254+
stream,
255+
method,
256+
requestHeaders,
257+
context,
258+
DecompressorRegistry.getDefaultInstance(),
259+
CompressorRegistry.getDefaultInstance());
260+
serverCall.close(Status.OK, new Metadata());
261+
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
262+
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
263+
verify(stream, times(1)).close(statusCaptor.capture(), metadataCaptor.capture());
264+
assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
265+
assertEquals(ServerCallImpl.MISSING_RESPONSE, statusCaptor.getValue().getDescription());
266+
assertTrue(metadataCaptor.getValue().keys().isEmpty());
267+
}
268+
269+
@Test
270+
public void serverSendsOne_canErrorWithoutResponse() {
271+
final String description = "test description";
272+
final Status status = Status.RESOURCE_EXHAUSTED.withDescription(description);
273+
final Metadata metadata = new Metadata();
274+
call.close(status, metadata);
275+
verify(stream, times(1)).close(same(status), same(metadata));
276+
}
277+
161278
@Test
162279
public void isReady() {
163280
when(stream.isReady()).thenReturn(true);
@@ -260,34 +377,20 @@ public void streamListener_onReady_onlyOnce() {
260377
public void streamListener_messageRead() {
261378
ServerStreamListenerImpl<Long> streamListener =
262379
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
263-
streamListener.messageRead(method.streamRequest(1234L));
264-
265-
verify(callListener).onMessage(1234L);
266-
}
267-
268-
@Test
269-
public void streamListener_messageRead_unaryFailsOnMultiple() {
270-
ServerStreamListenerImpl<Long> streamListener =
271-
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
272-
streamListener.messageRead(method.streamRequest(1234L));
273-
streamListener.messageRead(method.streamRequest(1234L));
380+
streamListener.messageRead(UNARY_METHOD.streamRequest(1234L));
274381

275-
// Makes sure this was only called once.
276382
verify(callListener).onMessage(1234L);
277-
278-
verify(stream).close(statusCaptor.capture(), Mockito.isA(Metadata.class));
279-
assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
280383
}
281384

282385
@Test
283386
public void streamListener_messageRead_onlyOnce() {
284387
ServerStreamListenerImpl<Long> streamListener =
285388
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
286-
streamListener.messageRead(method.streamRequest(1234L));
389+
streamListener.messageRead(UNARY_METHOD.streamRequest(1234L));
287390
// canceling the call should short circuit future halfClosed() calls.
288391
streamListener.closed(Status.CANCELLED);
289392

290-
streamListener.messageRead(method.streamRequest(1234L));
393+
streamListener.messageRead(UNARY_METHOD.streamRequest(1234L));
291394

292395
verify(callListener).onMessage(1234L);
293396
}
@@ -300,7 +403,7 @@ public void streamListener_unexpectedRuntimeException() {
300403
.when(callListener)
301404
.onMessage(any(Long.class));
302405

303-
InputStream inputStream = method.streamRequest(1234L);
406+
InputStream inputStream = UNARY_METHOD.streamRequest(1234L);
304407

305408
thrown.expect(RuntimeException.class);
306409
thrown.expectMessage("unexpected exception");

examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -300,52 +300,6 @@ public void onCompleted() {
300300
verify(testHelper, never()).onRpcError(any(Throwable.class));
301301
}
302302

303-
/**
304-
* Example for testing async client-streaming.
305-
*/
306-
@Test
307-
public void recordRoute_wrongResponse() throws Exception {
308-
client.setRandom(noRandomness);
309-
Point point1 = Point.newBuilder().setLatitude(1).setLongitude(1).build();
310-
final Feature requestFeature1 =
311-
Feature.newBuilder().setLocation(point1).build();
312-
final List<Feature> features = Arrays.asList(requestFeature1);
313-
314-
// implement the fake service
315-
RouteGuideImplBase recordRouteImpl =
316-
new RouteGuideImplBase() {
317-
@Override
318-
public StreamObserver<Point> recordRoute(StreamObserver<RouteSummary> responseObserver) {
319-
RouteSummary response = RouteSummary.getDefaultInstance();
320-
// sending more than one responses is not right for client-streaming call.
321-
responseObserver.onNext(response);
322-
responseObserver.onNext(response);
323-
responseObserver.onCompleted();
324-
325-
return new StreamObserver<Point>() {
326-
@Override
327-
public void onNext(Point value) {
328-
}
329-
330-
@Override
331-
public void onError(Throwable t) {
332-
}
333-
334-
@Override
335-
public void onCompleted() {
336-
}
337-
};
338-
}
339-
};
340-
serviceRegistry.addService(recordRouteImpl);
341-
342-
client.recordRoute(features, 4);
343-
344-
ArgumentCaptor<Throwable> errorCaptor = ArgumentCaptor.forClass(Throwable.class);
345-
verify(testHelper).onRpcError(errorCaptor.capture());
346-
assertEquals(Status.Code.CANCELLED, Status.fromThrowable(errorCaptor.getValue()).getCode());
347-
}
348-
349303
/**
350304
* Example for testing async client-streaming.
351305
*/

0 commit comments

Comments
 (0)