diff --git a/src/main/java/org/xbill/DNS/DohResolver.java b/src/main/java/org/xbill/DNS/DohResolver.java index 0933ef18..23dc871a 100644 --- a/src/main/java/org/xbill/DNS/DohResolver.java +++ b/src/main/java/org/xbill/DNS/DohResolver.java @@ -174,6 +174,11 @@ public final class DohResolver implements Resolver { USE_HTTP_CLIENT = initSuccess; } + // package-visible for testing + long getNanoTime() { + return System.nanoTime(); + } + /** * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). * @@ -315,7 +320,7 @@ public CompletionStage sendAsync(Message query, Executor executor) { private CompletionStage sendAsync8(final Message query, Executor executor) { byte[] queryBytes = prepareQuery(query).toWire(); String url = getUrl(queryBytes); - long startTime = System.nanoTime(); + long startTime = getNanoTime(); return maxConcurrentRequests .acquire(timeout) .handleAsync( @@ -363,7 +368,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( ((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory); } - Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); conn.setConnectTimeout((int) remainingTimeout.toMillis()); conn.setReadTimeout((int) remainingTimeout.toMillis()); conn.setRequestMethod(usePost ? "POST" : "GET"); @@ -389,7 +394,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( int offset = 0; while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) { offset += r; - remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); + remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { throw new SocketTimeoutException(); } @@ -403,7 +408,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( byte[] buffer = new byte[4096]; int r; while ((r = is.read(buffer, 0, buffer.length)) > 0) { - remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); + remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { throw new SocketTimeoutException(); } @@ -432,7 +437,7 @@ private void discardStream(InputStream es) throws IOException { } private CompletionStage sendAsync11(final Message query, Executor executor) { - long startTime = System.nanoTime(); + long startTime = getNanoTime(); byte[] queryBytes = prepareQuery(query).toWire(); String url = getUrl(queryBytes); @@ -454,7 +459,7 @@ private CompletionStage sendAsync11(final Message query, Executor execu // check if this request needs to be done synchronously because of HttpClient's stupidity to // not use the connection pool for HTTP/2 until one connection is successfully established, // which could lead to hundreds of connections (and threads with the default executor) - Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); return initialRequestLock .acquire(remainingTimeout) .handle( @@ -476,14 +481,13 @@ private CompletionStage sendAsync11WithInitialRequestPermit( Object requestBuilder, Permit initialRequestPermit) { long lastRequestTime = lastRequest.get(); - boolean isInitialRequest = - (lastRequestTime < System.nanoTime() - idleConnectionTimeout.toNanos()); + boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime; if (!isInitialRequest) { initialRequestPermit.release(); } // check if we already exceeded the query timeout while checking the initial connection - Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { if (isInitialRequest) { initialRequestPermit.release(); @@ -525,7 +529,7 @@ private CompletionStage sendAsync11WithConcurrentRequestPermit( boolean isInitialRequest, Permit maxConcurrentRequestPermit) { // check if the stream lock acquisition took too long - Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { if (isInitialRequest) { initialRequestPermit.release(); diff --git a/src/test/java/org/xbill/DNS/DohResolverTest.java b/src/test/java/org/xbill/DNS/DohResolverTest.java index c3931d63..26daf19c 100644 --- a/src/test/java/org/xbill/DNS/DohResolverTest.java +++ b/src/test/java/org/xbill/DNS/DohResolverTest.java @@ -3,6 +3,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; import io.netty.handler.codec.http.HttpHeaderNames; import io.vertx.core.Future; @@ -20,13 +22,20 @@ import java.time.Duration; import java.util.Base64; import java.util.Collections; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.stubbing.Answer; @ExtendWith(VertxExtension.class) class DohResolverTest { @@ -202,6 +211,75 @@ private Future setupResolverWithServer( .onSuccess(server -> resolver.setUriTemplate("http://localhost:" + server.actualPort())); } + @EnabledForJreRange( + min = JRE.JAVA_9, + disabledReason = "Java 8 implementation doesn't have the initial request guard") + @Test + void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( + Vertx vertx, VertxTestContext context) { + AtomicLong startNanos = new AtomicLong(System.nanoTime()); + resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2))); + resolver.setTimeout(Duration.ofSeconds(1)); + // Simulate a nanoTime value that is lower than the idle timeout + doAnswer((Answer) invocationOnMock -> System.nanoTime() - startNanos.get()) + .when(resolver) + .getNanoTime(); + + // Just add a 100ms delay before responding to the 1st call + // to simulate a 'concurrent doh request' for the 2nd call, + // then let the fake dns server respond to the 2nd call ASAP. + allRequestsUseTimeout = false; + + // idleConnectionTimeout = 2s, lastRequest = 0L + // Ensure idleConnectionTimeout < System.nanoTime() - lastRequest (3s) + + // Timeline: + // |<-------- 100ms -------->| + // ↑ ↑ + // 1st call sent response of 1st call + // |20ms|<------ 80ms ------>|<------ few millis ------->| + // ↑ wait until 1st call ↑ ↑ + // 2nd call begin 2nd call sent response of 2nd call + + AtomicBoolean firstCallCompleted = new AtomicBoolean(false); + + setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context) + .onSuccess( + server -> { + // First call + CompletionStage firstCall = resolver.sendAsync(qm); + // Ensure second call was made after first call and uses a different query + startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20)); + CompletionStage secondCall = resolver.sendAsync(Message.newQuery(qr)); + + Future.fromCompletionStage(firstCall) + .onComplete( + context.succeeding( + result -> + context.verify( + () -> { + assertEquals(Rcode.NOERROR, result.getHeader().getRcode()); + assertEquals(0, result.getHeader().getID()); + assertEquals(queryName, result.getQuestion().getName()); + firstCallCompleted.set(true); + }))); + + Future.fromCompletionStage(secondCall) + .onComplete( + context.succeeding( + result -> + context.verify( + () -> { + assertTrue(firstCallCompleted.get()); + assertEquals(Rcode.NOERROR, result.getHeader().getRcode()); + assertEquals(0, result.getHeader().getID()); + assertEquals(queryName, result.getQuestion().getName()); + // Complete context after the 2nd call was completed. + context.completeNow(); + }))); + }); + } + private Future setupServer( Message expectedDnsRequest, Message dnsResponse,