httpClients =
- Collections.synchronizedMap(new WeakHashMap<>());
+public final class DohResolver extends DohResolverCommon {
private final SSLSocketFactory sslSocketFactory;
- private static Object defaultHttpRequestBuilder;
- private static Method publisherOfByteArrayMethod;
- private static Method requestBuilderTimeoutMethod;
- private static Method requestBuilderCopyMethod;
- private static Method requestBuilderUriMethod;
- private static Method requestBuilderBuildMethod;
- private static Method requestBuilderPostMethod;
-
- private static Method httpClientNewBuilderMethod;
- private static Method httpClientBuilderTimeoutMethod;
- private static Method httpClientBuilderExecutorMethod;
- private static Method httpClientBuilderBuildMethod;
- private static Method httpClientSendAsyncMethod;
-
- private static Method byteArrayBodyPublisherMethod;
- private static Method httpResponseBodyMethod;
- private static Method httpResponseStatusCodeMethod;
-
- private boolean usePost = false;
- private Duration timeout = Duration.ofSeconds(5);
- private String uriTemplate;
- private final Duration idleConnectionTimeout;
- private OPTRecord queryOPT = new OPTRecord(0, 0, 0);
- private TSIG tsig;
- private Executor defaultExecutor = ForkJoinPool.commonPool();
-
- /**
- * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections.
- *
- * rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2.
- */
- private final AsyncSemaphore maxConcurrentRequests;
-
- private final AtomicLong lastRequest = new AtomicLong(0);
- private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1);
-
- private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
-
- static {
- boolean initSuccess = false;
- if (!System.getProperty("java.version").startsWith("1.")) {
- try {
- Class> httpClientBuilderClass = Class.forName("java.net.http.HttpClient$Builder");
- Class> httpClientClass = Class.forName("java.net.http.HttpClient");
- Class> httpVersionEnum = Class.forName("java.net.http.HttpClient$Version");
- Class> httpRequestBuilderClass = Class.forName("java.net.http.HttpRequest$Builder");
- Class> httpRequestClass = Class.forName("java.net.http.HttpRequest");
- Class> bodyPublishersClass = Class.forName("java.net.http.HttpRequest$BodyPublishers");
- Class> bodyPublisherClass = Class.forName("java.net.http.HttpRequest$BodyPublisher");
- Class> httpResponseClass = Class.forName("java.net.http.HttpResponse");
- Class> bodyHandlersClass = Class.forName("java.net.http.HttpResponse$BodyHandlers");
- Class> bodyHandlerClass = Class.forName("java.net.http.HttpResponse$BodyHandler");
-
- // HttpClient.Builder
- httpClientBuilderTimeoutMethod =
- httpClientBuilderClass.getDeclaredMethod("connectTimeout", Duration.class);
- httpClientBuilderExecutorMethod =
- httpClientBuilderClass.getDeclaredMethod("executor", Executor.class);
- httpClientBuilderBuildMethod = httpClientBuilderClass.getDeclaredMethod("build");
-
- // HttpClient
- httpClientNewBuilderMethod = httpClientClass.getDeclaredMethod("newBuilder");
- httpClientSendAsyncMethod =
- httpClientClass.getDeclaredMethod("sendAsync", httpRequestClass, bodyHandlerClass);
-
- // HttpRequestBuilder
- Method requestBuilderHeaderMethod =
- httpRequestBuilderClass.getDeclaredMethod("header", String.class, String.class);
- Method requestBuilderVersionMethod =
- httpRequestBuilderClass.getDeclaredMethod("version", httpVersionEnum);
- requestBuilderTimeoutMethod =
- httpRequestBuilderClass.getDeclaredMethod("timeout", Duration.class);
- requestBuilderUriMethod = httpRequestBuilderClass.getDeclaredMethod("uri", URI.class);
- requestBuilderCopyMethod = httpRequestBuilderClass.getDeclaredMethod("copy");
- requestBuilderBuildMethod = httpRequestBuilderClass.getDeclaredMethod("build");
- requestBuilderPostMethod =
- httpRequestBuilderClass.getDeclaredMethod("POST", bodyPublisherClass);
-
- // HttpRequest
- Method requestBuilderNewBuilderMethod = httpRequestClass.getDeclaredMethod("newBuilder");
-
- // BodyPublishers
- publisherOfByteArrayMethod =
- bodyPublishersClass.getDeclaredMethod("ofByteArray", byte[].class);
-
- // BodyPublisher
- byteArrayBodyPublisherMethod = bodyHandlersClass.getDeclaredMethod("ofByteArray");
-
- // HttpResponse
- httpResponseBodyMethod = httpResponseClass.getDeclaredMethod("body");
- httpResponseStatusCodeMethod = httpResponseClass.getDeclaredMethod("statusCode");
-
- // defaultHttpRequestBuilder = HttpRequest.newBuilder();
- // defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2);
- // defaultHttpRequestBuilder.header("Content-Type", "application/dns-message");
- // defaultHttpRequestBuilder.header("Accept", "application/dns-message");
- defaultHttpRequestBuilder = requestBuilderNewBuilderMethod.invoke(null);
- @SuppressWarnings({"unchecked", "rawtypes"})
- Enum> http2Version = Enum.valueOf((Class) httpVersionEnum, "HTTP_2");
- requestBuilderVersionMethod.invoke(defaultHttpRequestBuilder, http2Version);
- requestBuilderHeaderMethod.invoke(
- defaultHttpRequestBuilder, "Content-Type", APPLICATION_DNS_MESSAGE);
- requestBuilderHeaderMethod.invoke(
- defaultHttpRequestBuilder, "Accept", APPLICATION_DNS_MESSAGE);
- initSuccess = true;
- } catch (ClassNotFoundException
- | NoSuchMethodException
- | IllegalAccessException
- | InvocationTargetException e) {
- // fallback to Java 8
- log.warn("Java >= 11 detected, but HttpRequest not available");
- }
- }
-
- USE_HTTP_CLIENT = initSuccess;
- }
-
- // package-visible for testing
- long getNanoTime() {
- return System.nanoTime();
- }
-
/**
* Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
*
* @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
*/
public DohResolver(String uriTemplate) {
- this(uriTemplate, 100, Duration.ofMinutes(2));
+ this(uriTemplate, 100, Duration.ZERO);
}
/**
@@ -201,22 +68,9 @@ public DohResolver(String uriTemplate) {
*/
public DohResolver(
String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) {
- this.uriTemplate = uriTemplate;
- this.idleConnectionTimeout = idleConnectionTimeout;
- if (maxConcurrentRequests <= 0) {
- throw new IllegalArgumentException("maxConcurrentRequests must be > 0");
- }
- if (!USE_HTTP_CLIENT) {
- try {
- int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5"));
- if (maxConcurrentRequests > javaMaxConn) {
- maxConcurrentRequests = javaMaxConn;
- }
- } catch (NumberFormatException nfe) {
- // well, use what we got
- }
- }
- this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests);
+ super(uriTemplate, maxConcurrentRequests);
+
+ log.debug("Using Java 8 implementation");
try {
sslSocketFactory = SSLContext.getDefault().getSocketFactory();
} catch (NoSuchAlgorithmException e) {
@@ -224,45 +78,6 @@ public DohResolver(
}
}
- @SneakyThrows
- private Object getHttpClient(Executor executor) {
- return httpClients.computeIfAbsent(
- executor,
- key -> {
- try {
- // return HttpClient.newBuilder()
- // .connectTimeout(timeout).
- // .executor(executor)
- // .build();
- Object httpClientBuilder = httpClientNewBuilderMethod.invoke(null);
- httpClientBuilderTimeoutMethod.invoke(httpClientBuilder, timeout);
- httpClientBuilderExecutorMethod.invoke(httpClientBuilder, key);
- return httpClientBuilderBuildMethod.invoke(httpClientBuilder);
- } catch (IllegalAccessException | InvocationTargetException e) {
- log.warn("Could not create a HttpClient with for Executor {}", key, e);
- return null;
- }
- });
- }
-
- /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */
- @Override
- public void setPort(int port) {
- // Not implemented, port is part of the URI
- }
-
- /** Not implemented. */
- @Override
- public void setTCP(boolean flag) {
- // Not implemented, HTTP is always TCP
- }
-
- /** Not implemented. */
- @Override
- public void setIgnoreTruncation(boolean flag) {
- // Not implemented, protocol uses TCP and doesn't have truncation
- }
-
/**
* Sets the EDNS information on outgoing messages.
*
@@ -272,87 +87,85 @@ public void setIgnoreTruncation(boolean flag) {
* @param options EDNS options to be set in the OPT record
*/
@Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
public void setEDNS(int version, int payloadSize, int flags, List options) {
- switch (version) {
- case -1:
- queryOPT = null;
- break;
-
- case 0:
- queryOPT = new OPTRecord(0, 0, version, flags, options);
- break;
-
- default:
- throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
- }
- }
-
- @Override
- public void setTSIGKey(TSIG key) {
- this.tsig = key;
- }
-
- @Override
- public void setTimeout(Duration timeout) {
- this.timeout = timeout;
- httpClients.clear();
- }
-
- @Override
- public Duration getTimeout() {
- return timeout;
+ // required for source- and binary compatibility
+ super.setEDNS(version, payloadSize, flags, options);
}
@Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
public CompletionStage sendAsync(Message query) {
- return sendAsync(query, defaultExecutor);
+ // required for source- and binary compatibility
+ return this.sendAsync(query, defaultExecutor);
}
@Override
public CompletionStage sendAsync(Message query, Executor executor) {
- if (USE_HTTP_CLIENT) {
- return sendAsync11(query, executor);
- }
-
- return sendAsync8(query, executor);
- }
-
- private CompletionStage sendAsync8(final Message query, Executor executor) {
byte[] queryBytes = prepareQuery(query).toWire();
String url = getUrl(queryBytes);
long startTime = getNanoTime();
- return maxConcurrentRequests
- .acquire(timeout)
- .handleAsync(
- (permit, ex) -> {
- if (ex != null) {
- return this.timeoutFailedFuture(query, ex);
- } else {
- try {
- SendAndGetMessageBytesResponse result =
- sendAndGetMessageBytes(url, queryBytes, startTime);
- Message response;
- if (result.rc == Rcode.NOERROR) {
- response = new Message(result.responseBytes);
- verifyTSIG(query, response, result.responseBytes, tsig);
+ int queryId = query.getHeader().getID();
+
+ CompletableFuture f =
+ maxConcurrentRequests
+ .acquire(timeout, queryId, executor)
+ .handleAsync(
+ (permit, ex) -> {
+ if (ex != null) {
+ return this.timeoutFailedFuture(
+ query, "could not acquire lock to send request", ex);
} else {
- response = new Message(0);
- response.getHeader().setRcode(result.rc);
+ try {
+ SendAndGetMessageBytesResponse result =
+ sendAndGetMessageBytes(url, queryBytes, startTime);
+ Message response;
+ if (result.rc == Rcode.NOERROR) {
+ response = new Message(result.responseBytes);
+ verifyTSIG(query, response, result.responseBytes, tsig);
+ } else {
+ response = new Message(0);
+ response.getHeader().setRcode(result.rc);
+ }
+
+ response.setResolver(this);
+ return CompletableFuture.completedFuture(response);
+ } catch (SocketTimeoutException e) {
+ return this.timeoutFailedFuture(query, e);
+ } catch (IOException | URISyntaxException e) {
+ return this.failedFuture(e);
+ } finally {
+ permit.release(queryId, executor);
+ }
}
+ },
+ executor)
+ .thenCompose(Function.identity())
+ .toCompletableFuture();
- response.setResolver(this);
- return CompletableFuture.completedFuture(response);
- } catch (SocketTimeoutException e) {
- return this.timeoutFailedFuture(query, e);
- } catch (IOException | URISyntaxException e) {
- return this.failedFuture(e);
- } finally {
- permit.release();
- }
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ return TimeoutCompletableFuture.compatTimeout(
+ f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS)
+ .exceptionally(
+ ex -> {
+ if (ex instanceof TimeoutException) {
+ throw new CompletionException(
+ new TimeoutException(
+ "Query "
+ + queryId
+ + " for "
+ + query.getQuestion().getName()
+ + "/"
+ + Type.string(query.getQuestion().getType())
+ + " timed out in remaining "
+ + remainingTimeout.toMillis()
+ + "ms"));
+ } else if (ex instanceof CompletionException) {
+ throw (CompletionException) ex;
}
- },
- executor)
- .thenCompose(Function.identity());
+
+ throw new CompletionException(ex);
+ });
}
@Value
@@ -367,15 +180,28 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
if (conn instanceof HttpsURLConnection) {
((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory);
}
-
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- conn.setConnectTimeout((int) remainingTimeout.toMillis());
- conn.setReadTimeout((int) remainingTimeout.toMillis());
conn.setRequestMethod(usePost ? "POST" : "GET");
conn.setRequestProperty("Content-Type", APPLICATION_DNS_MESSAGE);
conn.setRequestProperty("Accept", APPLICATION_DNS_MESSAGE);
+
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ throw new SocketTimeoutException("No time left to connect");
+ }
+
+ conn.setConnectTimeout((int) remainingTimeout.toMillis());
if (usePost) {
conn.setDoOutput(true);
+ }
+
+ conn.connect();
+ remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ throw new SocketTimeoutException("No time left to request data");
+ }
+
+ conn.setReadTimeout((int) remainingTimeout.toMillis());
+ if (usePost) {
conn.getOutputStream().write(queryBytes);
}
@@ -395,13 +221,23 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) {
offset += r;
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- throw new SocketTimeoutException();
+
+ // Don't throw if we just received all data
+ if (offset != responseBytes.length
+ && (remainingTimeout.isNegative() || remainingTimeout.isZero())) {
+ throw new SocketTimeoutException(
+ "Timed out waiting for response data, got "
+ + offset
+ + " of "
+ + responseBytes.length
+ + " expected bytes");
}
}
+
if (offset < responseBytes.length) {
throw new EOFException("Could not read expected content length");
}
+
return new SendAndGetMessageBytesResponse(Rcode.NOERROR, responseBytes);
} else {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
@@ -409,8 +245,9 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
int r;
while ((r = is.read(buffer, 0, buffer.length)) > 0) {
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- throw new SocketTimeoutException();
+ if (remainingTimeout.isNegative() || remainingTimeout.isZero()) {
+ throw new SocketTimeoutException(
+ "Timed out waiting for response data, got " + bos.size() + " bytes so far");
}
bos.write(buffer, 0, r);
}
@@ -436,275 +273,10 @@ private void discardStream(InputStream es) throws IOException {
}
}
- private CompletionStage sendAsync11(final Message query, Executor executor) {
- long startTime = getNanoTime();
- byte[] queryBytes = prepareQuery(query).toWire();
- String url = getUrl(queryBytes);
-
- // var requestBuilder = defaultHttpRequestBuilder.copy();
- // requestBuilder.uri(URI.create(url));
- Object requestBuilder;
- try {
- requestBuilder = requestBuilderCopyMethod.invoke(defaultHttpRequestBuilder);
- requestBuilderUriMethod.invoke(requestBuilder, URI.create(url));
- if (usePost) {
- // requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes));
- requestBuilderPostMethod.invoke(
- requestBuilder, publisherOfByteArrayMethod.invoke(null, queryBytes));
- }
- } catch (IllegalAccessException | InvocationTargetException e) {
- return failedFuture(e);
- }
-
- // check if this request needs to be done synchronously because of HttpClient's stupidity to
- // not use the connection pool for HTTP/2 until one connection is successfully established,
- // which could lead to hundreds of connections (and threads with the default executor)
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- return initialRequestLock
- .acquire(remainingTimeout)
- .handle(
- (initialRequestPermit, initialRequestEx) -> {
- if (initialRequestEx != null) {
- return this.timeoutFailedFuture(query, initialRequestEx);
- } else {
- return sendAsync11WithInitialRequestPermit(
- query, executor, startTime, requestBuilder, initialRequestPermit);
- }
- })
- .thenCompose(Function.identity());
- }
-
- private CompletionStage sendAsync11WithInitialRequestPermit(
- Message query,
- Executor executor,
- long startTime,
- Object requestBuilder,
- Permit initialRequestPermit) {
- long lastRequestTime = lastRequest.get();
- boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime;
- if (!isInitialRequest) {
- initialRequestPermit.release();
- }
-
- // check if we already exceeded the query timeout while checking the initial connection
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- return timeoutFailedFuture(query, null);
- }
-
- // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the
- // request, but fail with an IOException which also CLOSES the connection... *facepalm*
- return maxConcurrentRequests
- .acquire(remainingTimeout)
- .handle(
- (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> {
- if (maxConcurrentRequestEx != null) {
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- return this.timeoutFailedFuture(query, maxConcurrentRequestEx);
- } else {
- return sendAsync11WithConcurrentRequestPermit(
- query,
- executor,
- startTime,
- requestBuilder,
- initialRequestPermit,
- isInitialRequest,
- maxConcurrentRequestPermit);
- }
- })
- .thenCompose(Function.identity());
- }
-
- private CompletionStage sendAsync11WithConcurrentRequestPermit(
- Message query,
- Executor executor,
- long startTime,
- Object requestBuilder,
- Permit initialRequestPermit,
- boolean isInitialRequest,
- Permit maxConcurrentRequestPermit) {
- // check if the stream lock acquisition took too long
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- maxConcurrentRequestPermit.release();
- return timeoutFailedFuture(query, null);
- }
-
- // var httpRequest = requestBuilder.timeout(remainingTimeout).build();
- // var bodyHandler = HttpResponse.BodyHandlers.ofByteArray();
- // return getHttpClient(executor).sendAsync(httpRequest, bodyHandler)
- try {
- Object httpClient = getHttpClient(executor);
- requestBuilderTimeoutMethod.invoke(requestBuilder, remainingTimeout);
- Object httpRequest = requestBuilderBuildMethod.invoke(requestBuilder);
- Object bodyHandler = byteArrayBodyPublisherMethod.invoke(null);
- CompletableFuture f =
- ((CompletableFuture>)
- httpClientSendAsyncMethod.invoke(httpClient, httpRequest, bodyHandler))
- .whenComplete(
- (result, ex) -> {
- if (ex == null) {
- lastRequest.set(startTime);
- }
- maxConcurrentRequestPermit.release();
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- })
- .handleAsync(
- (response, ex) -> {
- if (ex != null) {
- if (ex.getCause().getClass().getSimpleName().equals("HttpTimeoutException")) {
- return this.timeoutFailedFuture(query, ex.getCause());
- } else {
- return this.failedFuture(ex);
- }
- } else {
- try {
- Message responseMessage;
- // int rc = response.statusCode();
- int rc = (int) httpResponseStatusCodeMethod.invoke(response);
- if (rc >= 200 && rc < 300) {
- // byte[] responseBytes = response.body();
- byte[] responseBytes = (byte[]) httpResponseBodyMethod.invoke(response);
- responseMessage = new Message(responseBytes);
- verifyTSIG(query, responseMessage, responseBytes, tsig);
- } else {
- responseMessage = new Message();
- responseMessage.getHeader().setRcode(Rcode.SERVFAIL);
- }
-
- responseMessage.setResolver(this);
- return CompletableFuture.completedFuture(responseMessage);
- } catch (IOException | IllegalAccessException | InvocationTargetException e) {
- return this.failedFuture(e);
- }
- }
- },
- executor)
- .thenCompose(Function.identity());
- return TimeoutCompletableFuture.compatTimeout(
- f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS);
- } catch (IllegalAccessException | InvocationTargetException e) {
- return failedFuture(e);
- }
- }
-
- private CompletableFuture failedFuture(Throwable e) {
+ @Override
+ protected CompletableFuture failedFuture(Throwable e) {
CompletableFuture f = new CompletableFuture<>();
f.completeExceptionally(e);
return f;
}
-
- private CompletableFuture timeoutFailedFuture(Message query, Throwable inner) {
- return failedFuture(
- new IOException(
- "Query "
- + query.getHeader().getID()
- + " for "
- + query.getQuestion().getName()
- + "/"
- + Type.string(query.getQuestion().getType())
- + " timed out",
- inner));
- }
-
- private String getUrl(byte[] queryBytes) {
- String url = uriTemplate;
- if (!usePost) {
- url += "?dns=" + base64.toString(queryBytes, true);
- }
- return url;
- }
-
- private Message prepareQuery(Message query) {
- Message preparedQuery = query.clone();
- preparedQuery.getHeader().setID(0);
- if (queryOPT != null && preparedQuery.getOPT() == null) {
- preparedQuery.addRecord(queryOPT, Section.ADDITIONAL);
- }
-
- if (tsig != null) {
- tsig.apply(preparedQuery, null);
- }
-
- return preparedQuery;
- }
-
- private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
- if (tsig == null) {
- return;
- }
-
- int error = tsig.verify(response, b, query.getGeneratedTSIG());
- log.debug(
- "TSIG verify for query {}, {}/{}: {}",
- query.getHeader().getID(),
- query.getQuestion().getName(),
- Type.string(query.getQuestion().getType()),
- Rcode.TSIGstring(error));
- }
-
- /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */
- public boolean isUsePost() {
- return usePost;
- }
-
- /**
- * Sets the HTTP method to use for resolving.
- *
- * @param usePost {@code true} to use POST, {@code false} to use GET (the default).
- */
- public void setUsePost(boolean usePost) {
- this.usePost = usePost;
- }
-
- /** Gets the current URI used for resolving. */
- public String getUriTemplate() {
- return uriTemplate;
- }
-
- /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */
- public void setUriTemplate(String uriTemplate) {
- this.uriTemplate = uriTemplate;
- }
-
- /**
- * Gets the default {@link Executor} for request handling, defaults to {@link
- * ForkJoinPool#commonPool()}.
- *
- * @since 3.3
- * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used.
- */
- @Deprecated
- public Executor getExecutor() {
- return defaultExecutor;
- }
-
- /**
- * Sets the default {@link Executor} for request handling.
- *
- * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link
- * ForkJoinPool#commonPool()}).
- * @since 3.3
- * @deprecated Use {@link #sendAsync(Message, Executor)}.
- */
- @Deprecated
- public void setExecutor(Executor executor) {
- this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor;
- httpClients.clear();
- }
-
- @Override
- public String toString() {
- return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}";
- }
}
diff --git a/src/main/java/org/xbill/DNS/DohResolverCommon.java b/src/main/java/org/xbill/DNS/DohResolverCommon.java
new file mode 100644
index 00000000..0fb8ac07
--- /dev/null
+++ b/src/main/java/org/xbill/DNS/DohResolverCommon.java
@@ -0,0 +1,232 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicLong;
+import lombok.extern.slf4j.Slf4j;
+import org.xbill.DNS.utils.base64;
+
+@Slf4j
+abstract class DohResolverCommon implements Resolver {
+ /**
+ * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections.
+ *
+ * rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2.
+ */
+ protected final AsyncSemaphore maxConcurrentRequests;
+
+ protected final AtomicLong lastRequest = new AtomicLong(0);
+
+ protected static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
+
+ protected boolean usePost = false;
+ protected Duration timeout = Duration.ofSeconds(5);
+ protected String uriTemplate;
+ protected OPTRecord queryOPT = new OPTRecord(0, 0, 0);
+ protected TSIG tsig;
+ protected Executor defaultExecutor = ForkJoinPool.commonPool();
+
+ // package-visible for testing
+ long getNanoTime() {
+ return System.nanoTime();
+ }
+
+ /**
+ * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
+ *
+ * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
+ * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1
+ * connections for Java 8. On Java 8 this cannot exceed the system property {@code
+ * http.maxConnections}.
+ */
+ protected DohResolverCommon(String uriTemplate, int maxConcurrentRequests) {
+ this.uriTemplate = uriTemplate;
+ if (maxConcurrentRequests <= 0) {
+ throw new IllegalArgumentException("maxConcurrentRequests must be > 0");
+ }
+
+ try {
+ int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5"));
+ if (maxConcurrentRequests > javaMaxConn) {
+ maxConcurrentRequests = javaMaxConn;
+ }
+ } catch (NumberFormatException nfe) {
+ // well, use what we got
+ }
+
+ this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests, "concurrent");
+ }
+
+ /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */
+ @Override
+ public void setPort(int port) {
+ // Not implemented, port is part of the URI
+ }
+
+ /** Not implemented. */
+ @Override
+ public void setTCP(boolean flag) {
+ // Not implemented, HTTP is always TCP
+ }
+
+ /** Not implemented. */
+ @Override
+ public void setIgnoreTruncation(boolean flag) {
+ // Not implemented, protocol uses TCP and doesn't have truncation
+ }
+
+ /**
+ * Sets the EDNS information on outgoing messages.
+ *
+ * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS.
+ * @param payloadSize ignored
+ * @param flags EDNS extended flags to be set in the OPT record.
+ * @param options EDNS options to be set in the OPT record
+ */
+ @Override
+ public void setEDNS(int version, int payloadSize, int flags, List options) {
+ switch (version) {
+ case -1:
+ queryOPT = null;
+ break;
+
+ case 0:
+ queryOPT = new OPTRecord(0, 0, version, flags, options);
+ break;
+
+ default:
+ throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
+ }
+ }
+
+ @Override
+ public void setTSIGKey(TSIG key) {
+ this.tsig = key;
+ }
+
+ @Override
+ public void setTimeout(Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ @Override
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ protected String getUrl(byte[] queryBytes) {
+ String url = uriTemplate;
+ if (!usePost) {
+ url += "?dns=" + base64.toString(queryBytes, true);
+ }
+ return url;
+ }
+
+ protected Message prepareQuery(Message query) {
+ Message preparedQuery = query.clone();
+ preparedQuery.getHeader().setID(0);
+ if (queryOPT != null && preparedQuery.getOPT() == null) {
+ preparedQuery.addRecord(queryOPT, Section.ADDITIONAL);
+ }
+
+ if (tsig != null) {
+ tsig.apply(preparedQuery, null);
+ }
+
+ return preparedQuery;
+ }
+
+ protected void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
+ if (tsig == null) {
+ return;
+ }
+
+ int error = tsig.verify(response, b, query.getGeneratedTSIG());
+ log.debug(
+ "TSIG verify for query {}, {}/{}: {}",
+ query.getHeader().getID(),
+ query.getQuestion().getName(),
+ Type.string(query.getQuestion().getType()),
+ Rcode.TSIGstring(error));
+ }
+
+ /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */
+ public boolean isUsePost() {
+ return usePost;
+ }
+
+ /**
+ * Sets the HTTP method to use for resolving.
+ *
+ * @param usePost {@code true} to use POST, {@code false} to use GET (the default).
+ */
+ public void setUsePost(boolean usePost) {
+ this.usePost = usePost;
+ }
+
+ /** Gets the current URI used for resolving. */
+ public String getUriTemplate() {
+ return uriTemplate;
+ }
+
+ /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */
+ public void setUriTemplate(String uriTemplate) {
+ this.uriTemplate = uriTemplate;
+ }
+
+ /**
+ * Gets the default {@link Executor} for request handling, defaults to {@link
+ * ForkJoinPool#commonPool()}.
+ *
+ * @since 3.3
+ * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used.
+ */
+ @Deprecated
+ public Executor getExecutor() {
+ return defaultExecutor;
+ }
+
+ /**
+ * Sets the default {@link Executor} for request handling.
+ *
+ * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link
+ * ForkJoinPool#commonPool()}).
+ * @since 3.3
+ * @deprecated Use {@link #sendAsync(Message, Executor)}.
+ */
+ @Deprecated
+ public void setExecutor(Executor executor) {
+ this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor;
+ }
+
+ @Override
+ public String toString() {
+ return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}";
+ }
+
+ protected abstract CompletableFuture failedFuture(Throwable e);
+
+ protected final CompletableFuture timeoutFailedFuture(Message query, Throwable inner) {
+ return timeoutFailedFuture(query, null, inner);
+ }
+
+ protected final CompletableFuture timeoutFailedFuture(
+ Message query, String message, Throwable inner) {
+ return failedFuture(
+ new TimeoutException(
+ "Query "
+ + query.getHeader().getID()
+ + " for "
+ + query.getQuestion().getName()
+ + "/"
+ + Type.string(query.getQuestion().getType())
+ + " timed out"
+ + (message != null ? ": " + message : "")
+ + (inner != null && inner.getMessage() != null ? ", " + inner.getMessage() : "")));
+ }
+}
diff --git a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java
index 24796df6..e2ae60f6 100644
--- a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java
+++ b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java
@@ -1,8 +1,6 @@
// SPDX-License-Identifier: BSD-3-Clause
package org.xbill.DNS;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
@@ -10,57 +8,28 @@
import java.util.concurrent.TimeoutException;
import lombok.extern.slf4j.Slf4j;
-/**
- * Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. On Java 9+
- * the built-in method is called.
- */
+/** Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. */
@Slf4j
class TimeoutCompletableFuture extends CompletableFuture {
- private static final Method orTimeoutMethod;
-
- static {
- Method localOrTimeoutMethod;
- if (!System.getProperty("java.version").startsWith("1.")) {
- try {
- localOrTimeoutMethod =
- CompletableFuture.class.getMethod("orTimeout", long.class, TimeUnit.class);
- } catch (NoSuchMethodException e) {
- localOrTimeoutMethod = null;
- log.warn(
- "CompletableFuture.orTimeout method not found in Java 9+, using custom implementation",
- e);
- }
- } else {
- localOrTimeoutMethod = null;
- }
- orTimeoutMethod = localOrTimeoutMethod;
- }
-
public CompletableFuture compatTimeout(long timeout, TimeUnit unit) {
return compatTimeout(this, timeout, unit);
}
- @SuppressWarnings("unchecked")
public static CompletableFuture compatTimeout(
CompletableFuture f, long timeout, TimeUnit unit) {
- if (orTimeoutMethod == null) {
- return orTimeout(f, timeout, unit);
- } else {
- try {
- return (CompletableFuture) orTimeoutMethod.invoke(f, timeout, unit);
- } catch (IllegalAccessException | InvocationTargetException e) {
- return orTimeout(f, timeout, unit);
- }
+ if (timeout <= 0) {
+ f.completeExceptionally(new TimeoutException("timeout is " + timeout + ", but must be > 0"));
}
- }
- private static CompletableFuture orTimeout(
- CompletableFuture f, long timeout, TimeUnit unit) {
ScheduledFuture> sf =
TimeoutScheduler.executor.schedule(
() -> {
if (!f.isDone()) {
- f.completeExceptionally(new TimeoutException());
+ f.completeExceptionally(
+ new TimeoutException(
+ "Timeout of "
+ + unit.toMillis(timeout)
+ + "ms has elapsed before the task completed"));
}
},
timeout,
diff --git a/src/main/java11/org/xbill/DNS/AsyncSemaphore.java b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java
new file mode 100644
index 00000000..d2dcf6dd
--- /dev/null
+++ b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java
@@ -0,0 +1,66 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import java.time.Duration;
+import java.util.ArrayDeque;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import lombok.extern.slf4j.Slf4j;
+
+@Slf4j
+final class AsyncSemaphore {
+ private final Queue> queue = new ArrayDeque<>();
+ private final Permit singletonPermit = new Permit();
+ private final String name;
+ private volatile int permits;
+
+ final class Permit {
+ public void release(int id, Executor executor) {
+ synchronized (queue) {
+ CompletableFuture next = queue.poll();
+ if (next == null) {
+ permits++;
+ log.trace("{} permit released id={}, available={}", name, id, permits);
+ } else {
+ log.trace("{} permit released id={}, available={}, immediate next", name, id, permits);
+ next.completeAsync(() -> this, executor);
+ }
+ }
+ }
+ }
+
+ AsyncSemaphore(int permits, String name) {
+ this.permits = permits;
+ this.name = name;
+ log.debug("Using Java 11+ implementation for {}", name);
+ }
+
+ CompletionStage acquire(Duration timeout, int id, Executor executor) {
+ synchronized (queue) {
+ if (permits > 0) {
+ permits--;
+ log.trace("{} permit acquired id={}, available={}", name, id, permits);
+ return CompletableFuture.completedFuture(singletonPermit);
+ } else {
+ CompletableFuture f = new CompletableFuture<>();
+ f.orTimeout(timeout.toNanos(), TimeUnit.NANOSECONDS)
+ .whenCompleteAsync(
+ (result, ex) -> {
+ synchronized (queue) {
+ if (ex != null) {
+ log.trace("{} permit timed out id={}, available={}", name, id, permits);
+ }
+ queue.remove(f);
+ }
+ },
+ executor);
+ log.trace("{} permit queued id={}, available={}", name, id, permits);
+ queue.add(f);
+ return f;
+ }
+ }
+ }
+}
diff --git a/src/main/java11/org/xbill/DNS/DohResolver.java b/src/main/java11/org/xbill/DNS/DohResolver.java
new file mode 100644
index 00000000..8e73eca5
--- /dev/null
+++ b/src/main/java11/org/xbill/DNS/DohResolver.java
@@ -0,0 +1,311 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.net.http.HttpTimeoutException;
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.WeakHashMap;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.function.Function;
+import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import org.xbill.DNS.AsyncSemaphore.Permit;
+
+/**
+ * Proof-of-concept DNS over HTTP (DoH)
+ * resolver. This class is not suitable for high load scenarios because of the shortcomings of
+ * Java's built-in HTTP clients. For more control, implement your own {@link Resolver} using e.g. OkHttp.
+ *
+ * On Java 8, it uses HTTP/1.1, which is against the recommendation of RFC 8484 to use HTTP/2 and
+ * thus slower. On Java 11 or newer, HTTP/2 is always used, but the built-in HttpClient has its own
+ * issues with connection handling.
+ *
+ *
As of 2020-09-13, the following limits of public resolvers for HTTP/2 were observed:
+ *
https://cloudflare-dns.com/dns-query: max streams=250, idle timeout=400s
+ * https://dns.google/dns-query: max streams=100, idle timeout=240s
+ *
+ * @since 3.0
+ */
+@Slf4j
+public final class DohResolver extends DohResolverCommon {
+ private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
+ private static final Map httpClients =
+ Collections.synchronizedMap(new WeakHashMap<>());
+ private static final HttpRequest.Builder defaultHttpRequestBuilder;
+
+ private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1, "initial request");
+
+ private final Duration idleConnectionTimeout;
+
+ static {
+ defaultHttpRequestBuilder = HttpRequest.newBuilder();
+ defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2);
+ defaultHttpRequestBuilder.header("Content-Type", APPLICATION_DNS_MESSAGE);
+ defaultHttpRequestBuilder.header("Accept", APPLICATION_DNS_MESSAGE);
+ }
+
+ /**
+ * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
+ *
+ * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
+ */
+ public DohResolver(String uriTemplate) {
+ this(uriTemplate, 100, Duration.ofMinutes(2));
+ }
+
+ /**
+ * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
+ *
+ * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
+ * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1
+ * connections for Java 8. On Java 8 this cannot exceed the system property {@code
+ * http.maxConnections}.
+ * @param idleConnectionTimeout Max. idle time for HTTP/2 connections until a request is
+ * serialized. Applies to Java 11+ only.
+ * @since 3.3
+ */
+ public DohResolver(
+ String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) {
+ super(uriTemplate, maxConcurrentRequests);
+ log.debug("Using Java 11+ implementation");
+ this.idleConnectionTimeout = idleConnectionTimeout;
+ }
+
+ @SneakyThrows
+ private HttpClient getHttpClient(Executor executor) {
+ return httpClients.computeIfAbsent(
+ executor,
+ key -> {
+ try {
+ return HttpClient.newBuilder().connectTimeout(timeout).executor(executor).build();
+ } catch (IllegalArgumentException e) {
+ log.warn("Could not create a HttpClient for Executor {}", key, e);
+ return null;
+ }
+ });
+ }
+
+ @Override
+ public void setTimeout(Duration timeout) {
+ this.timeout = timeout;
+ httpClients.clear();
+ }
+
+ /**
+ * Sets the EDNS information on outgoing messages.
+ *
+ * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS.
+ * @param payloadSize ignored
+ * @param flags EDNS extended flags to be set in the OPT record.
+ * @param options EDNS options to be set in the OPT record
+ */
+ @Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
+ public void setEDNS(int version, int payloadSize, int flags, List options) {
+ // required for source- and binary compatibility
+ super.setEDNS(version, payloadSize, flags, options);
+ }
+
+ @Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
+ public CompletionStage sendAsync(Message query) {
+ return this.sendAsync(query, defaultExecutor);
+ }
+
+ @Override
+ public CompletionStage sendAsync(Message query, Executor executor) {
+ long startTime = getNanoTime();
+ byte[] queryBytes = prepareQuery(query).toWire();
+ String url = getUrl(queryBytes);
+
+ var requestBuilder = defaultHttpRequestBuilder.copy();
+ requestBuilder.uri(URI.create(url));
+ if (usePost) {
+ requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes));
+ }
+
+ // check if this request needs to be done synchronously because of HttpClient's stupidity to
+ // not use the connection pool for HTTP/2 until one connection is successfully established,
+ // which could lead to hundreds of connections (and threads with the default executor)
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ return timeoutFailedFuture(query, "no time left to acquire lock for first request", null);
+ }
+
+ return initialRequestLock
+ .acquire(remainingTimeout, query.getHeader().getID(), executor)
+ .handle(
+ (initialRequestPermit, initialRequestEx) -> {
+ if (initialRequestEx != null) {
+ return this.timeoutFailedFuture(query, initialRequestEx);
+ } else {
+ return sendAsyncWithInitialRequestPermit(
+ query, executor, startTime, requestBuilder, initialRequestPermit);
+ }
+ })
+ .thenCompose(Function.identity());
+ }
+
+ private CompletionStage sendAsyncWithInitialRequestPermit(
+ Message query,
+ Executor executor,
+ long startTime,
+ HttpRequest.Builder requestBuilder,
+ Permit initialRequestPermit) {
+ int queryId = query.getHeader().getID();
+ long lastRequestTime = lastRequest.get();
+ long requestDeltaNanos = getNanoTime() - lastRequestTime;
+ boolean isInitialRequest =
+ lastRequestTime == 0 || idleConnectionTimeout.toNanos() < requestDeltaNanos;
+ if (!isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+
+ // check if we already exceeded the query timeout while checking the initial connection
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+
+ return timeoutFailedFuture(
+ query, "no time left to acquire lock for concurrent request", null);
+ }
+
+ // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the
+ // request, but fail with an IOException which also CLOSES the connection... *facepalm*
+ return maxConcurrentRequests
+ .acquire(remainingTimeout, queryId, executor)
+ .handle(
+ (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> {
+ if (maxConcurrentRequestEx != null) {
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+ return this.timeoutFailedFuture(
+ query,
+ "timed out waiting for a concurrent request lease",
+ maxConcurrentRequestEx);
+ } else {
+ return sendAsyncWithConcurrentRequestPermit(
+ query,
+ executor,
+ startTime,
+ requestBuilder,
+ initialRequestPermit,
+ isInitialRequest,
+ maxConcurrentRequestPermit);
+ }
+ })
+ .thenCompose(Function.identity());
+ }
+
+ private CompletionStage sendAsyncWithConcurrentRequestPermit(
+ Message query,
+ Executor executor,
+ long startTime,
+ HttpRequest.Builder requestBuilder,
+ Permit initialRequestPermit,
+ boolean isInitialRequest,
+ Permit maxConcurrentRequestPermit) {
+ int queryId = query.getHeader().getID();
+
+ // check if the stream lock acquisition took too long
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+
+ maxConcurrentRequestPermit.release(queryId, executor);
+ return timeoutFailedFuture(
+ query, "no time left to acquire lock for concurrent request", null);
+ }
+
+ var httpRequest = requestBuilder.timeout(remainingTimeout).build();
+ var bodyHandler = HttpResponse.BodyHandlers.ofByteArray();
+ return getHttpClient(executor)
+ .sendAsync(httpRequest, bodyHandler)
+ .whenComplete(
+ (result, ex) -> {
+ if (ex == null) {
+ lastRequest.set(startTime);
+ }
+ maxConcurrentRequestPermit.release(queryId, executor);
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+ })
+ .handleAsync(
+ (response, ex) -> {
+ if (ex != null) {
+ if (ex instanceof HttpTimeoutException) {
+ return this.timeoutFailedFuture(
+ query, "http request did not complete", ex.getCause());
+ } else {
+ return CompletableFuture.failedFuture(ex);
+ }
+ } else {
+ try {
+ Message responseMessage;
+ int rc = response.statusCode();
+ if (rc >= 200 && rc < 300) {
+ byte[] responseBytes = response.body();
+ responseMessage = new Message(responseBytes);
+ verifyTSIG(query, responseMessage, responseBytes, tsig);
+ } else {
+ responseMessage = new Message();
+ responseMessage.getHeader().setRcode(Rcode.SERVFAIL);
+ }
+
+ responseMessage.setResolver(this);
+ return CompletableFuture.completedFuture(responseMessage);
+ } catch (IOException e) {
+ return CompletableFuture.failedFuture(e);
+ }
+ }
+ },
+ executor)
+ .thenCompose(Function.identity())
+ .orTimeout(remainingTimeout.toMillis(), TimeUnit.MILLISECONDS)
+ .exceptionally(
+ ex -> {
+ if (ex instanceof TimeoutException) {
+ throw new CompletionException(
+ new TimeoutException(
+ "Query "
+ + query.getHeader().getID()
+ + " for "
+ + query.getQuestion().getName()
+ + "/"
+ + Type.string(query.getQuestion().getType())
+ + " timed out in remaining "
+ + remainingTimeout.toMillis()
+ + "ms"));
+ } else if (ex instanceof CompletionException) {
+ throw (CompletionException) ex;
+ }
+
+ throw new CompletionException(ex);
+ });
+ }
+
+ @Override
+ protected CompletableFuture failedFuture(Throwable e) {
+ return CompletableFuture.failedFuture(e);
+ }
+}
diff --git a/src/test/java/org/xbill/DNS/DohResolverTest.java b/src/test/java/org/xbill/DNS/DohResolverTest.java
index 26daf19c..6d1500d6 100644
--- a/src/test/java/org/xbill/DNS/DohResolverTest.java
+++ b/src/test/java/org/xbill/DNS/DohResolverTest.java
@@ -28,8 +28,8 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -38,8 +38,8 @@
import org.mockito.stubbing.Answer;
@ExtendWith(VertxExtension.class)
+@Slf4j
class DohResolverTest {
- private DohResolver resolver;
private final Name queryName = Name.fromConstantString("example.com.");
private final Record qr = Record.newRecord(queryName, Type.A, DClass.IN);
private final Message qm = Message.newQuery(qr);
@@ -48,7 +48,6 @@ class DohResolverTest {
@BeforeEach
void beforeEach() throws UnknownHostException {
- resolver = new DohResolver("http://localhost");
Record ar =
new ARecord(
Name.fromConstantString("example.com."),
@@ -59,11 +58,16 @@ void beforeEach() throws UnknownHostException {
a.addRecord(ar, Section.ANSWER);
}
+ private DohResolver getResolver() {
+ return new DohResolver("http://localhost");
+ }
+
@ParameterizedTest
@ValueSource(booleans = {false, true})
void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = getResolver();
resolver.setUsePost(usePost);
- setupResolverWithServer(Duration.ZERO, 200, 1, vertx, context)
+ setupResolverWithServer(resolver, Duration.ZERO, 200, 1, vertx, context)
.onSuccess(
server ->
Future.fromCompletionStage(resolver.sendAsync(qm))
@@ -79,10 +83,13 @@ void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
}))));
}
- @Test
- void timeoutResolve(Vertx vertx, VertxTestContext context) {
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void timeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = getResolver();
resolver.setTimeout(Duration.ofSeconds(1));
- setupResolverWithServer(Duration.ofSeconds(5), 200, 1, vertx, context)
+ resolver.setUsePost(usePost);
+ setupResolverWithServer(resolver, Duration.ofSeconds(5), 200, 1, vertx, context)
.onSuccess(
server ->
Future.fromCompletionStage(resolver.sendAsync(qm))
@@ -98,9 +105,12 @@ void timeoutResolve(Vertx vertx, VertxTestContext context) {
}))));
}
- @Test
- void servfailResolve(Vertx vertx, VertxTestContext context) {
- setupResolverWithServer(Duration.ZERO, 301, 1, vertx, context)
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void servfailResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = getResolver();
+ resolver.setUsePost(usePost);
+ setupResolverWithServer(resolver, Duration.ZERO, 301, 1, vertx, context)
.onSuccess(
server ->
Future.fromCompletionStage(resolver.sendAsync(qm))
@@ -114,12 +124,14 @@ void servfailResolve(Vertx vertx, VertxTestContext context) {
}))));
}
- @Test
- void limitRequestsResolve(Vertx vertx, VertxTestContext context) {
- resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2));
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void limitRequestsResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2));
+ resolver.setUsePost(usePost);
int requests = 100;
Checkpoint cpPass = context.checkpoint(requests);
- setupResolverWithServer(Duration.ofMillis(100), 200, 5, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofMillis(100), 200, 5, vertx, context)
.onSuccess(
server -> {
for (int i = 0; i < requests; i++) {
@@ -137,13 +149,15 @@ void limitRequestsResolve(Vertx vertx, VertxTestContext context) {
});
}
- @Test
- void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) {
- resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void initialRequestSlowResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ resolver.setUsePost(usePost);
int requests = 20;
allRequestsUseTimeout = false;
Checkpoint cpPass = context.checkpoint(requests);
- setupResolverWithServer(Duration.ofSeconds(1), 200, 2, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofSeconds(1), 200, 2, vertx, context)
.onSuccess(
server -> {
for (int i = 0; i < requests; i++) {
@@ -161,19 +175,23 @@ void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) {
});
}
- @Test
- void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
- resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void initialRequestTimeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ resolver.setUsePost(usePost);
resolver.setTimeout(Duration.ofSeconds(1));
int requests = 20;
allRequestsUseTimeout = false;
Checkpoint cpPass = context.checkpoint(requests - 1);
Checkpoint cpFail = context.checkpoint();
- setupResolverWithServer(Duration.ofSeconds(2), 200, 2, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofSeconds(2), 200, 2, vertx, context)
.onSuccess(
server -> {
+ Message q = qm.clone();
+ q.getHeader().setID(0);
resolver
- .sendAsync(qm)
+ .sendAsync(q)
.whenComplete(
(result, ex) -> {
if (ex == null) {
@@ -185,9 +203,11 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
vertx.setTimer(
1000,
timer -> {
- for (int i = 0; i < requests - 1; i++) {
+ for (int i = 1; i < requests; i++) {
+ Message qq = qm.clone();
+ qq.getHeader().setID(i);
resolver
- .sendAsync(qm)
+ .sendAsync(qq)
.whenComplete(
(result, ex) -> {
if (ex == null) {
@@ -202,6 +222,7 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
}
private Future setupResolverWithServer(
+ DohResolver resolver,
Duration responseDelay,
int statusCode,
int maxConcurrentRequests,
@@ -214,12 +235,14 @@ private Future setupResolverWithServer(
@EnabledForJreRange(
min = JRE.JAVA_9,
disabledReason = "Java 8 implementation doesn't have the initial request guard")
- @Test
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
- Vertx vertx, VertxTestContext context) {
+ boolean usePost, Vertx vertx, VertxTestContext context) {
AtomicLong startNanos = new AtomicLong(System.nanoTime());
- resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2)));
+ DohResolver resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2)));
resolver.setTimeout(Duration.ofSeconds(1));
+ resolver.setUsePost(usePost);
// Simulate a nanoTime value that is lower than the idle timeout
doAnswer((Answer) invocationOnMock -> System.nanoTime() - startNanos.get())
.when(resolver)
@@ -243,11 +266,13 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
AtomicBoolean firstCallCompleted = new AtomicBoolean(false);
- setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofMillis(100L), 200, 2, vertx, context)
.onSuccess(
server -> {
// First call
- CompletionStage firstCall = resolver.sendAsync(qm);
+ CompletionStage firstCall =
+ resolver.sendAsync(qm).whenComplete((msg, ex) -> firstCallCompleted.set(true));
+
// Ensure second call was made after first call and uses a different query
startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20));
CompletionStage secondCall = resolver.sendAsync(Message.newQuery(qr));
@@ -261,7 +286,6 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
assertEquals(0, result.getHeader().getID());
assertEquals(queryName, result.getQuestion().getName());
- firstCallCompleted.set(true);
})));
Future.fromCompletionStage(secondCall)
@@ -302,10 +326,12 @@ private Future setupServer(
int thisRequestNum = requestCount.incrementAndGet();
int count = concurrentRequests.incrementAndGet();
if (count > maxConcurrentRequests) {
- context.failNow("Concurrent requests exceeded");
+ context.failNow(
+ "Concurrent requests exceeded: " + count + " > " + maxConcurrentRequests);
return;
}
+ httpRequest.endHandler(v -> concurrentRequests.decrementAndGet());
httpRequest.bodyHandler(
body -> {
context.verify(
@@ -332,15 +358,12 @@ private Future setupServer(
&& (thisRequestNum == 1 || allRequestsUseTimeout)) {
vertx.setTimer(
serverProcessingTime.toMillis(),
- timer -> {
- concurrentRequests.decrementAndGet();
- httpRequest
- .response()
- .setStatusCode(statusCode)
- .end(Buffer.buffer(dnsResponseCopy.toWire()));
- });
+ timer ->
+ httpRequest
+ .response()
+ .setStatusCode(statusCode)
+ .end(Buffer.buffer(dnsResponseCopy.toWire())));
} else {
- concurrentRequests.decrementAndGet();
httpRequest
.response()
.setStatusCode(statusCode)