From 8d54d90398e2ffa7288d1465e6583cc81c3fc6fd Mon Sep 17 00:00:00 2001 From: Ingo Bauersachs Date: Sun, 22 Jun 2025 22:19:29 +0200 Subject: [PATCH] Split DoH Resolver into Java versions, fix initial request and tests --- pom.xml | 137 ++-- .../java/org/xbill/DNS/AsyncSemaphore.java | 27 +- src/main/java/org/xbill/DNS/DohResolver.java | 636 +++--------------- .../java/org/xbill/DNS/DohResolverCommon.java | 232 +++++++ .../xbill/DNS/TimeoutCompletableFuture.java | 47 +- .../java11/org/xbill/DNS/AsyncSemaphore.java | 66 ++ .../java11/org/xbill/DNS/DohResolver.java | 311 +++++++++ .../java/org/xbill/DNS/DohResolverTest.java | 103 +-- 8 files changed, 902 insertions(+), 657 deletions(-) create mode 100644 src/main/java/org/xbill/DNS/DohResolverCommon.java create mode 100644 src/main/java11/org/xbill/DNS/AsyncSemaphore.java create mode 100644 src/main/java11/org/xbill/DNS/DohResolver.java diff --git a/pom.xml b/pom.xml index 1f941117..bf65a05e 100644 --- a/pom.xml +++ b/pom.xml @@ -371,6 +371,12 @@ true PATCH + + SUPERCLASS_ADDED + true + true + PATCH + ANNOTATION_DEPRECATED_ADDED PATCH @@ -497,41 +503,6 @@ - - org.codehaus.mojo - animal-sniffer-maven-plugin - 1.24 - - - net.sf.androidscents.signature - android-api-level-26 - 8.0.0_r2 - - - javax.naming.NamingException - javax.naming.directory.* - sun.net.spi.nameservice.* - java.net.spi.* - - - - - org.ow2.asm - asm - 9.7.1 - - - - - animal-sniffer - test - - check - - - - - org.apache.maven.plugins maven-enforcer-plugin @@ -719,6 +690,60 @@ ${target.jdk} + + + org.codehaus.mojo + animal-sniffer-maven-plugin + 1.24 + + + com.toasttab.android + gummy-bears-api-26 + 0.12.0 + + + javax.naming.NamingException + javax.naming.directory.* + sun.net.spi.nameservice.* + java.net.spi.* + + + + + org.ow2.asm + asm + 9.8 + + + + + animal-sniffer + test + + check + + + + + + + org.jacoco + jacoco-maven-plugin + + + report + verify + + report + + + + META-INF/** + + + + + @@ -815,18 +840,38 @@ @{argLine} --add-opens java.base/sun.net.dns=ALL-UNNAMED + + + false + ${project.build.outputDirectory}/META-INF/versions/11 - ${project.build.outputDirectory}/META-INF/versions/11 + ${project.build.outputDirectory} + + + org.jacoco + jacoco-maven-plugin + + + org/xbill/DNS/AsyncSemaphore* + org/xbill/DNS/DohResolver* + + + - - + java11-not-idea false @@ -910,10 +955,18 @@ @{argLine} --add-opens java.base/sun.net.dns=ALL-UNNAMED -javaagent:${net.bytebuddy:byte-buddy-agent:jar} + -javaagent:${org.mockito:mockito-core:jar} + + + false + ${project.build.outputDirectory}/META-INF/versions/18 ${project.build.outputDirectory}/META-INF/versions/11 - ${project.build.outputDirectory}/META-INF/versions/18 + ${project.build.outputDirectory} @@ -922,8 +975,10 @@ - - + java18-not-idea false diff --git a/src/main/java/org/xbill/DNS/AsyncSemaphore.java b/src/main/java/org/xbill/DNS/AsyncSemaphore.java index 456b8ad3..0d704913 100644 --- a/src/main/java/org/xbill/DNS/AsyncSemaphore.java +++ b/src/main/java/org/xbill/DNS/AsyncSemaphore.java @@ -6,6 +6,7 @@ import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import lombok.extern.slf4j.Slf4j; @@ -13,34 +14,50 @@ final class AsyncSemaphore { private final Queue> queue = new ArrayDeque<>(); private final Permit singletonPermit = new Permit(); + private final String name; private volatile int permits; final class Permit { - public void release() { + public void release(int id, Executor executor) { synchronized (queue) { CompletableFuture next = queue.poll(); if (next == null) { permits++; + log.trace("{} permit released id={}, available={}", name, id, permits); } else { - next.complete(this); + log.trace("{} permit released id={}, available={}, immediate next", name, id, permits); + executor.execute(() -> next.complete(this)); } } } } - AsyncSemaphore(int permits) { + AsyncSemaphore(int permits, String name) { this.permits = permits; + this.name = name; + log.debug("Using Java 8 implementation for {}", name); } - CompletionStage acquire(Duration timeout) { + CompletionStage acquire(Duration timeout, int id, Executor executor) { synchronized (queue) { if (permits > 0) { permits--; + log.trace("{} permit acquired id={}, available={}", name, id, permits); return CompletableFuture.completedFuture(singletonPermit); } else { TimeoutCompletableFuture f = new TimeoutCompletableFuture<>(); f.compatTimeout(timeout.toNanos(), TimeUnit.NANOSECONDS) - .whenComplete((result, ex) -> queue.remove(f)); + .whenCompleteAsync( + (result, ex) -> { + synchronized (queue) { + if (ex != null) { + log.trace("{} permit timed out id={}, available={}", name, id, permits); + } + queue.remove(f); + } + }, + executor); + log.trace("{} permit queued id={}, available={}", name, id, permits); queue.add(f); return f; } diff --git a/src/main/java/org/xbill/DNS/DohResolver.java b/src/main/java/org/xbill/DNS/DohResolver.java index b6707bbd..fa150cc0 100644 --- a/src/main/java/org/xbill/DNS/DohResolver.java +++ b/src/main/java/org/xbill/DNS/DohResolver.java @@ -5,8 +5,6 @@ import java.io.EOFException; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.net.HttpURLConnection; import java.net.SocketTimeoutException; import java.net.URI; @@ -14,25 +12,19 @@ import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.time.temporal.ChronoUnit; -import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.WeakHashMap; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.TimeoutException; import java.util.function.Function; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; -import lombok.SneakyThrows; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.xbill.DNS.AsyncSemaphore.Permit; -import org.xbill.DNS.utils.base64; /** * Proof-of-concept DNS over HTTP (DoH) @@ -51,141 +43,16 @@ * @since 3.0 */ @Slf4j -public final class DohResolver implements Resolver { - private static final boolean USE_HTTP_CLIENT; - private static final Map httpClients = - Collections.synchronizedMap(new WeakHashMap<>()); +public final class DohResolver extends DohResolverCommon { private final SSLSocketFactory sslSocketFactory; - private static Object defaultHttpRequestBuilder; - private static Method publisherOfByteArrayMethod; - private static Method requestBuilderTimeoutMethod; - private static Method requestBuilderCopyMethod; - private static Method requestBuilderUriMethod; - private static Method requestBuilderBuildMethod; - private static Method requestBuilderPostMethod; - - private static Method httpClientNewBuilderMethod; - private static Method httpClientBuilderTimeoutMethod; - private static Method httpClientBuilderExecutorMethod; - private static Method httpClientBuilderBuildMethod; - private static Method httpClientSendAsyncMethod; - - private static Method byteArrayBodyPublisherMethod; - private static Method httpResponseBodyMethod; - private static Method httpResponseStatusCodeMethod; - - private boolean usePost = false; - private Duration timeout = Duration.ofSeconds(5); - private String uriTemplate; - private final Duration idleConnectionTimeout; - private OPTRecord queryOPT = new OPTRecord(0, 0, 0); - private TSIG tsig; - private Executor defaultExecutor = ForkJoinPool.commonPool(); - - /** - * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections. - * - *

rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2. - */ - private final AsyncSemaphore maxConcurrentRequests; - - private final AtomicLong lastRequest = new AtomicLong(0); - private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1); - - private static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; - - static { - boolean initSuccess = false; - if (!System.getProperty("java.version").startsWith("1.")) { - try { - Class httpClientBuilderClass = Class.forName("java.net.http.HttpClient$Builder"); - Class httpClientClass = Class.forName("java.net.http.HttpClient"); - Class httpVersionEnum = Class.forName("java.net.http.HttpClient$Version"); - Class httpRequestBuilderClass = Class.forName("java.net.http.HttpRequest$Builder"); - Class httpRequestClass = Class.forName("java.net.http.HttpRequest"); - Class bodyPublishersClass = Class.forName("java.net.http.HttpRequest$BodyPublishers"); - Class bodyPublisherClass = Class.forName("java.net.http.HttpRequest$BodyPublisher"); - Class httpResponseClass = Class.forName("java.net.http.HttpResponse"); - Class bodyHandlersClass = Class.forName("java.net.http.HttpResponse$BodyHandlers"); - Class bodyHandlerClass = Class.forName("java.net.http.HttpResponse$BodyHandler"); - - // HttpClient.Builder - httpClientBuilderTimeoutMethod = - httpClientBuilderClass.getDeclaredMethod("connectTimeout", Duration.class); - httpClientBuilderExecutorMethod = - httpClientBuilderClass.getDeclaredMethod("executor", Executor.class); - httpClientBuilderBuildMethod = httpClientBuilderClass.getDeclaredMethod("build"); - - // HttpClient - httpClientNewBuilderMethod = httpClientClass.getDeclaredMethod("newBuilder"); - httpClientSendAsyncMethod = - httpClientClass.getDeclaredMethod("sendAsync", httpRequestClass, bodyHandlerClass); - - // HttpRequestBuilder - Method requestBuilderHeaderMethod = - httpRequestBuilderClass.getDeclaredMethod("header", String.class, String.class); - Method requestBuilderVersionMethod = - httpRequestBuilderClass.getDeclaredMethod("version", httpVersionEnum); - requestBuilderTimeoutMethod = - httpRequestBuilderClass.getDeclaredMethod("timeout", Duration.class); - requestBuilderUriMethod = httpRequestBuilderClass.getDeclaredMethod("uri", URI.class); - requestBuilderCopyMethod = httpRequestBuilderClass.getDeclaredMethod("copy"); - requestBuilderBuildMethod = httpRequestBuilderClass.getDeclaredMethod("build"); - requestBuilderPostMethod = - httpRequestBuilderClass.getDeclaredMethod("POST", bodyPublisherClass); - - // HttpRequest - Method requestBuilderNewBuilderMethod = httpRequestClass.getDeclaredMethod("newBuilder"); - - // BodyPublishers - publisherOfByteArrayMethod = - bodyPublishersClass.getDeclaredMethod("ofByteArray", byte[].class); - - // BodyPublisher - byteArrayBodyPublisherMethod = bodyHandlersClass.getDeclaredMethod("ofByteArray"); - - // HttpResponse - httpResponseBodyMethod = httpResponseClass.getDeclaredMethod("body"); - httpResponseStatusCodeMethod = httpResponseClass.getDeclaredMethod("statusCode"); - - // defaultHttpRequestBuilder = HttpRequest.newBuilder(); - // defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2); - // defaultHttpRequestBuilder.header("Content-Type", "application/dns-message"); - // defaultHttpRequestBuilder.header("Accept", "application/dns-message"); - defaultHttpRequestBuilder = requestBuilderNewBuilderMethod.invoke(null); - @SuppressWarnings({"unchecked", "rawtypes"}) - Enum http2Version = Enum.valueOf((Class) httpVersionEnum, "HTTP_2"); - requestBuilderVersionMethod.invoke(defaultHttpRequestBuilder, http2Version); - requestBuilderHeaderMethod.invoke( - defaultHttpRequestBuilder, "Content-Type", APPLICATION_DNS_MESSAGE); - requestBuilderHeaderMethod.invoke( - defaultHttpRequestBuilder, "Accept", APPLICATION_DNS_MESSAGE); - initSuccess = true; - } catch (ClassNotFoundException - | NoSuchMethodException - | IllegalAccessException - | InvocationTargetException e) { - // fallback to Java 8 - log.warn("Java >= 11 detected, but HttpRequest not available"); - } - } - - USE_HTTP_CLIENT = initSuccess; - } - - // package-visible for testing - long getNanoTime() { - return System.nanoTime(); - } - /** * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). * * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ public DohResolver(String uriTemplate) { - this(uriTemplate, 100, Duration.ofMinutes(2)); + this(uriTemplate, 100, Duration.ZERO); } /** @@ -201,22 +68,9 @@ public DohResolver(String uriTemplate) { */ public DohResolver( String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) { - this.uriTemplate = uriTemplate; - this.idleConnectionTimeout = idleConnectionTimeout; - if (maxConcurrentRequests <= 0) { - throw new IllegalArgumentException("maxConcurrentRequests must be > 0"); - } - if (!USE_HTTP_CLIENT) { - try { - int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5")); - if (maxConcurrentRequests > javaMaxConn) { - maxConcurrentRequests = javaMaxConn; - } - } catch (NumberFormatException nfe) { - // well, use what we got - } - } - this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests); + super(uriTemplate, maxConcurrentRequests); + + log.debug("Using Java 8 implementation"); try { sslSocketFactory = SSLContext.getDefault().getSocketFactory(); } catch (NoSuchAlgorithmException e) { @@ -224,45 +78,6 @@ public DohResolver( } } - @SneakyThrows - private Object getHttpClient(Executor executor) { - return httpClients.computeIfAbsent( - executor, - key -> { - try { - // return HttpClient.newBuilder() - // .connectTimeout(timeout). - // .executor(executor) - // .build(); - Object httpClientBuilder = httpClientNewBuilderMethod.invoke(null); - httpClientBuilderTimeoutMethod.invoke(httpClientBuilder, timeout); - httpClientBuilderExecutorMethod.invoke(httpClientBuilder, key); - return httpClientBuilderBuildMethod.invoke(httpClientBuilder); - } catch (IllegalAccessException | InvocationTargetException e) { - log.warn("Could not create a HttpClient with for Executor {}", key, e); - return null; - } - }); - } - - /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */ - @Override - public void setPort(int port) { - // Not implemented, port is part of the URI - } - - /** Not implemented. */ - @Override - public void setTCP(boolean flag) { - // Not implemented, HTTP is always TCP - } - - /** Not implemented. */ - @Override - public void setIgnoreTruncation(boolean flag) { - // Not implemented, protocol uses TCP and doesn't have truncation - } - /** * Sets the EDNS information on outgoing messages. * @@ -272,87 +87,85 @@ public void setIgnoreTruncation(boolean flag) { * @param options EDNS options to be set in the OPT record */ @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility public void setEDNS(int version, int payloadSize, int flags, List options) { - switch (version) { - case -1: - queryOPT = null; - break; - - case 0: - queryOPT = new OPTRecord(0, 0, version, flags, options); - break; - - default: - throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable"); - } - } - - @Override - public void setTSIGKey(TSIG key) { - this.tsig = key; - } - - @Override - public void setTimeout(Duration timeout) { - this.timeout = timeout; - httpClients.clear(); - } - - @Override - public Duration getTimeout() { - return timeout; + // required for source- and binary compatibility + super.setEDNS(version, payloadSize, flags, options); } @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility public CompletionStage sendAsync(Message query) { - return sendAsync(query, defaultExecutor); + // required for source- and binary compatibility + return this.sendAsync(query, defaultExecutor); } @Override public CompletionStage sendAsync(Message query, Executor executor) { - if (USE_HTTP_CLIENT) { - return sendAsync11(query, executor); - } - - return sendAsync8(query, executor); - } - - private CompletionStage sendAsync8(final Message query, Executor executor) { byte[] queryBytes = prepareQuery(query).toWire(); String url = getUrl(queryBytes); long startTime = getNanoTime(); - return maxConcurrentRequests - .acquire(timeout) - .handleAsync( - (permit, ex) -> { - if (ex != null) { - return this.timeoutFailedFuture(query, ex); - } else { - try { - SendAndGetMessageBytesResponse result = - sendAndGetMessageBytes(url, queryBytes, startTime); - Message response; - if (result.rc == Rcode.NOERROR) { - response = new Message(result.responseBytes); - verifyTSIG(query, response, result.responseBytes, tsig); + int queryId = query.getHeader().getID(); + + CompletableFuture f = + maxConcurrentRequests + .acquire(timeout, queryId, executor) + .handleAsync( + (permit, ex) -> { + if (ex != null) { + return this.timeoutFailedFuture( + query, "could not acquire lock to send request", ex); } else { - response = new Message(0); - response.getHeader().setRcode(result.rc); + try { + SendAndGetMessageBytesResponse result = + sendAndGetMessageBytes(url, queryBytes, startTime); + Message response; + if (result.rc == Rcode.NOERROR) { + response = new Message(result.responseBytes); + verifyTSIG(query, response, result.responseBytes, tsig); + } else { + response = new Message(0); + response.getHeader().setRcode(result.rc); + } + + response.setResolver(this); + return CompletableFuture.completedFuture(response); + } catch (SocketTimeoutException e) { + return this.timeoutFailedFuture(query, e); + } catch (IOException | URISyntaxException e) { + return this.failedFuture(e); + } finally { + permit.release(queryId, executor); + } } + }, + executor) + .thenCompose(Function.identity()) + .toCompletableFuture(); - response.setResolver(this); - return CompletableFuture.completedFuture(response); - } catch (SocketTimeoutException e) { - return this.timeoutFailedFuture(query, e); - } catch (IOException | URISyntaxException e) { - return this.failedFuture(e); - } finally { - permit.release(); - } + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + return TimeoutCompletableFuture.compatTimeout( + f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally( + ex -> { + if (ex instanceof TimeoutException) { + throw new CompletionException( + new TimeoutException( + "Query " + + queryId + + " for " + + query.getQuestion().getName() + + "/" + + Type.string(query.getQuestion().getType()) + + " timed out in remaining " + + remainingTimeout.toMillis() + + "ms")); + } else if (ex instanceof CompletionException) { + throw (CompletionException) ex; } - }, - executor) - .thenCompose(Function.identity()); + + throw new CompletionException(ex); + }); } @Value @@ -367,15 +180,28 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( if (conn instanceof HttpsURLConnection) { ((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory); } - - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - conn.setConnectTimeout((int) remainingTimeout.toMillis()); - conn.setReadTimeout((int) remainingTimeout.toMillis()); conn.setRequestMethod(usePost ? "POST" : "GET"); conn.setRequestProperty("Content-Type", APPLICATION_DNS_MESSAGE); conn.setRequestProperty("Accept", APPLICATION_DNS_MESSAGE); + + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + throw new SocketTimeoutException("No time left to connect"); + } + + conn.setConnectTimeout((int) remainingTimeout.toMillis()); if (usePost) { conn.setDoOutput(true); + } + + conn.connect(); + remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + throw new SocketTimeoutException("No time left to request data"); + } + + conn.setReadTimeout((int) remainingTimeout.toMillis()); + if (usePost) { conn.getOutputStream().write(queryBytes); } @@ -395,13 +221,23 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) { offset += r; remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - throw new SocketTimeoutException(); + + // Don't throw if we just received all data + if (offset != responseBytes.length + && (remainingTimeout.isNegative() || remainingTimeout.isZero())) { + throw new SocketTimeoutException( + "Timed out waiting for response data, got " + + offset + + " of " + + responseBytes.length + + " expected bytes"); } } + if (offset < responseBytes.length) { throw new EOFException("Could not read expected content length"); } + return new SendAndGetMessageBytesResponse(Rcode.NOERROR, responseBytes); } else { try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { @@ -409,8 +245,9 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( int r; while ((r = is.read(buffer, 0, buffer.length)) > 0) { remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - throw new SocketTimeoutException(); + if (remainingTimeout.isNegative() || remainingTimeout.isZero()) { + throw new SocketTimeoutException( + "Timed out waiting for response data, got " + bos.size() + " bytes so far"); } bos.write(buffer, 0, r); } @@ -436,275 +273,10 @@ private void discardStream(InputStream es) throws IOException { } } - private CompletionStage sendAsync11(final Message query, Executor executor) { - long startTime = getNanoTime(); - byte[] queryBytes = prepareQuery(query).toWire(); - String url = getUrl(queryBytes); - - // var requestBuilder = defaultHttpRequestBuilder.copy(); - // requestBuilder.uri(URI.create(url)); - Object requestBuilder; - try { - requestBuilder = requestBuilderCopyMethod.invoke(defaultHttpRequestBuilder); - requestBuilderUriMethod.invoke(requestBuilder, URI.create(url)); - if (usePost) { - // requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes)); - requestBuilderPostMethod.invoke( - requestBuilder, publisherOfByteArrayMethod.invoke(null, queryBytes)); - } - } catch (IllegalAccessException | InvocationTargetException e) { - return failedFuture(e); - } - - // check if this request needs to be done synchronously because of HttpClient's stupidity to - // not use the connection pool for HTTP/2 until one connection is successfully established, - // which could lead to hundreds of connections (and threads with the default executor) - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - return initialRequestLock - .acquire(remainingTimeout) - .handle( - (initialRequestPermit, initialRequestEx) -> { - if (initialRequestEx != null) { - return this.timeoutFailedFuture(query, initialRequestEx); - } else { - return sendAsync11WithInitialRequestPermit( - query, executor, startTime, requestBuilder, initialRequestPermit); - } - }) - .thenCompose(Function.identity()); - } - - private CompletionStage sendAsync11WithInitialRequestPermit( - Message query, - Executor executor, - long startTime, - Object requestBuilder, - Permit initialRequestPermit) { - long lastRequestTime = lastRequest.get(); - boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime; - if (!isInitialRequest) { - initialRequestPermit.release(); - } - - // check if we already exceeded the query timeout while checking the initial connection - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - if (isInitialRequest) { - initialRequestPermit.release(); - } - return timeoutFailedFuture(query, null); - } - - // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the - // request, but fail with an IOException which also CLOSES the connection... *facepalm* - return maxConcurrentRequests - .acquire(remainingTimeout) - .handle( - (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> { - if (maxConcurrentRequestEx != null) { - if (isInitialRequest) { - initialRequestPermit.release(); - } - return this.timeoutFailedFuture(query, maxConcurrentRequestEx); - } else { - return sendAsync11WithConcurrentRequestPermit( - query, - executor, - startTime, - requestBuilder, - initialRequestPermit, - isInitialRequest, - maxConcurrentRequestPermit); - } - }) - .thenCompose(Function.identity()); - } - - private CompletionStage sendAsync11WithConcurrentRequestPermit( - Message query, - Executor executor, - long startTime, - Object requestBuilder, - Permit initialRequestPermit, - boolean isInitialRequest, - Permit maxConcurrentRequestPermit) { - // check if the stream lock acquisition took too long - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - if (isInitialRequest) { - initialRequestPermit.release(); - } - maxConcurrentRequestPermit.release(); - return timeoutFailedFuture(query, null); - } - - // var httpRequest = requestBuilder.timeout(remainingTimeout).build(); - // var bodyHandler = HttpResponse.BodyHandlers.ofByteArray(); - // return getHttpClient(executor).sendAsync(httpRequest, bodyHandler) - try { - Object httpClient = getHttpClient(executor); - requestBuilderTimeoutMethod.invoke(requestBuilder, remainingTimeout); - Object httpRequest = requestBuilderBuildMethod.invoke(requestBuilder); - Object bodyHandler = byteArrayBodyPublisherMethod.invoke(null); - CompletableFuture f = - ((CompletableFuture) - httpClientSendAsyncMethod.invoke(httpClient, httpRequest, bodyHandler)) - .whenComplete( - (result, ex) -> { - if (ex == null) { - lastRequest.set(startTime); - } - maxConcurrentRequestPermit.release(); - if (isInitialRequest) { - initialRequestPermit.release(); - } - }) - .handleAsync( - (response, ex) -> { - if (ex != null) { - if (ex.getCause().getClass().getSimpleName().equals("HttpTimeoutException")) { - return this.timeoutFailedFuture(query, ex.getCause()); - } else { - return this.failedFuture(ex); - } - } else { - try { - Message responseMessage; - // int rc = response.statusCode(); - int rc = (int) httpResponseStatusCodeMethod.invoke(response); - if (rc >= 200 && rc < 300) { - // byte[] responseBytes = response.body(); - byte[] responseBytes = (byte[]) httpResponseBodyMethod.invoke(response); - responseMessage = new Message(responseBytes); - verifyTSIG(query, responseMessage, responseBytes, tsig); - } else { - responseMessage = new Message(); - responseMessage.getHeader().setRcode(Rcode.SERVFAIL); - } - - responseMessage.setResolver(this); - return CompletableFuture.completedFuture(responseMessage); - } catch (IOException | IllegalAccessException | InvocationTargetException e) { - return this.failedFuture(e); - } - } - }, - executor) - .thenCompose(Function.identity()); - return TimeoutCompletableFuture.compatTimeout( - f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS); - } catch (IllegalAccessException | InvocationTargetException e) { - return failedFuture(e); - } - } - - private CompletableFuture failedFuture(Throwable e) { + @Override + protected CompletableFuture failedFuture(Throwable e) { CompletableFuture f = new CompletableFuture<>(); f.completeExceptionally(e); return f; } - - private CompletableFuture timeoutFailedFuture(Message query, Throwable inner) { - return failedFuture( - new IOException( - "Query " - + query.getHeader().getID() - + " for " - + query.getQuestion().getName() - + "/" - + Type.string(query.getQuestion().getType()) - + " timed out", - inner)); - } - - private String getUrl(byte[] queryBytes) { - String url = uriTemplate; - if (!usePost) { - url += "?dns=" + base64.toString(queryBytes, true); - } - return url; - } - - private Message prepareQuery(Message query) { - Message preparedQuery = query.clone(); - preparedQuery.getHeader().setID(0); - if (queryOPT != null && preparedQuery.getOPT() == null) { - preparedQuery.addRecord(queryOPT, Section.ADDITIONAL); - } - - if (tsig != null) { - tsig.apply(preparedQuery, null); - } - - return preparedQuery; - } - - private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { - if (tsig == null) { - return; - } - - int error = tsig.verify(response, b, query.getGeneratedTSIG()); - log.debug( - "TSIG verify for query {}, {}/{}: {}", - query.getHeader().getID(), - query.getQuestion().getName(), - Type.string(query.getQuestion().getType()), - Rcode.TSIGstring(error)); - } - - /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */ - public boolean isUsePost() { - return usePost; - } - - /** - * Sets the HTTP method to use for resolving. - * - * @param usePost {@code true} to use POST, {@code false} to use GET (the default). - */ - public void setUsePost(boolean usePost) { - this.usePost = usePost; - } - - /** Gets the current URI used for resolving. */ - public String getUriTemplate() { - return uriTemplate; - } - - /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ - public void setUriTemplate(String uriTemplate) { - this.uriTemplate = uriTemplate; - } - - /** - * Gets the default {@link Executor} for request handling, defaults to {@link - * ForkJoinPool#commonPool()}. - * - * @since 3.3 - * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used. - */ - @Deprecated - public Executor getExecutor() { - return defaultExecutor; - } - - /** - * Sets the default {@link Executor} for request handling. - * - * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link - * ForkJoinPool#commonPool()}). - * @since 3.3 - * @deprecated Use {@link #sendAsync(Message, Executor)}. - */ - @Deprecated - public void setExecutor(Executor executor) { - this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor; - httpClients.clear(); - } - - @Override - public String toString() { - return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}"; - } } diff --git a/src/main/java/org/xbill/DNS/DohResolverCommon.java b/src/main/java/org/xbill/DNS/DohResolverCommon.java new file mode 100644 index 00000000..0fb8ac07 --- /dev/null +++ b/src/main/java/org/xbill/DNS/DohResolverCommon.java @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicLong; +import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.utils.base64; + +@Slf4j +abstract class DohResolverCommon implements Resolver { + /** + * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections. + * + *

rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2. + */ + protected final AsyncSemaphore maxConcurrentRequests; + + protected final AtomicLong lastRequest = new AtomicLong(0); + + protected static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; + + protected boolean usePost = false; + protected Duration timeout = Duration.ofSeconds(5); + protected String uriTemplate; + protected OPTRecord queryOPT = new OPTRecord(0, 0, 0); + protected TSIG tsig; + protected Executor defaultExecutor = ForkJoinPool.commonPool(); + + // package-visible for testing + long getNanoTime() { + return System.nanoTime(); + } + + /** + * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). + * + * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} + * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1 + * connections for Java 8. On Java 8 this cannot exceed the system property {@code + * http.maxConnections}. + */ + protected DohResolverCommon(String uriTemplate, int maxConcurrentRequests) { + this.uriTemplate = uriTemplate; + if (maxConcurrentRequests <= 0) { + throw new IllegalArgumentException("maxConcurrentRequests must be > 0"); + } + + try { + int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5")); + if (maxConcurrentRequests > javaMaxConn) { + maxConcurrentRequests = javaMaxConn; + } + } catch (NumberFormatException nfe) { + // well, use what we got + } + + this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests, "concurrent"); + } + + /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */ + @Override + public void setPort(int port) { + // Not implemented, port is part of the URI + } + + /** Not implemented. */ + @Override + public void setTCP(boolean flag) { + // Not implemented, HTTP is always TCP + } + + /** Not implemented. */ + @Override + public void setIgnoreTruncation(boolean flag) { + // Not implemented, protocol uses TCP and doesn't have truncation + } + + /** + * Sets the EDNS information on outgoing messages. + * + * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS. + * @param payloadSize ignored + * @param flags EDNS extended flags to be set in the OPT record. + * @param options EDNS options to be set in the OPT record + */ + @Override + public void setEDNS(int version, int payloadSize, int flags, List options) { + switch (version) { + case -1: + queryOPT = null; + break; + + case 0: + queryOPT = new OPTRecord(0, 0, version, flags, options); + break; + + default: + throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable"); + } + } + + @Override + public void setTSIGKey(TSIG key) { + this.tsig = key; + } + + @Override + public void setTimeout(Duration timeout) { + this.timeout = timeout; + } + + @Override + public Duration getTimeout() { + return timeout; + } + + protected String getUrl(byte[] queryBytes) { + String url = uriTemplate; + if (!usePost) { + url += "?dns=" + base64.toString(queryBytes, true); + } + return url; + } + + protected Message prepareQuery(Message query) { + Message preparedQuery = query.clone(); + preparedQuery.getHeader().setID(0); + if (queryOPT != null && preparedQuery.getOPT() == null) { + preparedQuery.addRecord(queryOPT, Section.ADDITIONAL); + } + + if (tsig != null) { + tsig.apply(preparedQuery, null); + } + + return preparedQuery; + } + + protected void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { + if (tsig == null) { + return; + } + + int error = tsig.verify(response, b, query.getGeneratedTSIG()); + log.debug( + "TSIG verify for query {}, {}/{}: {}", + query.getHeader().getID(), + query.getQuestion().getName(), + Type.string(query.getQuestion().getType()), + Rcode.TSIGstring(error)); + } + + /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */ + public boolean isUsePost() { + return usePost; + } + + /** + * Sets the HTTP method to use for resolving. + * + * @param usePost {@code true} to use POST, {@code false} to use GET (the default). + */ + public void setUsePost(boolean usePost) { + this.usePost = usePost; + } + + /** Gets the current URI used for resolving. */ + public String getUriTemplate() { + return uriTemplate; + } + + /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ + public void setUriTemplate(String uriTemplate) { + this.uriTemplate = uriTemplate; + } + + /** + * Gets the default {@link Executor} for request handling, defaults to {@link + * ForkJoinPool#commonPool()}. + * + * @since 3.3 + * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used. + */ + @Deprecated + public Executor getExecutor() { + return defaultExecutor; + } + + /** + * Sets the default {@link Executor} for request handling. + * + * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link + * ForkJoinPool#commonPool()}). + * @since 3.3 + * @deprecated Use {@link #sendAsync(Message, Executor)}. + */ + @Deprecated + public void setExecutor(Executor executor) { + this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor; + } + + @Override + public String toString() { + return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}"; + } + + protected abstract CompletableFuture failedFuture(Throwable e); + + protected final CompletableFuture timeoutFailedFuture(Message query, Throwable inner) { + return timeoutFailedFuture(query, null, inner); + } + + protected final CompletableFuture timeoutFailedFuture( + Message query, String message, Throwable inner) { + return failedFuture( + new TimeoutException( + "Query " + + query.getHeader().getID() + + " for " + + query.getQuestion().getName() + + "/" + + Type.string(query.getQuestion().getType()) + + " timed out" + + (message != null ? ": " + message : "") + + (inner != null && inner.getMessage() != null ? ", " + inner.getMessage() : ""))); + } +} diff --git a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java index 24796df6..e2ae60f6 100644 --- a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java +++ b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java @@ -1,8 +1,6 @@ // SPDX-License-Identifier: BSD-3-Clause package org.xbill.DNS; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -10,57 +8,28 @@ import java.util.concurrent.TimeoutException; import lombok.extern.slf4j.Slf4j; -/** - * Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. On Java 9+ - * the built-in method is called. - */ +/** Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. */ @Slf4j class TimeoutCompletableFuture extends CompletableFuture { - private static final Method orTimeoutMethod; - - static { - Method localOrTimeoutMethod; - if (!System.getProperty("java.version").startsWith("1.")) { - try { - localOrTimeoutMethod = - CompletableFuture.class.getMethod("orTimeout", long.class, TimeUnit.class); - } catch (NoSuchMethodException e) { - localOrTimeoutMethod = null; - log.warn( - "CompletableFuture.orTimeout method not found in Java 9+, using custom implementation", - e); - } - } else { - localOrTimeoutMethod = null; - } - orTimeoutMethod = localOrTimeoutMethod; - } - public CompletableFuture compatTimeout(long timeout, TimeUnit unit) { return compatTimeout(this, timeout, unit); } - @SuppressWarnings("unchecked") public static CompletableFuture compatTimeout( CompletableFuture f, long timeout, TimeUnit unit) { - if (orTimeoutMethod == null) { - return orTimeout(f, timeout, unit); - } else { - try { - return (CompletableFuture) orTimeoutMethod.invoke(f, timeout, unit); - } catch (IllegalAccessException | InvocationTargetException e) { - return orTimeout(f, timeout, unit); - } + if (timeout <= 0) { + f.completeExceptionally(new TimeoutException("timeout is " + timeout + ", but must be > 0")); } - } - private static CompletableFuture orTimeout( - CompletableFuture f, long timeout, TimeUnit unit) { ScheduledFuture sf = TimeoutScheduler.executor.schedule( () -> { if (!f.isDone()) { - f.completeExceptionally(new TimeoutException()); + f.completeExceptionally( + new TimeoutException( + "Timeout of " + + unit.toMillis(timeout) + + "ms has elapsed before the task completed")); } }, timeout, diff --git a/src/main/java11/org/xbill/DNS/AsyncSemaphore.java b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java new file mode 100644 index 00000000..d2dcf6dd --- /dev/null +++ b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +final class AsyncSemaphore { + private final Queue> queue = new ArrayDeque<>(); + private final Permit singletonPermit = new Permit(); + private final String name; + private volatile int permits; + + final class Permit { + public void release(int id, Executor executor) { + synchronized (queue) { + CompletableFuture next = queue.poll(); + if (next == null) { + permits++; + log.trace("{} permit released id={}, available={}", name, id, permits); + } else { + log.trace("{} permit released id={}, available={}, immediate next", name, id, permits); + next.completeAsync(() -> this, executor); + } + } + } + } + + AsyncSemaphore(int permits, String name) { + this.permits = permits; + this.name = name; + log.debug("Using Java 11+ implementation for {}", name); + } + + CompletionStage acquire(Duration timeout, int id, Executor executor) { + synchronized (queue) { + if (permits > 0) { + permits--; + log.trace("{} permit acquired id={}, available={}", name, id, permits); + return CompletableFuture.completedFuture(singletonPermit); + } else { + CompletableFuture f = new CompletableFuture<>(); + f.orTimeout(timeout.toNanos(), TimeUnit.NANOSECONDS) + .whenCompleteAsync( + (result, ex) -> { + synchronized (queue) { + if (ex != null) { + log.trace("{} permit timed out id={}, available={}", name, id, permits); + } + queue.remove(f); + } + }, + executor); + log.trace("{} permit queued id={}, available={}", name, id, permits); + queue.add(f); + return f; + } + } + } +} diff --git a/src/main/java11/org/xbill/DNS/DohResolver.java b/src/main/java11/org/xbill/DNS/DohResolver.java new file mode 100644 index 00000000..8e73eca5 --- /dev/null +++ b/src/main/java11/org/xbill/DNS/DohResolver.java @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpTimeoutException; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.WeakHashMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.AsyncSemaphore.Permit; + +/** + * Proof-of-concept DNS over HTTP (DoH) + * resolver. This class is not suitable for high load scenarios because of the shortcomings of + * Java's built-in HTTP clients. For more control, implement your own {@link Resolver} using e.g. OkHttp. + * + *

On Java 8, it uses HTTP/1.1, which is against the recommendation of RFC 8484 to use HTTP/2 and + * thus slower. On Java 11 or newer, HTTP/2 is always used, but the built-in HttpClient has its own + * issues with connection handling. + * + *

As of 2020-09-13, the following limits of public resolvers for HTTP/2 were observed: + *

  • https://cloudflare-dns.com/dns-query: max streams=250, idle timeout=400s + *
  • https://dns.google/dns-query: max streams=100, idle timeout=240s + * + * @since 3.0 + */ +@Slf4j +public final class DohResolver extends DohResolverCommon { + private static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; + private static final Map httpClients = + Collections.synchronizedMap(new WeakHashMap<>()); + private static final HttpRequest.Builder defaultHttpRequestBuilder; + + private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1, "initial request"); + + private final Duration idleConnectionTimeout; + + static { + defaultHttpRequestBuilder = HttpRequest.newBuilder(); + defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2); + defaultHttpRequestBuilder.header("Content-Type", APPLICATION_DNS_MESSAGE); + defaultHttpRequestBuilder.header("Accept", APPLICATION_DNS_MESSAGE); + } + + /** + * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). + * + * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} + */ + public DohResolver(String uriTemplate) { + this(uriTemplate, 100, Duration.ofMinutes(2)); + } + + /** + * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). + * + * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} + * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1 + * connections for Java 8. On Java 8 this cannot exceed the system property {@code + * http.maxConnections}. + * @param idleConnectionTimeout Max. idle time for HTTP/2 connections until a request is + * serialized. Applies to Java 11+ only. + * @since 3.3 + */ + public DohResolver( + String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) { + super(uriTemplate, maxConcurrentRequests); + log.debug("Using Java 11+ implementation"); + this.idleConnectionTimeout = idleConnectionTimeout; + } + + @SneakyThrows + private HttpClient getHttpClient(Executor executor) { + return httpClients.computeIfAbsent( + executor, + key -> { + try { + return HttpClient.newBuilder().connectTimeout(timeout).executor(executor).build(); + } catch (IllegalArgumentException e) { + log.warn("Could not create a HttpClient for Executor {}", key, e); + return null; + } + }); + } + + @Override + public void setTimeout(Duration timeout) { + this.timeout = timeout; + httpClients.clear(); + } + + /** + * Sets the EDNS information on outgoing messages. + * + * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS. + * @param payloadSize ignored + * @param flags EDNS extended flags to be set in the OPT record. + * @param options EDNS options to be set in the OPT record + */ + @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility + public void setEDNS(int version, int payloadSize, int flags, List options) { + // required for source- and binary compatibility + super.setEDNS(version, payloadSize, flags, options); + } + + @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility + public CompletionStage sendAsync(Message query) { + return this.sendAsync(query, defaultExecutor); + } + + @Override + public CompletionStage sendAsync(Message query, Executor executor) { + long startTime = getNanoTime(); + byte[] queryBytes = prepareQuery(query).toWire(); + String url = getUrl(queryBytes); + + var requestBuilder = defaultHttpRequestBuilder.copy(); + requestBuilder.uri(URI.create(url)); + if (usePost) { + requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes)); + } + + // check if this request needs to be done synchronously because of HttpClient's stupidity to + // not use the connection pool for HTTP/2 until one connection is successfully established, + // which could lead to hundreds of connections (and threads with the default executor) + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + return timeoutFailedFuture(query, "no time left to acquire lock for first request", null); + } + + return initialRequestLock + .acquire(remainingTimeout, query.getHeader().getID(), executor) + .handle( + (initialRequestPermit, initialRequestEx) -> { + if (initialRequestEx != null) { + return this.timeoutFailedFuture(query, initialRequestEx); + } else { + return sendAsyncWithInitialRequestPermit( + query, executor, startTime, requestBuilder, initialRequestPermit); + } + }) + .thenCompose(Function.identity()); + } + + private CompletionStage sendAsyncWithInitialRequestPermit( + Message query, + Executor executor, + long startTime, + HttpRequest.Builder requestBuilder, + Permit initialRequestPermit) { + int queryId = query.getHeader().getID(); + long lastRequestTime = lastRequest.get(); + long requestDeltaNanos = getNanoTime() - lastRequestTime; + boolean isInitialRequest = + lastRequestTime == 0 || idleConnectionTimeout.toNanos() < requestDeltaNanos; + if (!isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + + // check if we already exceeded the query timeout while checking the initial connection + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + + return timeoutFailedFuture( + query, "no time left to acquire lock for concurrent request", null); + } + + // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the + // request, but fail with an IOException which also CLOSES the connection... *facepalm* + return maxConcurrentRequests + .acquire(remainingTimeout, queryId, executor) + .handle( + (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> { + if (maxConcurrentRequestEx != null) { + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + return this.timeoutFailedFuture( + query, + "timed out waiting for a concurrent request lease", + maxConcurrentRequestEx); + } else { + return sendAsyncWithConcurrentRequestPermit( + query, + executor, + startTime, + requestBuilder, + initialRequestPermit, + isInitialRequest, + maxConcurrentRequestPermit); + } + }) + .thenCompose(Function.identity()); + } + + private CompletionStage sendAsyncWithConcurrentRequestPermit( + Message query, + Executor executor, + long startTime, + HttpRequest.Builder requestBuilder, + Permit initialRequestPermit, + boolean isInitialRequest, + Permit maxConcurrentRequestPermit) { + int queryId = query.getHeader().getID(); + + // check if the stream lock acquisition took too long + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + + maxConcurrentRequestPermit.release(queryId, executor); + return timeoutFailedFuture( + query, "no time left to acquire lock for concurrent request", null); + } + + var httpRequest = requestBuilder.timeout(remainingTimeout).build(); + var bodyHandler = HttpResponse.BodyHandlers.ofByteArray(); + return getHttpClient(executor) + .sendAsync(httpRequest, bodyHandler) + .whenComplete( + (result, ex) -> { + if (ex == null) { + lastRequest.set(startTime); + } + maxConcurrentRequestPermit.release(queryId, executor); + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + }) + .handleAsync( + (response, ex) -> { + if (ex != null) { + if (ex instanceof HttpTimeoutException) { + return this.timeoutFailedFuture( + query, "http request did not complete", ex.getCause()); + } else { + return CompletableFuture.failedFuture(ex); + } + } else { + try { + Message responseMessage; + int rc = response.statusCode(); + if (rc >= 200 && rc < 300) { + byte[] responseBytes = response.body(); + responseMessage = new Message(responseBytes); + verifyTSIG(query, responseMessage, responseBytes, tsig); + } else { + responseMessage = new Message(); + responseMessage.getHeader().setRcode(Rcode.SERVFAIL); + } + + responseMessage.setResolver(this); + return CompletableFuture.completedFuture(responseMessage); + } catch (IOException e) { + return CompletableFuture.failedFuture(e); + } + } + }, + executor) + .thenCompose(Function.identity()) + .orTimeout(remainingTimeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally( + ex -> { + if (ex instanceof TimeoutException) { + throw new CompletionException( + new TimeoutException( + "Query " + + query.getHeader().getID() + + " for " + + query.getQuestion().getName() + + "/" + + Type.string(query.getQuestion().getType()) + + " timed out in remaining " + + remainingTimeout.toMillis() + + "ms")); + } else if (ex instanceof CompletionException) { + throw (CompletionException) ex; + } + + throw new CompletionException(ex); + }); + } + + @Override + protected CompletableFuture failedFuture(Throwable e) { + return CompletableFuture.failedFuture(e); + } +} diff --git a/src/test/java/org/xbill/DNS/DohResolverTest.java b/src/test/java/org/xbill/DNS/DohResolverTest.java index 26daf19c..6d1500d6 100644 --- a/src/test/java/org/xbill/DNS/DohResolverTest.java +++ b/src/test/java/org/xbill/DNS/DohResolverTest.java @@ -28,8 +28,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledForJreRange; import org.junit.jupiter.api.condition.JRE; import org.junit.jupiter.api.extension.ExtendWith; @@ -38,8 +38,8 @@ import org.mockito.stubbing.Answer; @ExtendWith(VertxExtension.class) +@Slf4j class DohResolverTest { - private DohResolver resolver; private final Name queryName = Name.fromConstantString("example.com."); private final Record qr = Record.newRecord(queryName, Type.A, DClass.IN); private final Message qm = Message.newQuery(qr); @@ -48,7 +48,6 @@ class DohResolverTest { @BeforeEach void beforeEach() throws UnknownHostException { - resolver = new DohResolver("http://localhost"); Record ar = new ARecord( Name.fromConstantString("example.com."), @@ -59,11 +58,16 @@ void beforeEach() throws UnknownHostException { a.addRecord(ar, Section.ANSWER); } + private DohResolver getResolver() { + return new DohResolver("http://localhost"); + } + @ParameterizedTest @ValueSource(booleans = {false, true}) void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = getResolver(); resolver.setUsePost(usePost); - setupResolverWithServer(Duration.ZERO, 200, 1, vertx, context) + setupResolverWithServer(resolver, Duration.ZERO, 200, 1, vertx, context) .onSuccess( server -> Future.fromCompletionStage(resolver.sendAsync(qm)) @@ -79,10 +83,13 @@ void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) { })))); } - @Test - void timeoutResolve(Vertx vertx, VertxTestContext context) { + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void timeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = getResolver(); resolver.setTimeout(Duration.ofSeconds(1)); - setupResolverWithServer(Duration.ofSeconds(5), 200, 1, vertx, context) + resolver.setUsePost(usePost); + setupResolverWithServer(resolver, Duration.ofSeconds(5), 200, 1, vertx, context) .onSuccess( server -> Future.fromCompletionStage(resolver.sendAsync(qm)) @@ -98,9 +105,12 @@ void timeoutResolve(Vertx vertx, VertxTestContext context) { })))); } - @Test - void servfailResolve(Vertx vertx, VertxTestContext context) { - setupResolverWithServer(Duration.ZERO, 301, 1, vertx, context) + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void servfailResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = getResolver(); + resolver.setUsePost(usePost); + setupResolverWithServer(resolver, Duration.ZERO, 301, 1, vertx, context) .onSuccess( server -> Future.fromCompletionStage(resolver.sendAsync(qm)) @@ -114,12 +124,14 @@ void servfailResolve(Vertx vertx, VertxTestContext context) { })))); } - @Test - void limitRequestsResolve(Vertx vertx, VertxTestContext context) { - resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2)); + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void limitRequestsResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2)); + resolver.setUsePost(usePost); int requests = 100; Checkpoint cpPass = context.checkpoint(requests); - setupResolverWithServer(Duration.ofMillis(100), 200, 5, vertx, context) + setupResolverWithServer(resolver, Duration.ofMillis(100), 200, 5, vertx, context) .onSuccess( server -> { for (int i = 0; i < requests; i++) { @@ -137,13 +149,15 @@ void limitRequestsResolve(Vertx vertx, VertxTestContext context) { }); } - @Test - void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) { - resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void initialRequestSlowResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + resolver.setUsePost(usePost); int requests = 20; allRequestsUseTimeout = false; Checkpoint cpPass = context.checkpoint(requests); - setupResolverWithServer(Duration.ofSeconds(1), 200, 2, vertx, context) + setupResolverWithServer(resolver, Duration.ofSeconds(1), 200, 2, vertx, context) .onSuccess( server -> { for (int i = 0; i < requests; i++) { @@ -161,19 +175,23 @@ void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) { }); } - @Test - void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) { - resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void initialRequestTimeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + resolver.setUsePost(usePost); resolver.setTimeout(Duration.ofSeconds(1)); int requests = 20; allRequestsUseTimeout = false; Checkpoint cpPass = context.checkpoint(requests - 1); Checkpoint cpFail = context.checkpoint(); - setupResolverWithServer(Duration.ofSeconds(2), 200, 2, vertx, context) + setupResolverWithServer(resolver, Duration.ofSeconds(2), 200, 2, vertx, context) .onSuccess( server -> { + Message q = qm.clone(); + q.getHeader().setID(0); resolver - .sendAsync(qm) + .sendAsync(q) .whenComplete( (result, ex) -> { if (ex == null) { @@ -185,9 +203,11 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) { vertx.setTimer( 1000, timer -> { - for (int i = 0; i < requests - 1; i++) { + for (int i = 1; i < requests; i++) { + Message qq = qm.clone(); + qq.getHeader().setID(i); resolver - .sendAsync(qm) + .sendAsync(qq) .whenComplete( (result, ex) -> { if (ex == null) { @@ -202,6 +222,7 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) { } private Future setupResolverWithServer( + DohResolver resolver, Duration responseDelay, int statusCode, int maxConcurrentRequests, @@ -214,12 +235,14 @@ private Future setupResolverWithServer( @EnabledForJreRange( min = JRE.JAVA_9, disabledReason = "Java 8 implementation doesn't have the initial request guard") - @Test + @ParameterizedTest + @ValueSource(booleans = {false, true}) void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( - Vertx vertx, VertxTestContext context) { + boolean usePost, Vertx vertx, VertxTestContext context) { AtomicLong startNanos = new AtomicLong(System.nanoTime()); - resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2))); + DohResolver resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2))); resolver.setTimeout(Duration.ofSeconds(1)); + resolver.setUsePost(usePost); // Simulate a nanoTime value that is lower than the idle timeout doAnswer((Answer) invocationOnMock -> System.nanoTime() - startNanos.get()) .when(resolver) @@ -243,11 +266,13 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( AtomicBoolean firstCallCompleted = new AtomicBoolean(false); - setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context) + setupResolverWithServer(resolver, Duration.ofMillis(100L), 200, 2, vertx, context) .onSuccess( server -> { // First call - CompletionStage firstCall = resolver.sendAsync(qm); + CompletionStage firstCall = + resolver.sendAsync(qm).whenComplete((msg, ex) -> firstCallCompleted.set(true)); + // Ensure second call was made after first call and uses a different query startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20)); CompletionStage secondCall = resolver.sendAsync(Message.newQuery(qr)); @@ -261,7 +286,6 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( assertEquals(Rcode.NOERROR, result.getHeader().getRcode()); assertEquals(0, result.getHeader().getID()); assertEquals(queryName, result.getQuestion().getName()); - firstCallCompleted.set(true); }))); Future.fromCompletionStage(secondCall) @@ -302,10 +326,12 @@ private Future setupServer( int thisRequestNum = requestCount.incrementAndGet(); int count = concurrentRequests.incrementAndGet(); if (count > maxConcurrentRequests) { - context.failNow("Concurrent requests exceeded"); + context.failNow( + "Concurrent requests exceeded: " + count + " > " + maxConcurrentRequests); return; } + httpRequest.endHandler(v -> concurrentRequests.decrementAndGet()); httpRequest.bodyHandler( body -> { context.verify( @@ -332,15 +358,12 @@ private Future setupServer( && (thisRequestNum == 1 || allRequestsUseTimeout)) { vertx.setTimer( serverProcessingTime.toMillis(), - timer -> { - concurrentRequests.decrementAndGet(); - httpRequest - .response() - .setStatusCode(statusCode) - .end(Buffer.buffer(dnsResponseCopy.toWire())); - }); + timer -> + httpRequest + .response() + .setStatusCode(statusCode) + .end(Buffer.buffer(dnsResponseCopy.toWire()))); } else { - concurrentRequests.decrementAndGet(); httpRequest .response() .setStatusCode(statusCode)