Skip to content
Prev Previous commit
Next Next commit
Update NioUdpClient.java
Add support to wait for multiple answers to arrive if and only if the address of the DNS server is a multicast address.
  • Loading branch information
ka2ddo authored Jan 6, 2021
commit fce0cf8b18dbddc65b8613cdd319dd3c281fa7df
100 changes: 86 additions & 14 deletions src/main/java/org/xbill/DNS/NioUdpClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
Expand All @@ -12,6 +13,7 @@
import java.nio.channels.Selector;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -71,8 +73,7 @@ private static void checkTransactionTimeouts() {
for (Iterator<Transaction> it = pendingTransactions.iterator(); it.hasNext(); ) {
Transaction t = it.next();
if (t.endTime - System.nanoTime() < 0) {
t.silentCloseChannel();
t.f.completeExceptionally(new SocketTimeoutException("Query timed out"));
t.closeTransaction();
it.remove();
}
}
Expand All @@ -81,19 +82,20 @@ private static void checkTransactionTimeouts() {
@RequiredArgsConstructor
private static class Transaction implements KeyProcessor {
private final byte[] data;
private final int max;
final int max;
private final long endTime;
private final DatagramChannel channel;
private final CompletableFuture<byte[]> f;
private final SocketAddress remoteSocketAddress;
final CompletableFuture<Object> f;

void send() throws IOException {
ByteBuffer buffer = ByteBuffer.wrap(data);
verboseLog(
"UDP write",
channel.socket().getLocalSocketAddress(),
channel.socket().getRemoteSocketAddress(),
remoteSocketAddress,
data);
int n = channel.send(buffer, channel.socket().getRemoteSocketAddress());
int n = channel.send(buffer, remoteSocketAddress);
if (n <= 0) {
throw new EOFException();
}
Expand All @@ -109,10 +111,12 @@ public void processReadyKey(SelectionKey key) {

DatagramChannel channel = (DatagramChannel) key.channel();
ByteBuffer buffer = ByteBuffer.allocate(max);
SocketAddress source;
int read;
try {
read = channel.read(buffer);
if (read <= 0) {
source = channel.receive(buffer);
read = buffer.position();
if (read <= 0 || source == null) {
throw new EOFException();
}
} catch (IOException e) {
Expand All @@ -128,26 +132,88 @@ public void processReadyKey(SelectionKey key) {
verboseLog(
"UDP read",
channel.socket().getLocalSocketAddress(),
channel.socket().getRemoteSocketAddress(),
remoteSocketAddress,
data);
silentCloseChannel();
f.complete(data);
pendingTransactions.remove(this);
}

private void silentCloseChannel() {
void silentCloseChannel() {
try {
channel.disconnect();
channel.close();
} catch (IOException e) {
// ignore, we either already have everything we need or can't do anything
}
}

void closeTransaction() {
silentCloseChannel();
f.completeExceptionally(new SocketTimeoutException("Query timed out"));
}
}

private static class MultiAnswerTransaction extends Transaction {
MultiAnswerTransaction(byte[] query, int max, long endTime, DatagramChannel channel,
SocketAddress remoteSocketAddress,
CompletableFuture<Object> f) {
super(query, max, endTime, channel, remoteSocketAddress, f);
}

public void processReadyKey(SelectionKey key) {
if (!key.isReadable()) {
silentCloseChannel();
f.completeExceptionally(new EOFException("channel not readable"));
pendingTransactions.remove(this);
return;
}

DatagramChannel channel = (DatagramChannel) key.channel();
ByteBuffer buffer = ByteBuffer.allocate(max);
SocketAddress source;
int read;
try {
source = channel.receive(buffer);
read = buffer.position();
if (read <= 0 || source == null) {
return; // ignore this datagram
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of code duplication. The return here vs. EOFException could be:

if (remoteSocketAddress.getAddress().isMulticast()) {
  return;
}
throw new EOFException();

}
} catch (IOException e) {
silentCloseChannel();
f.completeExceptionally(e);
pendingTransactions.remove(this);
return;
}

buffer.flip();
byte[] data = new byte[read];
System.arraycopy(buffer.array(), 0, data, 0, read);
verboseLog(
"UDP read",
channel.socket().getLocalSocketAddress(),
source,
data);
answers.add(data);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only other line in this method that differs for mDNS. Instead of duplicating everything, move the processing of the data byte-array into a separate method, e.g.

// in Transaction:
...
  System.arraycopy(buffer.array(), 0, data, 0, read);
  verboseLog("UDP read", channel.socket().getLocalSocketAddress(), source, data);
  processAnswer(data);
}

processAnswer(byte[] data) {
  silentCloseChannel();
  f.complete(Collections.singletonList(data));
  pendingTransactions.remove(this);
}

// in MultiAnswerTransaction:
processAnswer(byte[] data) {
  answers.add(data);
}

}

private ArrayList<byte[]> answers = new ArrayList<>();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to the top of the class definition and make it final, fields randomly in the middle of classes is weird.


@Override
void closeTransaction() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

closeTransaction is only triggered when the timeout has expired. Is this really the only condition on which you want to return answers? What about the case where one is only interested in the first answer, as fast as possible?

The timeout processing currently only runs every second. Is that still sufficient?

if (answers.size() > 0) {
silentCloseChannel();
f.complete(answers);
} else {
// we failed, no answers
super.closeTransaction();
}
}
}

static CompletableFuture<byte[]> sendrecv(
static CompletableFuture<Object> sendrecv(
InetSocketAddress local, InetSocketAddress remote, byte[] data, int max, Duration timeout) {
CompletableFuture<byte[]> f = new CompletableFuture<>();
CompletableFuture<Object> f = new CompletableFuture<>();
try {
final Selector selector = selector();
DatagramChannel channel = DatagramChannel.open();
Expand Down Expand Up @@ -185,9 +251,15 @@ static CompletableFuture<byte[]> sendrecv(
}
}

channel.connect(remote);
long endTime = System.nanoTime() + timeout.toNanos();
Transaction t = new Transaction(data, max, endTime, channel, f);
Transaction t;
if (!remote.getAddress().isMulticastAddress()) {
channel.connect(remote);
t = new Transaction(data, max, endTime, channel, f);
} else {
// stop this a little before the timeout so we can report what answers we did get
t = new MultiAnswerTransaction(data, max, endTime - 1000000000L, channel, f);
}
pendingTransactions.add(t);
registrationQueue.add(t);
selector.wakeup();
Expand Down