diff --git a/src/main/java/org/xbill/DNS/TCPClient.java b/src/main/java/org/xbill/DNS/TCPClient.java
index 7f65f439..6151e1b9 100644
--- a/src/main/java/org/xbill/DNS/TCPClient.java
+++ b/src/main/java/org/xbill/DNS/TCPClient.java
@@ -14,7 +14,7 @@
import java.time.Duration;
import java.time.temporal.ChronoUnit;
-final class TCPClient {
+class TCPClient implements AutoCloseable {
private final long startTime;
private final Duration timeout;
private final SelectionKey key;
@@ -144,7 +144,7 @@ private void blockUntil(SelectionKey key) throws IOException {
}
}
- void cleanup() throws IOException {
+ public void close() throws IOException {
key.selector().close();
key.channel().close();
}
diff --git a/src/main/java/org/xbill/DNS/TSIG.java b/src/main/java/org/xbill/DNS/TSIG.java
index aabd364b..29ecbea1 100644
--- a/src/main/java/org/xbill/DNS/TSIG.java
+++ b/src/main/java/org/xbill/DNS/TSIG.java
@@ -15,6 +15,7 @@
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
+import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.xbill.DNS.utils.base64;
import org.xbill.DNS.utils.hexdump;
@@ -369,28 +370,35 @@ public TSIGRecord generate(Message m, byte[] b, int error, TSIGRecord old) {
*/
public TSIGRecord generate(
Message m, byte[] b, int error, TSIGRecord old, boolean fullSignature) {
- Instant timeSigned;
- if (error == Rcode.BADTIME) {
- timeSigned = old.getTimeSigned();
- } else {
- timeSigned = clock.instant();
- }
-
- boolean signing = false;
Mac hmac = null;
if (error == Rcode.NOERROR || error == Rcode.BADTIME || error == Rcode.BADTRUNC) {
- signing = true;
hmac = initHmac();
}
- Duration fudge;
- int fudgeOption = Options.intValue("tsigfudge");
- if (fudgeOption < 0 || fudgeOption > 0x7FFF) {
- fudge = FUDGE;
- } else {
- fudge = Duration.ofSeconds(fudgeOption);
- }
+ return generate(m, b, error, old, fullSignature, hmac);
+ }
+ /**
+ * Generates a TSIG record with a specific error for a message that has been rendered.
+ *
+ * @param m The message
+ * @param b The rendered message
+ * @param error The error
+ * @param old If this message is a response, the TSIG from the request
+ * @param fullSignature {@code true} if this {@link TSIGRecord} is the to be added to the first of
+ * many messages in a TCP connection and all TSIG variables (rfc2845, 3.4.2.) should be
+ * included in the signature. {@code false} for subsequent messages with reduced TSIG
+ * variables set (rfc2845, 4.4.).
+ * @param hmac A mac instance to reuse for a stream of messages to sign, e.g. when doing a zone
+ * transfer.
+ * @return The TSIG record to be added to the message
+ */
+ private TSIGRecord generate(
+ Message m, byte[] b, int error, TSIGRecord old, boolean fullSignature, Mac hmac) {
+ Instant timeSigned = getTimeSigned(error, old);
+ Duration fudge = getTsigFudge();
+
+ boolean signing = hmac != null;
if (old != null && signing) {
hmacAddSignature(hmac, old);
}
@@ -413,7 +421,7 @@ public TSIGRecord generate(
alg.toWireCanonical(out);
}
- writeTsigTimersVariables(timeSigned, fudge, out);
+ writeTsigTimerVariables(timeSigned, fudge, out);
if (fullSignature) {
out.writeU16(error);
out.writeU16(0); /* No other data */
@@ -450,6 +458,15 @@ public TSIGRecord generate(
other);
}
+ private Instant getTimeSigned(int error, TSIGRecord old) {
+ return error == Rcode.BADTIME ? old.getTimeSigned() : clock.instant();
+ }
+
+ private static Duration getTsigFudge() {
+ int fudgeOption = Options.intValue("tsigfudge");
+ return fudgeOption < 0 || fudgeOption > 0x7FFF ? FUDGE : Duration.ofSeconds(fudgeOption);
+ }
+
/**
* Generates a TSIG record for a message and adds it to the message
*
@@ -522,6 +539,8 @@ public void applyStream(Message m, TSIGRecord old, boolean fullSignature) {
* TSIG is expected to be present, it is an error if one is not present. After calling this
* routine, Message.isVerified() may be called on this message.
*
+ * Use {@link StreamVerifier} to validate multiple messages in a stream.
+ *
* @param m The message
* @param b An array containing the message in unparsed form. This is necessary since TSIG signs
* the message in wire format, and we can't recreate the exact wire format (with the same name
@@ -542,6 +561,8 @@ public byte verify(Message m, byte[] b, int length, TSIGRecord old) {
* TSIG is expected to be present, it is an error if one is not present. After calling this
* routine, Message.isVerified() may be called on this message.
*
+ *
Use {@link StreamVerifier} to validate multiple messages in a stream.
+ *
* @param m The message to verify
* @param messageBytes An array containing the message in unparsed form. This is necessary since
* TSIG signs the message in wire format, and we can't recreate the exact wire format (with
@@ -559,6 +580,8 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG) {
* TSIG is expected to be present, it is an error if one is not present. After calling this
* routine, Message.isVerified() may be called on this message.
*
+ *
Use {@link StreamVerifier} to validate multiple messages in a stream.
+ *
* @param m The message to verify
* @param messageBytes An array containing the message in unparsed form. This is necessary since
* TSIG signs the message in wire format, and we can't recreate the exact wire format (with
@@ -572,6 +595,27 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG) {
* @since 3.2
*/
public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolean fullSignature) {
+ return verify(m, messageBytes, requestTSIG, fullSignature, null);
+ }
+
+ /**
+ * Verifies a TSIG record on an incoming message. Since this is only called in the context where a
+ * TSIG is expected to be present, it is an error if one is not present. After calling this
+ * routine, Message.isVerified() may be called on this message.
+ *
+ * @param m The message to verify
+ * @param messageBytes An array containing the message in unparsed form. This is necessary since
+ * TSIG signs the message in wire format, and we can't recreate the exact wire format (with
+ * the same name compression).
+ * @param requestTSIG If this message is a response, the TSIG from the request
+ * @param fullSignature {@code true} if this message is the first of many in a TCP connection and
+ * all TSIG variables (rfc2845, 3.4.2.) should be included in the signature. {@code false} for
+ * subsequent messages with reduced TSIG variables set (rfc2845, 4.4.).
+ * @return The result of the verification (as an Rcode)
+ * @see Rcode
+ */
+ private int verify(
+ Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolean fullSignature, Mac hmac) {
m.tsigState = Message.TSIG_FAILED;
TSIGRecord tsig = m.getTSIG();
if (tsig == null) {
@@ -589,7 +633,10 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea
return Rcode.BADKEY;
}
- Mac hmac = initHmac();
+ if (hmac == null) {
+ hmac = initHmac();
+ }
+
if (requestTSIG != null && tsig.getError() != Rcode.BADKEY && tsig.getError() != Rcode.BADSIG) {
hmacAddSignature(hmac, requestTSIG);
}
@@ -608,6 +655,27 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea
}
hmac.update(messageBytes, header.length, len);
+ byte[] tsigVariables = getTsigVariables(fullSignature, tsig);
+ hmac.update(tsigVariables);
+
+ byte[] signature = tsig.getSignature();
+ int badsig = verifySignature(hmac, signature);
+ if (badsig != Rcode.NOERROR) {
+ return badsig;
+ }
+
+ // validate time after the signature, as per
+ // https://www.rfc-editor.org/rfc/rfc8945.html#section-5.4
+ int badtime = verifyTime(tsig);
+ if (badtime != Rcode.NOERROR) {
+ return badtime;
+ }
+
+ m.tsigState = Message.TSIG_VERIFIED;
+ return Rcode.NOERROR;
+ }
+
+ private static byte[] getTsigVariables(boolean fullSignature, TSIGRecord tsig) {
DNSOutput out = new DNSOutput();
if (fullSignature) {
tsig.getName().toWireCanonical(out);
@@ -615,7 +683,7 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea
out.writeU32(tsig.ttl);
tsig.getAlgorithm().toWireCanonical(out);
}
- writeTsigTimersVariables(tsig.getTimeSigned(), tsig.getFudge(), out);
+ writeTsigTimerVariables(tsig.getTimeSigned(), tsig.getFudge(), out);
if (fullSignature) {
out.writeU16(tsig.getError());
if (tsig.getOther() != null) {
@@ -630,9 +698,10 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea
if (log.isTraceEnabled()) {
log.trace(hexdump.dump("TSIG-HMAC variables", tsigVariables));
}
- hmac.update(tsigVariables);
+ return tsigVariables;
+ }
- byte[] signature = tsig.getSignature();
+ private static int verifySignature(Mac hmac, byte[] signature) {
int digestLength = hmac.getMacLength();
// rfc4635#section-3.1, 4.:
@@ -662,9 +731,10 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea
return Rcode.BADSIG;
}
}
+ return Rcode.NOERROR;
+ }
- // validate time after the signature, as per
- // https://tools.ietf.org/html/draft-ietf-dnsop-rfc2845bis-08#section-5.4.3
+ private int verifyTime(TSIGRecord tsig) {
Instant now = clock.instant();
Duration delta = Duration.between(now, tsig.getTimeSigned()).abs();
if (delta.compareTo(tsig.getFudge()) > 0) {
@@ -675,8 +745,6 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea
tsig.getFudge());
return Rcode.BADTIME;
}
-
- m.tsigState = Message.TSIG_VERIFIED;
return Rcode.NOERROR;
}
@@ -706,7 +774,7 @@ private static void hmacAddSignature(Mac hmac, TSIGRecord tsig) {
hmac.update(tsig.getSignature());
}
- private static void writeTsigTimersVariables(Instant instant, Duration fudge, DNSOutput out) {
+ private static void writeTsigTimerVariables(Instant instant, Duration fudge, DNSOutput out) {
writeTsigTime(instant, out);
out.writeU16((int) fudge.getSeconds());
}
@@ -719,58 +787,198 @@ private static void writeTsigTime(Instant instant, DNSOutput out) {
out.writeU32(timeLow);
}
+ /**
+ * A utility class for generating signed message responses.
+ *
+ * @since 3.5.3
+ */
+ public static class StreamGenerator {
+ private final TSIG key;
+ private final Mac sharedHmac;
+ private final int signEveryNthMessage;
+
+ private int numGenerated;
+ private TSIGRecord lastTsigRecord;
+
+ /**
+ * Creates an instance to sign multiple message for use in a stream.
+ *
+ *
This class creates a {@link TSIGRecord} on every message to conform with RFC 8945, 5.3.1.
+ *
+ * @param key The TSIG key used to create the signature records.
+ * @param queryTsig The initial TSIG records, e.g. from a query to a server.
+ */
+ public StreamGenerator(TSIG key, TSIGRecord queryTsig) {
+ // The TSIG MUST be included on all DNS messages in the response.
+ this(key, queryTsig, 1);
+ }
+
+ /**
+ * This constructor is only for unit-testing {@link StreamVerifier} with responses where
+ * not every message is signed.
+ */
+ StreamGenerator(TSIG key, TSIGRecord queryTsig, int signEveryNthMessage) {
+ if (signEveryNthMessage < 1 || signEveryNthMessage > 100) {
+ throw new IllegalArgumentException("signEveryNthMessage must be between 1 and 100");
+ }
+
+ this.key = key;
+ this.lastTsigRecord = queryTsig;
+ this.signEveryNthMessage = signEveryNthMessage;
+ sharedHmac = this.key.initHmac();
+ }
+
+ /**
+ * Generate TSIG a signature for use of the message in a stream.
+ *
+ * @param message The message to sign.
+ */
+ public void generate(Message message) {
+ generate(message, true);
+ }
+
+ void generate(Message message, boolean isLastMessage) {
+ boolean isNthMessage = numGenerated % signEveryNthMessage == 0;
+ boolean isFirstMessage = numGenerated == 0;
+ if (isFirstMessage || isNthMessage || isLastMessage) {
+ TSIGRecord r =
+ key.generate(
+ message,
+ message.toWire(),
+ Rcode.NOERROR,
+ isFirstMessage ? lastTsigRecord : null,
+ isFirstMessage,
+ sharedHmac);
+ message.addRecord(r, Section.ADDITIONAL);
+ message.tsigState = Message.TSIG_SIGNED;
+ lastTsigRecord = r;
+ hmacAddSignature(sharedHmac, r);
+ } else {
+ byte[] responseBytes = message.toWire(Message.MAXLENGTH);
+ sharedHmac.update(responseBytes);
+ }
+
+ numGenerated++;
+ }
+ }
+
+ /** A utility class for verifying multiple message responses. */
public static class StreamVerifier {
- /** A helper class for verifying multiple message responses. */
private final TSIG key;
+ private final Mac sharedHmac;
+ private final TSIGRecord queryTsig;
private int nresponses;
private int lastsigned;
- private TSIGRecord lastTSIG;
+
+ /** {@code null} or the detailed error when validation failed due to a {@link Rcode#FORMERR}. */
+ @Getter private String errorMessage;
/** Creates an object to verify a multiple message response */
public StreamVerifier(TSIG tsig, TSIGRecord queryTsig) {
key = tsig;
+ sharedHmac = key.initHmac();
nresponses = 0;
- lastTSIG = queryTsig;
+ this.queryTsig = queryTsig;
}
/**
* Verifies a TSIG record on an incoming message that is part of a multiple message response.
* TSIG records must be present on the first and last messages, and at least every 100 records
- * in between. After calling this routine, Message.isVerified() may be called on this message.
+ * in between. After calling this routine,{@link Message#isVerified()} may be called on this
+ * message.
*
- * @param m The message
- * @param b The message in unparsed form
+ *
This overload assumes that the verified message is not the last one, which is required to
+ * have a {@link TSIGRecord}. Use {@link #verify(Message, byte[], boolean)} to explicitly
+ * specify the last message or check that the message is verified with {@link
+ * Message#isVerified()}.
+ *
+ * @param message The message
+ * @param messageBytes The message in unparsed form
* @return The result of the verification (as an Rcode)
* @see Rcode
*/
- public int verify(Message m, byte[] b) {
- TSIGRecord tsig = m.getTSIG();
+ public int verify(Message message, byte[] messageBytes) {
+ return verify(message, messageBytes, false);
+ }
+ /**
+ * Verifies a TSIG record on an incoming message that is part of a multiple message response.
+ * TSIG records must be present on the first and last messages, and at least every 100 records
+ * in between. After calling this routine, {@link Message#isVerified()} may be called on this
+ * message.
+ *
+ * @param message The message
+ * @param messageBytes The message in unparsed form
+ * @param isLastMessage If true, verifies that the {@link Message} has an {@link TSIGRecord}.
+ * @return The result of the verification (as an Rcode)
+ * @see Rcode
+ * @since 3.5.3
+ */
+ public int verify(Message message, byte[] messageBytes, boolean isLastMessage) {
+ final String warningPrefix = "FORMERR: {}";
+ TSIGRecord tsig = message.getTSIG();
+
+ // https://datatracker.ietf.org/doc/html/rfc8945#section-5.3.1
+ // [...] a client that receives DNS messages and verifies TSIG MUST accept up to 99
+ // intermediary messages without a TSIG and MUST verify that both the first and last message
+ // contain a TSIG.
nresponses++;
if (nresponses == 1) {
- int result = key.verify(m, b, lastTSIG);
- lastTSIG = tsig;
- return result;
+ if (tsig != null) {
+ int result = key.verify(message, messageBytes, queryTsig, true, sharedHmac);
+ hmacAddSignature(sharedHmac, tsig);
+ lastsigned = nresponses;
+ return result;
+ } else {
+ errorMessage = "missing required signature on first message";
+ log.debug(warningPrefix, errorMessage);
+ message.tsigState = Message.TSIG_FAILED;
+ return Rcode.FORMERR;
+ }
}
if (tsig != null) {
- int result = key.verify(m, b, lastTSIG, false);
+ int result = key.verify(message, messageBytes, null, false, sharedHmac);
lastsigned = nresponses;
- lastTSIG = tsig;
+ hmacAddSignature(sharedHmac, tsig);
return result;
} else {
boolean required = nresponses - lastsigned >= 100;
if (required) {
- log.debug("FORMERR: missing required signature on {}th message", nresponses);
- m.tsigState = Message.TSIG_FAILED;
+ errorMessage = "Missing required signature on message #" + nresponses;
+ log.debug(warningPrefix, errorMessage);
+ message.tsigState = Message.TSIG_FAILED;
+ return Rcode.FORMERR;
+ } else if (isLastMessage) {
+ errorMessage = "Missing required signature on last message";
+ log.debug(warningPrefix, errorMessage);
+ message.tsigState = Message.TSIG_FAILED;
return Rcode.FORMERR;
} else {
- log.trace("Intermediate message {} without signature", nresponses);
- m.tsigState = Message.TSIG_INTERMEDIATE;
+ errorMessage = "Intermediate message #" + nresponses + " without signature";
+ log.debug(warningPrefix, errorMessage);
+ addUnsignedMessageToMac(message, messageBytes, sharedHmac);
return Rcode.NOERROR;
}
}
}
+
+ private void addUnsignedMessageToMac(Message m, byte[] messageBytes, Mac hmac) {
+ byte[] header = m.getHeader().toWire();
+ if (log.isTraceEnabled()) {
+ log.trace(hexdump.dump("TSIG-HMAC header", header));
+ }
+
+ hmac.update(header);
+ int len = messageBytes.length - header.length;
+ if (log.isTraceEnabled()) {
+ log.trace(hexdump.dump("TSIG-HMAC message after header", messageBytes, header.length, len));
+ }
+
+ hmac.update(messageBytes, header.length, len);
+ m.tsigState = Message.TSIG_INTERMEDIATE;
+ }
}
}
diff --git a/src/main/java/org/xbill/DNS/ZoneTransferIn.java b/src/main/java/org/xbill/DNS/ZoneTransferIn.java
index 177460ff..91fd73c8 100644
--- a/src/main/java/org/xbill/DNS/ZoneTransferIn.java
+++ b/src/main/java/org/xbill/DNS/ZoneTransferIn.java
@@ -49,17 +49,17 @@ public class ZoneTransferIn {
private static final int AXFR = 6;
private static final int END = 7;
- private Name zname;
+ private final Name zname;
private int qtype;
private int dclass;
- private long ixfr_serial;
- private boolean want_fallback;
+ private final long ixfr_serial;
+ private final boolean want_fallback;
private ZoneTransferHandler handler;
private SocketAddress localAddress;
- private SocketAddress address;
+ private final SocketAddress address;
private TCPClient client;
- private TSIG tsig;
+ private final TSIG tsig;
private TSIG.StreamVerifier verifier;
private Duration timeout = Duration.ofMinutes(15);
@@ -155,7 +155,7 @@ public void startIXFRAdds(Record soa) {
public void handleRecord(Record r) {
if (ixfr != null) {
Delta delta = ixfr.get(ixfr.size() - 1);
- if (delta.adds.size() > 0) {
+ if (!delta.adds.isEmpty()) {
delta.adds.add(r);
} else {
delta.deletes.add(r);
@@ -166,9 +166,7 @@ public void handleRecord(Record r) {
}
}
- private ZoneTransferIn() {}
-
- private ZoneTransferIn(
+ ZoneTransferIn(
Name zone, int xfrtype, long serial, boolean fallback, SocketAddress address, TSIG key) {
this.address = address;
this.tsig = key;
@@ -330,13 +328,17 @@ public void setLocalAddress(SocketAddress addr) {
}
private void openConnection() throws IOException {
- client = new TCPClient(timeout);
+ client = createTcpClient(timeout);
if (localAddress != null) {
client.bind(localAddress);
}
client.connect(address);
}
+ TCPClient createTcpClient(Duration timeout) throws IOException {
+ return new TCPClient(timeout);
+ }
+
private void sendQuery() throws IOException {
Record question = Record.newRecord(zname, qtype, dclass);
@@ -477,9 +479,10 @@ private void parseRR(Record rec) throws ZoneTransferException {
private void closeConnection() {
try {
if (client != null) {
- client.cleanup();
+ client.close();
}
} catch (IOException e) {
+ // Ignore
}
}
@@ -490,7 +493,7 @@ private Message parseMessage(byte[] b) throws WireParseException {
if (e instanceof WireParseException) {
throw (WireParseException) e;
}
- throw new WireParseException("Error parsing message");
+ throw new WireParseException("Error parsing message", e);
}
}
@@ -499,14 +502,24 @@ private void doxfr() throws IOException, ZoneTransferException {
while (state != END) {
byte[] in = client.recv();
Message response = parseMessage(in);
+ List answers = response.getSection(Section.ANSWER);
if (response.getHeader().getRcode() == Rcode.NOERROR && verifier != null) {
- int error = verifier.verify(response, in);
+ int error =
+ verifier.verify(response, in, answers.get(answers.size() - 1).getType() == Type.SOA);
if (error != Rcode.NOERROR) {
- fail("TSIG failure: " + Rcode.TSIGstring(error));
+ if (verifier.getErrorMessage() != null) {
+ fail(
+ "TSIG failure: "
+ + Rcode.TSIGstring(error)
+ + " ("
+ + verifier.getErrorMessage()
+ + ")");
+ } else {
+ fail("TSIG failure: " + Rcode.TSIGstring(error));
+ }
}
}
- List answers = response.getSection(Section.ANSWER);
if (state == INITIALSOA) {
int rcode = response.getRcode();
if (rcode != Rcode.NOERROR) {
@@ -533,10 +546,6 @@ private void doxfr() throws IOException, ZoneTransferException {
for (Record answer : answers) {
parseRR(answer);
}
-
- if (state == END && verifier != null && !response.isVerified()) {
- fail("last message must be signed");
- }
}
}
diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java
index d4e52b0e..66da6dd9 100644
--- a/src/test/java/org/xbill/DNS/TSIGTest.java
+++ b/src/test/java/org/xbill/DNS/TSIGTest.java
@@ -11,33 +11,49 @@
import static org.mockito.ArgumentMatchers.anyInt;
import java.io.IOException;
+import java.lang.reflect.Field;
import java.net.InetAddress;
import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.nio.ByteBuffer;
+import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
+import java.time.ZoneId;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
+import javax.crypto.spec.SecretKeySpec;
+import lombok.Getter;
+import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
+import org.xbill.DNS.TSIG.StreamGenerator;
import org.xbill.DNS.utils.base64;
class TSIGTest {
+ private final TSIG defaultKey = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
+
@Test
void signedQuery() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
-
Record question = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN);
Message query = Message.newQuery(question);
- query.setTSIG(key);
+ query.setTSIG(defaultKey);
byte[] qbytes = query.toWire(512);
assertEquals(1, qbytes[11]);
Message qparsed = new Message(qbytes);
- int result = key.verify(qparsed, qbytes, null);
+ int result = defaultKey.verify(qparsed, qbytes, null);
assertEquals(Rcode.NOERROR, result);
assertTrue(qparsed.isSigned());
assertTrue(qparsed.isVerified());
@@ -87,12 +103,10 @@ void queryStringAlgError() {
@Test
void queryIsLastAddMessageRecord() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
-
Record rec = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN);
OPTRecord opt = new OPTRecord(SimpleResolver.DEFAULT_EDNS_PAYLOADSIZE, 0, 0, 0);
Message msg = Message.newQuery(rec);
- msg.setTSIG(key);
+ msg.setTSIG(defaultKey);
msg.addRecord(opt, Section.ADDITIONAL);
byte[] bytes = msg.toWire(512);
assertEquals(2, bytes[11]); // additional RR count, lower byte
@@ -101,7 +115,7 @@ void queryIsLastAddMessageRecord() throws IOException {
List additionalSection = parsed.getSection(Section.ADDITIONAL);
assertEquals(Type.string(Type.OPT), Type.string(additionalSection.get(0).getType()));
assertEquals(Type.string(Type.TSIG), Type.string(additionalSection.get(1).getType()));
- int result = key.verify(parsed, bytes, null);
+ int result = defaultKey.verify(parsed, bytes, null);
assertEquals(Rcode.NOERROR, result);
assertTrue(parsed.isSigned());
assertTrue(parsed.isVerified());
@@ -116,8 +130,7 @@ void queryAndTsigApplyMisbehaves() throws IOException {
assertFalse(msg.isSigned());
assertFalse(msg.isVerified());
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
- key.apply(msg, null); // additional RR count, lower byte
+ defaultKey.apply(msg, null); // additional RR count, lower byte
byte[] bytes = msg.toWire(Message.MAXLENGTH);
assertThrows(WireParseException.class, () -> new Message(bytes), "Expected TSIG error");
@@ -125,7 +138,6 @@ void queryAndTsigApplyMisbehaves() throws IOException {
@Test
void tsigInQueryIsLastViaResolver() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
SimpleResolver res =
new SimpleResolver("127.0.0.1") {
@Override
@@ -140,7 +152,7 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e
}
}
};
- res.setTSIGKey(key);
+ res.setTSIGKey(defaultKey);
Name qname = Name.fromString("www.example.com.");
Record question = Record.newRecord(qname, Type.A, DClass.IN);
@@ -150,7 +162,7 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e
List additionalSection = response.getSection(Section.ADDITIONAL);
assertEquals(Type.string(Type.OPT), Type.string(additionalSection.get(0).getType()));
assertEquals(Type.string(Type.TSIG), Type.string(additionalSection.get(1).getType()));
- int result = key.verify(response, response.toWire(), null);
+ int result = defaultKey.verify(response, response.toWire(), null);
assertEquals(Rcode.NOERROR, result);
assertTrue(response.isSigned());
assertTrue(response.isVerified());
@@ -158,14 +170,12 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e
@Test
void unsignedQuerySignedResponse() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
-
Name qname = Name.fromString("www.example.");
Record question = Record.newRecord(qname, Type.A, DClass.IN);
Message query = Message.newQuery(question);
Message response = new Message(query.getHeader().getID());
- response.setTSIG(key, Rcode.NOERROR, null);
+ response.setTSIG(defaultKey, Rcode.NOERROR, null);
response.getHeader().setFlag(Flags.QR);
response.addRecord(question, Section.QUESTION);
Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null);
@@ -173,7 +183,7 @@ void unsignedQuerySignedResponse() throws IOException {
byte[] rbytes = response.toWire(Message.MAXLENGTH);
Message rparsed = new Message(rbytes);
- int result = key.verify(rparsed, rbytes, null);
+ int result = defaultKey.verify(rparsed, rbytes, null);
assertEquals(Rcode.NOERROR, result);
assertTrue(rparsed.isSigned());
assertTrue(rparsed.isVerified());
@@ -181,19 +191,17 @@ void unsignedQuerySignedResponse() throws IOException {
@Test
void signedQuerySignedResponse() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
-
Name qname = Name.fromString("www.example.");
Record question = Record.newRecord(qname, Type.A, DClass.IN);
Message query = Message.newQuery(question);
- query.setTSIG(key);
+ query.setTSIG(defaultKey);
byte[] qbytes = query.toWire(Message.MAXLENGTH);
Message qparsed = new Message(qbytes);
assertNotNull(query.getGeneratedTSIG());
assertEquals(query.getGeneratedTSIG(), qparsed.getTSIG());
Message response = new Message(query.getHeader().getID());
- response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG());
+ response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG());
response.getHeader().setFlag(Flags.QR);
response.addRecord(question, Section.QUESTION);
Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null);
@@ -201,7 +209,7 @@ void signedQuerySignedResponse() throws IOException {
byte[] rbytes = response.toWire(Message.MAXLENGTH);
Message rparsed = new Message(rbytes);
- int result = key.verify(rparsed, rbytes, query.getGeneratedTSIG());
+ int result = defaultKey.verify(rparsed, rbytes, query.getGeneratedTSIG());
assertEquals(Rcode.NOERROR, result);
assertTrue(rparsed.isSigned());
assertTrue(rparsed.isVerified());
@@ -209,8 +217,6 @@ void signedQuerySignedResponse() throws IOException {
@Test
void signedQuerySignedResponseViaResolver() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
-
Name qname = Name.fromString("www.example.");
Record question = Record.newRecord(qname, Type.A, DClass.IN);
Message query = Message.newQuery(question);
@@ -231,7 +237,7 @@ void signedQuerySignedResponseViaResolver() throws IOException {
Message qparsed = new Message(a.getArgument(3, byte[].class));
Message response = new Message(qparsed.getHeader().getID());
- response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG());
+ response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG());
response.getHeader().setFlag(Flags.QR);
response.addRecord(question, Section.QUESTION);
Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null);
@@ -243,7 +249,7 @@ void signedQuerySignedResponseViaResolver() throws IOException {
return f;
});
SimpleResolver res = new SimpleResolver("127.0.0.1");
- res.setTSIGKey(key);
+ res.setTSIGKey(defaultKey);
Message responseFromResolver = res.send(query);
assertTrue(responseFromResolver.isSigned());
@@ -253,17 +259,15 @@ void signedQuerySignedResponseViaResolver() throws IOException {
@Test
void truncated() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
-
Name qname = Name.fromString("www.example.");
Record question = Record.newRecord(qname, Type.A, DClass.IN);
Message query = Message.newQuery(question);
- query.setTSIG(key, Rcode.NOERROR, null);
+ query.setTSIG(defaultKey, Rcode.NOERROR, null);
byte[] qbytes = query.toWire(512);
Message qparsed = new Message(qbytes);
Message response = new Message(query.getHeader().getID());
- response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG());
+ response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG());
response.getHeader().setFlag(Flags.QR);
response.addRecord(question, Section.QUESTION);
for (int i = 0; i < 40; i++) {
@@ -274,7 +278,7 @@ void truncated() throws IOException {
Message rparsed = new Message(rbytes);
assertTrue(rparsed.getHeader().getFlag(Flags.TC));
- int result = key.verify(rparsed, rbytes, qparsed.getTSIG());
+ int result = defaultKey.verify(rparsed, rbytes, qparsed.getTSIG());
assertEquals(Rcode.NOERROR, result);
assertTrue(rparsed.isSigned());
assertTrue(rparsed.isVerified());
@@ -291,7 +295,6 @@ void rdataFromString() {
@Test
void testTSIGMessageClone() throws IOException {
- TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678");
TSIGRecord old =
new TSIGRecord(
Name.fromConstantString("example."),
@@ -312,7 +315,7 @@ void testTSIGMessageClone() throws IOException {
response.addRecord(question, Section.QUESTION);
response.addRecord(
new ARecord(qname, DClass.IN, 0, InetAddress.getByName("127.0.0.1")), Section.ANSWER);
- response.setTSIG(key, Rcode.NOERROR, old);
+ response.setTSIG(defaultKey, Rcode.NOERROR, old);
byte[] responseBytes = response.toWire(Message.MAXLENGTH);
assertNotNull(responseBytes);
assertNotEquals(0, responseBytes.length);
@@ -325,4 +328,326 @@ void testTSIGMessageClone() throws IOException {
assertNotNull(cloneBytes);
assertNotEquals(0, cloneBytes.length);
}
+
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 0, 101})
+ void testStreamGeneratorNthMessageArgument(int signEvery) {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> new TSIG.StreamGenerator(defaultKey, null, signEvery));
+ }
+
+ @Test
+ void testTSIGStreamVerifierMissingMinimumTsig() throws Exception {
+ MockMessageClient client = new MockMessageClient(defaultKey);
+ int numResponses = 200;
+ byte[] query = client.createQuery();
+ List response;
+ try (MockMessageServer server = new MockMessageServer(defaultKey, numResponses, 200, false)) {
+ server.send(query);
+ response = server.getMessages();
+ }
+ Map expectedRcodes = new HashMap<>();
+ for (int i = 0; i < numResponses; i++) {
+ expectedRcodes.put(i, i < 100 ? Rcode.NOERROR : Rcode.FORMERR);
+ }
+ expectedRcodes.put(numResponses - 1, Rcode.BADSIG);
+ client.validateResponse(query, response, expectedRcodes, false);
+ }
+
+ @ParameterizedTest(name = "testTSIGStreamVerifier(numResponses: {0}, signEvery: {1})")
+ @CsvSource({
+ "20,1",
+ "53,6",
+ "105,7",
+ "1000,100",
+ })
+ void testTSIGStreamVerifier(int numResponses, int signEvery) throws Exception {
+ MockMessageClient client = new MockMessageClient(defaultKey);
+ byte[] query = client.createQuery();
+ List response;
+ try (MockMessageServer server =
+ new MockMessageServer(defaultKey, numResponses, signEvery, false)) {
+
+ server.send(query);
+ response = server.getMessages();
+ }
+ Map expectedRcodes = new HashMap<>();
+ for (int i = 0; i < numResponses; i++) {
+ expectedRcodes.put(i, Rcode.NOERROR);
+ }
+ client.validateResponse(query, response, expectedRcodes, true);
+ }
+
+ @ParameterizedTest(name = "testTSIGStreamVerifierLastMessage(numResponses: {0}, signEvery: {1})")
+ @CsvSource({
+ "53,6",
+ "105,7",
+ "1000,100",
+ })
+ void testTSIGStreamVerifierLastMessage(int numResponses, int signEvery) throws Exception {
+ MockMessageClient client = new MockMessageClient(defaultKey);
+ byte[] query = client.createQuery();
+ List response;
+ try (MockMessageServer server =
+ new MockMessageServer(defaultKey, numResponses, signEvery, true)) {
+
+ server.send(query);
+ response = server.getMessages();
+ }
+ Map expectedRcodes = new HashMap<>();
+ for (int i = 0; i < numResponses; i++) {
+ expectedRcodes.put(i, Rcode.NOERROR);
+ }
+
+ expectedRcodes.put(numResponses - 1, Rcode.FORMERR);
+ client.validateResponse(query, response, expectedRcodes, false);
+ }
+
+ @Test
+ void testFromTcpStream() throws IOException {
+ DNSInput request = new DNSInput(IOUtils.resourceToByteArray("/tsig-axfr/request.bin"));
+ byte[] queryBytes = request.readByteArray(request.readU16());
+ Message query = new Message(queryBytes);
+ assertNotNull(query.getTSIG());
+ TSIG key =
+ new TSIG(
+ TSIG.HMAC_SHA256,
+ Name.fromConstantString("dnssecishardtest."),
+ new SecretKeySpec(
+ Objects.requireNonNull(
+ base64.fromString("q4Gsu0nYoyub20//PATXhABobmrVUQyqq5TFzYHfC7o=")),
+ "HmacSHA256"),
+ Clock.fixed(Instant.parse("2023-11-01T20:52:08Z"), ZoneId.of("UTC")));
+
+ TSIG.StreamVerifier verifier = new TSIG.StreamVerifier(key, query.getTSIG());
+ DNSInput response = new DNSInput(IOUtils.resourceToByteArray("/tsig-axfr/response.bin"));
+
+ // Use a list, not a map, to keep the message order intact
+ List> messages = new ArrayList<>();
+ while (response.remaining() > 0) {
+ byte[] messageBytes = response.readByteArray(response.readU16());
+ Message message = new Message(messageBytes);
+ messages.add(new AbstractMap.SimpleEntry<>(message, messageBytes));
+ }
+
+ for (int i = 0; i < messages.size(); i++) {
+ Map.Entry e = messages.get(i);
+ assertEquals(
+ Rcode.NOERROR, verifier.verify(e.getKey(), e.getValue(), i == messages.size() - 1));
+ }
+ }
+
+ @Test
+ void testAxfrLastNotSignedError() throws Exception {
+ Name name = Name.fromConstantString("example.com.");
+ ZoneTransferIn client =
+ new ZoneTransferIn(
+ name,
+ Type.AXFR,
+ 0,
+ false,
+ new InetSocketAddress(InetAddress.getLocalHost(), 53),
+ defaultKey) {
+ @Override
+ TCPClient createTcpClient(Duration timeout) throws IOException {
+ return new MockMessageServer(defaultKey, 200, 20, true);
+ }
+ };
+
+ ZoneTransferException exception =
+ assertThrows(ZoneTransferException.class, () -> client.run(new ZoneBuilderAxfrHandler()));
+ assertTrue(exception.getMessage().contains(Rcode.TSIGstring(Rcode.FORMERR)));
+ assertTrue(exception.getMessage().contains("last"));
+ }
+
+ @Test
+ void testAxfr() throws Exception {
+ Name name = Name.fromConstantString("example.com.");
+ ZoneTransferIn client =
+ new ZoneTransferIn(
+ name,
+ Type.AXFR,
+ 0,
+ false,
+ new InetSocketAddress(InetAddress.getLocalHost(), 53),
+ defaultKey) {
+ @Override
+ TCPClient createTcpClient(Duration timeout) throws IOException {
+ return new MockMessageServer(defaultKey, 200, 20, false);
+ }
+ };
+
+ ZoneBuilderAxfrHandler handler = new ZoneBuilderAxfrHandler();
+ client.run(handler);
+ // soa on first message, + a record on every message, +soa on last message
+ assertEquals(202, handler.getRecords().size());
+ }
+
+ @Getter
+ private static class ZoneBuilderAxfrHandler implements ZoneTransferIn.ZoneTransferHandler {
+ private final List records = new ArrayList<>();
+
+ @Override
+ public void startAXFR() {}
+
+ @Override
+ public void startIXFR() {}
+
+ @Override
+ public void startIXFRDeletes(Record soa) {}
+
+ @Override
+ public void startIXFRAdds(Record soa) {}
+
+ @Override
+ public void handleRecord(Record r) {
+ records.add(r);
+ }
+ }
+
+ private static class MockMessageClient {
+ private final TSIG key;
+
+ MockMessageClient(TSIG key) {
+ this.key = key;
+ }
+
+ byte[] createQuery() throws TextParseException {
+ Name qname = Name.fromString("www.example.");
+ Record question = Record.newRecord(qname, Type.A, DClass.IN);
+ Message query = Message.newQuery(question);
+ query.setTSIG(key);
+ return query.toWire(Message.MAXLENGTH);
+ }
+
+ public void validateResponse(
+ byte[] query,
+ List responses,
+ Map expectedRcodes,
+ boolean lastResponseSignedState)
+ throws IOException {
+ Message queryMessage = new Message(query);
+ TSIG.StreamVerifier verifier = new TSIG.StreamVerifier(key, queryMessage.getTSIG());
+
+ Map actualRcodes = new HashMap<>();
+ for (int i = 0; i < responses.size(); i++) {
+ boolean isLastMessage = i == responses.size() - 1;
+ byte[] renderedMessage = responses.get(i).toWire(Message.MAXLENGTH);
+ Message messageFromWire = new Message(renderedMessage);
+ actualRcodes.put(i, verifier.verify(messageFromWire, renderedMessage, isLastMessage));
+ if (isLastMessage) {
+ assertEquals(messageFromWire.isVerified(), lastResponseSignedState);
+ }
+ }
+
+ assertEquals(expectedRcodes, actualRcodes);
+ }
+ }
+
+ private static class MockMessageServer extends TCPClient {
+ private final TSIG key;
+ private final int responseMessageCount;
+ private final int signEvery;
+ private final boolean skipLast;
+ @Getter private List messages;
+ private int recvCalls;
+
+ MockMessageServer(TSIG key, int responseMessageCount, int signEvery, boolean skipLast)
+ throws IOException {
+ super(Duration.ZERO);
+ this.key = key;
+ this.responseMessageCount = responseMessageCount;
+ this.signEvery = signEvery;
+ this.skipLast = skipLast;
+ }
+
+ @Override
+ void bind(SocketAddress addr) {
+ // do nothing
+ }
+
+ @Override
+ void connect(SocketAddress addr) {
+ // do nothing
+ }
+
+ @Override
+ public void close() {
+ // do nothing
+ }
+
+ @Override
+ void send(byte[] queryMessageBytes) throws IOException {
+ Message parsedQueryMessage = new Message(queryMessageBytes);
+ assertNotNull(parsedQueryMessage.getTSIG());
+
+ messages = new LinkedList<>();
+ StreamGenerator generator;
+ try {
+ generator = getStreamGenerator(signEvery, parsedQueryMessage);
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new IOException(e);
+ }
+
+ Name queryName = parsedQueryMessage.getQuestion().getName();
+ Record soa =
+ new SOARecord(
+ queryName,
+ DClass.IN,
+ 300,
+ new Name("ns1", queryName),
+ new Name("admin", queryName),
+ 1,
+ 3600,
+ 1,
+ 3600,
+ 1800);
+ for (int i = 0; i < responseMessageCount; i++) {
+ Message response = new Message(parsedQueryMessage.getHeader().getID());
+ response.getHeader().setFlag(Flags.QR);
+ response.addRecord(parsedQueryMessage.getQuestion(), Section.QUESTION);
+ if (i == 0) {
+ response.addRecord(soa, Section.ANSWER);
+ }
+ Record answer =
+ new ARecord(
+ parsedQueryMessage.getQuestion().getName(),
+ DClass.IN,
+ 300,
+ InetAddress.getByAddress(ByteBuffer.allocate(4).putInt(i).array()));
+ response.addRecord(answer, Section.ANSWER);
+
+ if (i == responseMessageCount - 1) {
+ response.addRecord(soa, Section.ANSWER);
+ }
+
+ generator.generate(response, !skipLast && i == responseMessageCount - 1);
+ messages.add(response);
+ }
+ }
+
+ @Override
+ byte[] recv() {
+ return messages.get(recvCalls++).toWire(Message.MAXLENGTH);
+ }
+
+ private StreamGenerator getStreamGenerator(int signEvery, Message parsedQueryMessage)
+ throws NoSuchFieldException, IllegalAccessException {
+ TSIGRecord queryMessageTSIG = parsedQueryMessage.getTSIG();
+ StreamGenerator generator;
+
+ // Hack for testing invalid server responses, the constructor would normally prevent such an
+ // invalid argument
+ if (signEvery > 100) {
+ generator = new StreamGenerator(key, queryMessageTSIG, 1);
+ Field signEveryNthMessage = StreamGenerator.class.getDeclaredField("signEveryNthMessage");
+ signEveryNthMessage.setAccessible(true);
+ signEveryNthMessage.set(generator, signEvery);
+ } else {
+ generator = new StreamGenerator(key, queryMessageTSIG, signEvery);
+ }
+ return generator;
+ }
+ }
}
diff --git a/src/test/resources/tsig-axfr/request.bin b/src/test/resources/tsig-axfr/request.bin
new file mode 100644
index 00000000..9e110f7e
Binary files /dev/null and b/src/test/resources/tsig-axfr/request.bin differ
diff --git a/src/test/resources/tsig-axfr/response-messages.txt b/src/test/resources/tsig-axfr/response-messages.txt
new file mode 100644
index 00000000..492d3d2a
--- /dev/null
+++ b/src/test/resources/tsig-axfr/response-messages.txt
@@ -0,0 +1 @@
+dig @dnssec1.stage.arin.net -p 53 AXFR dnssecishard.com. -y 'hmac-sha256:dnssecishardtest:q4Gsu0nYoyub20//PATXhABobmrVUQyqq5TFzYHfC7o='
diff --git a/src/test/resources/tsig-axfr/response.bin b/src/test/resources/tsig-axfr/response.bin
new file mode 100644
index 00000000..f861f8b1
Binary files /dev/null and b/src/test/resources/tsig-axfr/response.bin differ