Skip to content

Commit e1e5c0c

Browse files
committed
1. Fix 'initialRequest' guard might be incorrect if the initial value of 'idleConnectionTimeout' is set to a value larger than current (and later for a period) System.nanoTime().
2. Fix race condition on setting 'lastRequest' timestamp among concurrent requests, ensuring its value is always be monotonic.
1 parent a6371af commit e1e5c0c

File tree

2 files changed

+135
-8
lines changed

2 files changed

+135
-8
lines changed

src/main/java/org/xbill/DNS/DohResolver.java

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.concurrent.Executor;
2424
import java.util.concurrent.ForkJoinPool;
2525
import java.util.concurrent.TimeUnit;
26+
import java.util.concurrent.atomic.AtomicBoolean;
2627
import java.util.concurrent.atomic.AtomicLong;
2728
import java.util.function.Function;
2829
import javax.net.ssl.HttpsURLConnection;
@@ -91,6 +92,8 @@ public final class DohResolver implements Resolver {
9192
private final AsyncSemaphore maxConcurrentRequests;
9293

9394
private final AtomicLong lastRequest = new AtomicLong(0);
95+
96+
private final AtomicBoolean initialRequestSentMark = new AtomicBoolean(false);
9497
private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1);
9598

9699
private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
@@ -469,15 +472,28 @@ private CompletionStage<Message> sendAsync11(final Message query, Executor execu
469472
.thenCompose(Function.identity());
470473
}
471474

475+
/**
476+
* Check whether current initiating DoH request is initial request of this {@link DohResolver}.
477+
*/
478+
private boolean checkInitialRequest() {
479+
// If initial request haven't been completed successfully yet, just return true.
480+
if (!initialRequestSentMark.get()) {
481+
return true;
482+
}
483+
484+
// Otherwise, check whether such request is happened
485+
// after last successful request plus idle connection timeout.
486+
long lastRequestTime = lastRequest.get();
487+
return (lastRequestTime + idleConnectionTimeout.toNanos() < System.nanoTime());
488+
}
489+
472490
private CompletionStage<Message> sendAsync11WithInitialRequestPermit(
473491
Message query,
474492
Executor executor,
475493
long startTime,
476494
Object requestBuilder,
477495
Permit initialRequestPermit) {
478-
long lastRequestTime = lastRequest.get();
479-
boolean isInitialRequest =
480-
(lastRequestTime < System.nanoTime() - idleConnectionTimeout.toNanos());
496+
boolean isInitialRequest = checkInitialRequest();
481497
if (!isInitialRequest) {
482498
initialRequestPermit.release();
483499
}
@@ -516,6 +532,25 @@ private CompletionStage<Message> sendAsync11WithInitialRequestPermit(
516532
.thenCompose(Function.identity());
517533
}
518534

535+
/**
536+
* Set last request time to {@link DohResolver#lastRequest}, which ensures only the largest timestamp could be accepted.
537+
*
538+
* @param startTime start time in nanos of a Doh request.
539+
*/
540+
private void setLastRequestTime(long startTime) {
541+
long current = lastRequest.get();
542+
// Only update value of 'lastRequest' if timestamp in 'lastRequest' is smaller than incoming 'startTime' value.
543+
if (current < startTime) {
544+
while (!lastRequest.compareAndSet(current, startTime)) {
545+
// CAS failed, re-verify the eligibility of timestamp in 'lastRequest' to be updated to the incoming 'startTime' value.
546+
current = lastRequest.get();
547+
if (current > startTime) {
548+
return;
549+
}
550+
}
551+
}
552+
}
553+
519554
private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(
520555
Message query,
521556
Executor executor,
@@ -548,7 +583,12 @@ private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(
548583
.whenComplete(
549584
(result, ex) -> {
550585
if (ex == null) {
551-
lastRequest.set(startTime);
586+
setLastRequestTime(startTime);
587+
if (isInitialRequest) {
588+
// initial request was completed successfully, so toggle initialRequestSentMark to true.
589+
// it's very safe to toggle initialRequestSentMark here, since this code had been guarded by initialRequestLock and its permit outside.
590+
initialRequestSentMark.compareAndSet(false, true);
591+
}
552592
}
553593
maxConcurrentRequestPermit.release();
554594
if (isInitialRequest) {

src/test/java/org/xbill/DNS/DohResolverTest.java

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// SPDX-License-Identifier: BSD-3-Clause
22
package org.xbill.DNS;
33

4-
import static org.junit.jupiter.api.Assertions.assertEquals;
5-
import static org.junit.jupiter.api.Assertions.assertTrue;
6-
74
import io.netty.handler.codec.http.HttpHeaderNames;
85
import io.vertx.core.Future;
96
import io.vertx.core.Vertx;
@@ -20,14 +17,18 @@
2017
import java.time.Duration;
2118
import java.util.Base64;
2219
import java.util.Collections;
20+
import java.util.concurrent.CompletionStage;
2321
import java.util.concurrent.TimeoutException;
22+
import java.util.concurrent.atomic.AtomicBoolean;
2423
import java.util.concurrent.atomic.AtomicInteger;
2524
import org.junit.jupiter.api.BeforeEach;
2625
import org.junit.jupiter.api.Test;
2726
import org.junit.jupiter.api.extension.ExtendWith;
2827
import org.junit.jupiter.params.ParameterizedTest;
2928
import org.junit.jupiter.params.provider.ValueSource;
3029

30+
import static org.junit.jupiter.api.Assertions.*;
31+
3132
@ExtendWith(VertxExtension.class)
3233
class DohResolverTest {
3334
private DohResolver resolver;
@@ -152,6 +153,88 @@ void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) {
152153
});
153154
}
154155

156+
157+
@Test
158+
void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(Vertx vertx, VertxTestContext context) {
159+
if (isPreJava9()) {
160+
System.out.println("Current JVM is PreJava9, no need to run such test.");
161+
context.completeNow();
162+
return;
163+
}
164+
resolver = new DohResolver("http://localhost",
165+
2,
166+
// so long idleConnectionTimeout
167+
// in order to hack the condition for checking initial request in org.xbill.DNS.DohResolver.checkInitialRequest
168+
Duration.ofNanos(System.nanoTime() + Duration.ofSeconds(100L).toNanos()));
169+
resolver.setTimeout(Duration.ofSeconds(1));
170+
// Just add a 100ms delay before responding to the 1st call
171+
// to simulate a 'concurrent doh request' for the 2nd call,
172+
// then let the fake dns server respond to the 2nd call ASAP.
173+
allRequestsUseTimeout = false;
174+
175+
// idleConnectionTimeout = 2s, lastRequest = 0L
176+
// Ensure lastRequest + idleConnectionTimeout < System.nanoTime() (3s)
177+
178+
// Timeline:
179+
// |<-------- 100ms -------->|
180+
// ↑ ↑
181+
// 1st call sent response of 1st call
182+
// |20ms|<------ 80ms ------>|<------ few millis ------->|
183+
// ↑ wait until 1st call ↑ ↑
184+
// 2nd call begin 2nd call sent response of 2nd call
185+
186+
AtomicBoolean firstCallCompleted = new AtomicBoolean(false);
187+
188+
setupResolverWithServer(Duration.ofMillis(100L),
189+
200,
190+
2,
191+
vertx,
192+
context)
193+
.onSuccess(
194+
server -> {
195+
// First call
196+
CompletionStage<Message> firstCall = resolver.sendAsync(qm);
197+
// Ensure second call was made after first call.
198+
sleepNotThrown(20L);
199+
CompletionStage<Message> secondCall = resolver.sendAsync(qm);
200+
201+
Future.fromCompletionStage(firstCall)
202+
.onComplete(
203+
context.succeeding(
204+
result ->
205+
context.verify(
206+
() -> {
207+
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
208+
assertEquals(0, result.getHeader().getID());
209+
assertEquals(queryName, result.getQuestion().getName());
210+
firstCallCompleted.set(true);
211+
})));
212+
213+
Future.fromCompletionStage(secondCall)
214+
.onComplete(
215+
context.succeeding(
216+
result ->
217+
context.verify(
218+
() -> {
219+
assertTrue(firstCallCompleted.get());
220+
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
221+
assertEquals(0, result.getHeader().getID());
222+
assertEquals(queryName, result.getQuestion().getName());
223+
// Complete context after the 2nd call was completed.
224+
context.completeNow();
225+
})));
226+
}
227+
);
228+
}
229+
230+
private static void sleepNotThrown(long millis) {
231+
try {
232+
Thread.sleep(millis);
233+
} catch (InterruptedException e) {
234+
throw new RuntimeException(e);
235+
}
236+
}
237+
155238
@Test
156239
void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
157240
resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
@@ -192,6 +275,10 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
192275
});
193276
}
194277

278+
private static boolean isPreJava9() {
279+
return System.getProperty("java.version").startsWith("1.");
280+
}
281+
195282
private Future<HttpServer> setupResolverWithServer(
196283
Duration responseDelay,
197284
int statusCode,
@@ -211,7 +298,7 @@ private Future<HttpServer> setupServer(
211298
VertxTestContext context,
212299
Vertx vertx) {
213300
HttpVersion version =
214-
System.getProperty("java.version").startsWith("1.")
301+
isPreJava9()
215302
? HttpVersion.HTTP_1_1
216303
: HttpVersion.HTTP_2;
217304
AtomicInteger requestCount = new AtomicInteger(0);

0 commit comments

Comments
 (0)