Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
1. Fix 'initialRequest' guard might be incorrect if the initial value…
… of 'idleConnectionTimeout' is set to a value larger than current (and later for a period) System.nanoTime().

2. Fix race condition on setting 'lastRequest' timestamp among concurrent requests, ensuring its value is always be monotonic.
  • Loading branch information
LinZong authored and ibauersachs committed Jan 5, 2025
commit 265adacdf7202e65632e28c3137adfed5175f45c
48 changes: 44 additions & 4 deletions src/main/java/org/xbill/DNS/DohResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import javax.net.ssl.HttpsURLConnection;
Expand Down Expand Up @@ -91,6 +92,8 @@ public final class DohResolver implements Resolver {
private final AsyncSemaphore maxConcurrentRequests;

private final AtomicLong lastRequest = new AtomicLong(0);

private final AtomicBoolean initialRequestSentMark = new AtomicBoolean(false);
private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1);

private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
Expand Down Expand Up @@ -469,15 +472,28 @@ private CompletionStage<Message> sendAsync11(final Message query, Executor execu
.thenCompose(Function.identity());
}

/**
* Check whether current initiating DoH request is initial request of this {@link DohResolver}.
*/
private boolean checkInitialRequest() {
// If initial request haven't been completed successfully yet, just return true.
if (!initialRequestSentMark.get()) {
return true;
}

// Otherwise, check whether such request is happened
// after last successful request plus idle connection timeout.
long lastRequestTime = lastRequest.get();
return (lastRequestTime + idleConnectionTimeout.toNanos() < System.nanoTime());
}

private CompletionStage<Message> sendAsync11WithInitialRequestPermit(
Message query,
Executor executor,
long startTime,
Object requestBuilder,
Permit initialRequestPermit) {
long lastRequestTime = lastRequest.get();
boolean isInitialRequest =
(lastRequestTime < System.nanoTime() - idleConnectionTimeout.toNanos());
boolean isInitialRequest = checkInitialRequest();
if (!isInitialRequest) {
initialRequestPermit.release();
}
Expand Down Expand Up @@ -516,6 +532,25 @@ private CompletionStage<Message> sendAsync11WithInitialRequestPermit(
.thenCompose(Function.identity());
}

/**
* Set last request time to {@link DohResolver#lastRequest}, which ensures only the largest timestamp could be accepted.
*
* @param startTime start time in nanos of a Doh request.
*/
private void setLastRequestTime(long startTime) {
long current = lastRequest.get();
// Only update value of 'lastRequest' if timestamp in 'lastRequest' is smaller than incoming 'startTime' value.
if (current < startTime) {
while (!lastRequest.compareAndSet(current, startTime)) {
// CAS failed, re-verify the eligibility of timestamp in 'lastRequest' to be updated to the incoming 'startTime' value.
current = lastRequest.get();
if (current > startTime) {
return;
}
}
}
}

private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(
Message query,
Executor executor,
Expand Down Expand Up @@ -548,7 +583,12 @@ private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(
.whenComplete(
(result, ex) -> {
if (ex == null) {
lastRequest.set(startTime);
setLastRequestTime(startTime);
if (isInitialRequest) {
// initial request was completed successfully, so toggle initialRequestSentMark to true.
// it's very safe to toggle initialRequestSentMark here, since this code had been guarded by initialRequestLock and its permit outside.
initialRequestSentMark.compareAndSet(false, true);
}
}
maxConcurrentRequestPermit.release();
if (isInitialRequest) {
Expand Down
95 changes: 91 additions & 4 deletions src/test/java/org/xbill/DNS/DohResolverTest.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// SPDX-License-Identifier: BSD-3-Clause
package org.xbill.DNS;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.netty.handler.codec.http.HttpHeaderNames;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
Expand All @@ -20,14 +17,18 @@
import java.time.Duration;
import java.util.Base64;
import java.util.Collections;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import static org.junit.jupiter.api.Assertions.*;

@ExtendWith(VertxExtension.class)
class DohResolverTest {
private DohResolver resolver;
Expand Down Expand Up @@ -152,6 +153,88 @@ void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) {
});
}


@Test
void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(Vertx vertx, VertxTestContext context) {
if (isPreJava9()) {
System.out.println("Current JVM is PreJava9, no need to run such test.");
context.completeNow();
return;
}
resolver = new DohResolver("http://localhost",
2,
// so long idleConnectionTimeout
// in order to hack the condition for checking initial request in org.xbill.DNS.DohResolver.checkInitialRequest
Duration.ofNanos(System.nanoTime() + Duration.ofSeconds(100L).toNanos()));
resolver.setTimeout(Duration.ofSeconds(1));
// Just add a 100ms delay before responding to the 1st call
// to simulate a 'concurrent doh request' for the 2nd call,
// then let the fake dns server respond to the 2nd call ASAP.
allRequestsUseTimeout = false;

// idleConnectionTimeout = 2s, lastRequest = 0L
// Ensure lastRequest + idleConnectionTimeout < System.nanoTime() (3s)

// Timeline:
// |<-------- 100ms -------->|
// ↑ ↑
// 1st call sent response of 1st call
// |20ms|<------ 80ms ------>|<------ few millis ------->|
// ↑ wait until 1st call ↑ ↑
// 2nd call begin 2nd call sent response of 2nd call

AtomicBoolean firstCallCompleted = new AtomicBoolean(false);

setupResolverWithServer(Duration.ofMillis(100L),
200,
2,
vertx,
context)
.onSuccess(
server -> {
// First call
CompletionStage<Message> firstCall = resolver.sendAsync(qm);
// Ensure second call was made after first call.
sleepNotThrown(20L);
CompletionStage<Message> secondCall = resolver.sendAsync(qm);

Future.fromCompletionStage(firstCall)
.onComplete(
context.succeeding(
result ->
context.verify(
() -> {
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
assertEquals(0, result.getHeader().getID());
assertEquals(queryName, result.getQuestion().getName());
firstCallCompleted.set(true);
})));

Future.fromCompletionStage(secondCall)
.onComplete(
context.succeeding(
result ->
context.verify(
() -> {
assertTrue(firstCallCompleted.get());
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
assertEquals(0, result.getHeader().getID());
assertEquals(queryName, result.getQuestion().getName());
// Complete context after the 2nd call was completed.
context.completeNow();
})));
}
);
}

private static void sleepNotThrown(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

@Test
void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
Expand Down Expand Up @@ -192,6 +275,10 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
});
}

private static boolean isPreJava9() {
return System.getProperty("java.version").startsWith("1.");
}

private Future<HttpServer> setupResolverWithServer(
Duration responseDelay,
int statusCode,
Expand All @@ -211,7 +298,7 @@ private Future<HttpServer> setupServer(
VertxTestContext context,
Vertx vertx) {
HttpVersion version =
System.getProperty("java.version").startsWith("1.")
isPreJava9()
? HttpVersion.HTTP_1_1
: HttpVersion.HTTP_2;
AtomicInteger requestCount = new AtomicInteger(0);
Expand Down