diff --git a/.gitattributes b/.gitattributes index fcadb2cf..3fef5261 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ * text eol=lf +*.bin binary diff --git a/pom.xml b/pom.xml index bcf6c9f1..62fcc432 100644 --- a/pom.xml +++ b/pom.xml @@ -433,6 +433,12 @@ ${vertx.version} test + + commons-io + commons-io + 2.15.0 + test + diff --git a/src/main/java/org/xbill/DNS/TCPClient.java b/src/main/java/org/xbill/DNS/TCPClient.java index 7f65f439..6151e1b9 100644 --- a/src/main/java/org/xbill/DNS/TCPClient.java +++ b/src/main/java/org/xbill/DNS/TCPClient.java @@ -14,7 +14,7 @@ import java.time.Duration; import java.time.temporal.ChronoUnit; -final class TCPClient { +class TCPClient implements AutoCloseable { private final long startTime; private final Duration timeout; private final SelectionKey key; @@ -144,7 +144,7 @@ private void blockUntil(SelectionKey key) throws IOException { } } - void cleanup() throws IOException { + public void close() throws IOException { key.selector().close(); key.channel().close(); } diff --git a/src/main/java/org/xbill/DNS/TSIG.java b/src/main/java/org/xbill/DNS/TSIG.java index aabd364b..29ecbea1 100644 --- a/src/main/java/org/xbill/DNS/TSIG.java +++ b/src/main/java/org/xbill/DNS/TSIG.java @@ -15,6 +15,7 @@ import javax.crypto.Mac; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.xbill.DNS.utils.base64; import org.xbill.DNS.utils.hexdump; @@ -369,28 +370,35 @@ public TSIGRecord generate(Message m, byte[] b, int error, TSIGRecord old) { */ public TSIGRecord generate( Message m, byte[] b, int error, TSIGRecord old, boolean fullSignature) { - Instant timeSigned; - if (error == Rcode.BADTIME) { - timeSigned = old.getTimeSigned(); - } else { - timeSigned = clock.instant(); - } - - boolean signing = false; Mac hmac = null; if (error == Rcode.NOERROR || error == Rcode.BADTIME || error == Rcode.BADTRUNC) { - signing = true; hmac = initHmac(); } - Duration fudge; - int fudgeOption = Options.intValue("tsigfudge"); - if (fudgeOption < 0 || fudgeOption > 0x7FFF) { - fudge = FUDGE; - } else { - fudge = Duration.ofSeconds(fudgeOption); - } + return generate(m, b, error, old, fullSignature, hmac); + } + /** + * Generates a TSIG record with a specific error for a message that has been rendered. + * + * @param m The message + * @param b The rendered message + * @param error The error + * @param old If this message is a response, the TSIG from the request + * @param fullSignature {@code true} if this {@link TSIGRecord} is the to be added to the first of + * many messages in a TCP connection and all TSIG variables (rfc2845, 3.4.2.) should be + * included in the signature. {@code false} for subsequent messages with reduced TSIG + * variables set (rfc2845, 4.4.). + * @param hmac A mac instance to reuse for a stream of messages to sign, e.g. when doing a zone + * transfer. + * @return The TSIG record to be added to the message + */ + private TSIGRecord generate( + Message m, byte[] b, int error, TSIGRecord old, boolean fullSignature, Mac hmac) { + Instant timeSigned = getTimeSigned(error, old); + Duration fudge = getTsigFudge(); + + boolean signing = hmac != null; if (old != null && signing) { hmacAddSignature(hmac, old); } @@ -413,7 +421,7 @@ public TSIGRecord generate( alg.toWireCanonical(out); } - writeTsigTimersVariables(timeSigned, fudge, out); + writeTsigTimerVariables(timeSigned, fudge, out); if (fullSignature) { out.writeU16(error); out.writeU16(0); /* No other data */ @@ -450,6 +458,15 @@ public TSIGRecord generate( other); } + private Instant getTimeSigned(int error, TSIGRecord old) { + return error == Rcode.BADTIME ? old.getTimeSigned() : clock.instant(); + } + + private static Duration getTsigFudge() { + int fudgeOption = Options.intValue("tsigfudge"); + return fudgeOption < 0 || fudgeOption > 0x7FFF ? FUDGE : Duration.ofSeconds(fudgeOption); + } + /** * Generates a TSIG record for a message and adds it to the message * @@ -522,6 +539,8 @@ public void applyStream(Message m, TSIGRecord old, boolean fullSignature) { * TSIG is expected to be present, it is an error if one is not present. After calling this * routine, Message.isVerified() may be called on this message. * + *

Use {@link StreamVerifier} to validate multiple messages in a stream. + * * @param m The message * @param b An array containing the message in unparsed form. This is necessary since TSIG signs * the message in wire format, and we can't recreate the exact wire format (with the same name @@ -542,6 +561,8 @@ public byte verify(Message m, byte[] b, int length, TSIGRecord old) { * TSIG is expected to be present, it is an error if one is not present. After calling this * routine, Message.isVerified() may be called on this message. * + *

Use {@link StreamVerifier} to validate multiple messages in a stream. + * * @param m The message to verify * @param messageBytes An array containing the message in unparsed form. This is necessary since * TSIG signs the message in wire format, and we can't recreate the exact wire format (with @@ -559,6 +580,8 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG) { * TSIG is expected to be present, it is an error if one is not present. After calling this * routine, Message.isVerified() may be called on this message. * + *

Use {@link StreamVerifier} to validate multiple messages in a stream. + * * @param m The message to verify * @param messageBytes An array containing the message in unparsed form. This is necessary since * TSIG signs the message in wire format, and we can't recreate the exact wire format (with @@ -572,6 +595,27 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG) { * @since 3.2 */ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolean fullSignature) { + return verify(m, messageBytes, requestTSIG, fullSignature, null); + } + + /** + * Verifies a TSIG record on an incoming message. Since this is only called in the context where a + * TSIG is expected to be present, it is an error if one is not present. After calling this + * routine, Message.isVerified() may be called on this message. + * + * @param m The message to verify + * @param messageBytes An array containing the message in unparsed form. This is necessary since + * TSIG signs the message in wire format, and we can't recreate the exact wire format (with + * the same name compression). + * @param requestTSIG If this message is a response, the TSIG from the request + * @param fullSignature {@code true} if this message is the first of many in a TCP connection and + * all TSIG variables (rfc2845, 3.4.2.) should be included in the signature. {@code false} for + * subsequent messages with reduced TSIG variables set (rfc2845, 4.4.). + * @return The result of the verification (as an Rcode) + * @see Rcode + */ + private int verify( + Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolean fullSignature, Mac hmac) { m.tsigState = Message.TSIG_FAILED; TSIGRecord tsig = m.getTSIG(); if (tsig == null) { @@ -589,7 +633,10 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea return Rcode.BADKEY; } - Mac hmac = initHmac(); + if (hmac == null) { + hmac = initHmac(); + } + if (requestTSIG != null && tsig.getError() != Rcode.BADKEY && tsig.getError() != Rcode.BADSIG) { hmacAddSignature(hmac, requestTSIG); } @@ -608,6 +655,27 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea } hmac.update(messageBytes, header.length, len); + byte[] tsigVariables = getTsigVariables(fullSignature, tsig); + hmac.update(tsigVariables); + + byte[] signature = tsig.getSignature(); + int badsig = verifySignature(hmac, signature); + if (badsig != Rcode.NOERROR) { + return badsig; + } + + // validate time after the signature, as per + // https://www.rfc-editor.org/rfc/rfc8945.html#section-5.4 + int badtime = verifyTime(tsig); + if (badtime != Rcode.NOERROR) { + return badtime; + } + + m.tsigState = Message.TSIG_VERIFIED; + return Rcode.NOERROR; + } + + private static byte[] getTsigVariables(boolean fullSignature, TSIGRecord tsig) { DNSOutput out = new DNSOutput(); if (fullSignature) { tsig.getName().toWireCanonical(out); @@ -615,7 +683,7 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea out.writeU32(tsig.ttl); tsig.getAlgorithm().toWireCanonical(out); } - writeTsigTimersVariables(tsig.getTimeSigned(), tsig.getFudge(), out); + writeTsigTimerVariables(tsig.getTimeSigned(), tsig.getFudge(), out); if (fullSignature) { out.writeU16(tsig.getError()); if (tsig.getOther() != null) { @@ -630,9 +698,10 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea if (log.isTraceEnabled()) { log.trace(hexdump.dump("TSIG-HMAC variables", tsigVariables)); } - hmac.update(tsigVariables); + return tsigVariables; + } - byte[] signature = tsig.getSignature(); + private static int verifySignature(Mac hmac, byte[] signature) { int digestLength = hmac.getMacLength(); // rfc4635#section-3.1, 4.: @@ -662,9 +731,10 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea return Rcode.BADSIG; } } + return Rcode.NOERROR; + } - // validate time after the signature, as per - // https://tools.ietf.org/html/draft-ietf-dnsop-rfc2845bis-08#section-5.4.3 + private int verifyTime(TSIGRecord tsig) { Instant now = clock.instant(); Duration delta = Duration.between(now, tsig.getTimeSigned()).abs(); if (delta.compareTo(tsig.getFudge()) > 0) { @@ -675,8 +745,6 @@ public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolea tsig.getFudge()); return Rcode.BADTIME; } - - m.tsigState = Message.TSIG_VERIFIED; return Rcode.NOERROR; } @@ -706,7 +774,7 @@ private static void hmacAddSignature(Mac hmac, TSIGRecord tsig) { hmac.update(tsig.getSignature()); } - private static void writeTsigTimersVariables(Instant instant, Duration fudge, DNSOutput out) { + private static void writeTsigTimerVariables(Instant instant, Duration fudge, DNSOutput out) { writeTsigTime(instant, out); out.writeU16((int) fudge.getSeconds()); } @@ -719,58 +787,198 @@ private static void writeTsigTime(Instant instant, DNSOutput out) { out.writeU32(timeLow); } + /** + * A utility class for generating signed message responses. + * + * @since 3.5.3 + */ + public static class StreamGenerator { + private final TSIG key; + private final Mac sharedHmac; + private final int signEveryNthMessage; + + private int numGenerated; + private TSIGRecord lastTsigRecord; + + /** + * Creates an instance to sign multiple message for use in a stream. + * + *

This class creates a {@link TSIGRecord} on every message to conform with RFC 8945, 5.3.1. + * + * @param key The TSIG key used to create the signature records. + * @param queryTsig The initial TSIG records, e.g. from a query to a server. + */ + public StreamGenerator(TSIG key, TSIGRecord queryTsig) { + // The TSIG MUST be included on all DNS messages in the response. + this(key, queryTsig, 1); + } + + /** + * This constructor is only for unit-testing {@link StreamVerifier} with responses where + * not every message is signed. + */ + StreamGenerator(TSIG key, TSIGRecord queryTsig, int signEveryNthMessage) { + if (signEveryNthMessage < 1 || signEveryNthMessage > 100) { + throw new IllegalArgumentException("signEveryNthMessage must be between 1 and 100"); + } + + this.key = key; + this.lastTsigRecord = queryTsig; + this.signEveryNthMessage = signEveryNthMessage; + sharedHmac = this.key.initHmac(); + } + + /** + * Generate TSIG a signature for use of the message in a stream. + * + * @param message The message to sign. + */ + public void generate(Message message) { + generate(message, true); + } + + void generate(Message message, boolean isLastMessage) { + boolean isNthMessage = numGenerated % signEveryNthMessage == 0; + boolean isFirstMessage = numGenerated == 0; + if (isFirstMessage || isNthMessage || isLastMessage) { + TSIGRecord r = + key.generate( + message, + message.toWire(), + Rcode.NOERROR, + isFirstMessage ? lastTsigRecord : null, + isFirstMessage, + sharedHmac); + message.addRecord(r, Section.ADDITIONAL); + message.tsigState = Message.TSIG_SIGNED; + lastTsigRecord = r; + hmacAddSignature(sharedHmac, r); + } else { + byte[] responseBytes = message.toWire(Message.MAXLENGTH); + sharedHmac.update(responseBytes); + } + + numGenerated++; + } + } + + /** A utility class for verifying multiple message responses. */ public static class StreamVerifier { - /** A helper class for verifying multiple message responses. */ private final TSIG key; + private final Mac sharedHmac; + private final TSIGRecord queryTsig; private int nresponses; private int lastsigned; - private TSIGRecord lastTSIG; + + /** {@code null} or the detailed error when validation failed due to a {@link Rcode#FORMERR}. */ + @Getter private String errorMessage; /** Creates an object to verify a multiple message response */ public StreamVerifier(TSIG tsig, TSIGRecord queryTsig) { key = tsig; + sharedHmac = key.initHmac(); nresponses = 0; - lastTSIG = queryTsig; + this.queryTsig = queryTsig; } /** * Verifies a TSIG record on an incoming message that is part of a multiple message response. * TSIG records must be present on the first and last messages, and at least every 100 records - * in between. After calling this routine, Message.isVerified() may be called on this message. + * in between. After calling this routine,{@link Message#isVerified()} may be called on this + * message. * - * @param m The message - * @param b The message in unparsed form + *

This overload assumes that the verified message is not the last one, which is required to + * have a {@link TSIGRecord}. Use {@link #verify(Message, byte[], boolean)} to explicitly + * specify the last message or check that the message is verified with {@link + * Message#isVerified()}. + * + * @param message The message + * @param messageBytes The message in unparsed form * @return The result of the verification (as an Rcode) * @see Rcode */ - public int verify(Message m, byte[] b) { - TSIGRecord tsig = m.getTSIG(); + public int verify(Message message, byte[] messageBytes) { + return verify(message, messageBytes, false); + } + /** + * Verifies a TSIG record on an incoming message that is part of a multiple message response. + * TSIG records must be present on the first and last messages, and at least every 100 records + * in between. After calling this routine, {@link Message#isVerified()} may be called on this + * message. + * + * @param message The message + * @param messageBytes The message in unparsed form + * @param isLastMessage If true, verifies that the {@link Message} has an {@link TSIGRecord}. + * @return The result of the verification (as an Rcode) + * @see Rcode + * @since 3.5.3 + */ + public int verify(Message message, byte[] messageBytes, boolean isLastMessage) { + final String warningPrefix = "FORMERR: {}"; + TSIGRecord tsig = message.getTSIG(); + + // https://datatracker.ietf.org/doc/html/rfc8945#section-5.3.1 + // [...] a client that receives DNS messages and verifies TSIG MUST accept up to 99 + // intermediary messages without a TSIG and MUST verify that both the first and last message + // contain a TSIG. nresponses++; if (nresponses == 1) { - int result = key.verify(m, b, lastTSIG); - lastTSIG = tsig; - return result; + if (tsig != null) { + int result = key.verify(message, messageBytes, queryTsig, true, sharedHmac); + hmacAddSignature(sharedHmac, tsig); + lastsigned = nresponses; + return result; + } else { + errorMessage = "missing required signature on first message"; + log.debug(warningPrefix, errorMessage); + message.tsigState = Message.TSIG_FAILED; + return Rcode.FORMERR; + } } if (tsig != null) { - int result = key.verify(m, b, lastTSIG, false); + int result = key.verify(message, messageBytes, null, false, sharedHmac); lastsigned = nresponses; - lastTSIG = tsig; + hmacAddSignature(sharedHmac, tsig); return result; } else { boolean required = nresponses - lastsigned >= 100; if (required) { - log.debug("FORMERR: missing required signature on {}th message", nresponses); - m.tsigState = Message.TSIG_FAILED; + errorMessage = "Missing required signature on message #" + nresponses; + log.debug(warningPrefix, errorMessage); + message.tsigState = Message.TSIG_FAILED; + return Rcode.FORMERR; + } else if (isLastMessage) { + errorMessage = "Missing required signature on last message"; + log.debug(warningPrefix, errorMessage); + message.tsigState = Message.TSIG_FAILED; return Rcode.FORMERR; } else { - log.trace("Intermediate message {} without signature", nresponses); - m.tsigState = Message.TSIG_INTERMEDIATE; + errorMessage = "Intermediate message #" + nresponses + " without signature"; + log.debug(warningPrefix, errorMessage); + addUnsignedMessageToMac(message, messageBytes, sharedHmac); return Rcode.NOERROR; } } } + + private void addUnsignedMessageToMac(Message m, byte[] messageBytes, Mac hmac) { + byte[] header = m.getHeader().toWire(); + if (log.isTraceEnabled()) { + log.trace(hexdump.dump("TSIG-HMAC header", header)); + } + + hmac.update(header); + int len = messageBytes.length - header.length; + if (log.isTraceEnabled()) { + log.trace(hexdump.dump("TSIG-HMAC message after header", messageBytes, header.length, len)); + } + + hmac.update(messageBytes, header.length, len); + m.tsigState = Message.TSIG_INTERMEDIATE; + } } } diff --git a/src/main/java/org/xbill/DNS/ZoneTransferIn.java b/src/main/java/org/xbill/DNS/ZoneTransferIn.java index 177460ff..91fd73c8 100644 --- a/src/main/java/org/xbill/DNS/ZoneTransferIn.java +++ b/src/main/java/org/xbill/DNS/ZoneTransferIn.java @@ -49,17 +49,17 @@ public class ZoneTransferIn { private static final int AXFR = 6; private static final int END = 7; - private Name zname; + private final Name zname; private int qtype; private int dclass; - private long ixfr_serial; - private boolean want_fallback; + private final long ixfr_serial; + private final boolean want_fallback; private ZoneTransferHandler handler; private SocketAddress localAddress; - private SocketAddress address; + private final SocketAddress address; private TCPClient client; - private TSIG tsig; + private final TSIG tsig; private TSIG.StreamVerifier verifier; private Duration timeout = Duration.ofMinutes(15); @@ -155,7 +155,7 @@ public void startIXFRAdds(Record soa) { public void handleRecord(Record r) { if (ixfr != null) { Delta delta = ixfr.get(ixfr.size() - 1); - if (delta.adds.size() > 0) { + if (!delta.adds.isEmpty()) { delta.adds.add(r); } else { delta.deletes.add(r); @@ -166,9 +166,7 @@ public void handleRecord(Record r) { } } - private ZoneTransferIn() {} - - private ZoneTransferIn( + ZoneTransferIn( Name zone, int xfrtype, long serial, boolean fallback, SocketAddress address, TSIG key) { this.address = address; this.tsig = key; @@ -330,13 +328,17 @@ public void setLocalAddress(SocketAddress addr) { } private void openConnection() throws IOException { - client = new TCPClient(timeout); + client = createTcpClient(timeout); if (localAddress != null) { client.bind(localAddress); } client.connect(address); } + TCPClient createTcpClient(Duration timeout) throws IOException { + return new TCPClient(timeout); + } + private void sendQuery() throws IOException { Record question = Record.newRecord(zname, qtype, dclass); @@ -477,9 +479,10 @@ private void parseRR(Record rec) throws ZoneTransferException { private void closeConnection() { try { if (client != null) { - client.cleanup(); + client.close(); } } catch (IOException e) { + // Ignore } } @@ -490,7 +493,7 @@ private Message parseMessage(byte[] b) throws WireParseException { if (e instanceof WireParseException) { throw (WireParseException) e; } - throw new WireParseException("Error parsing message"); + throw new WireParseException("Error parsing message", e); } } @@ -499,14 +502,24 @@ private void doxfr() throws IOException, ZoneTransferException { while (state != END) { byte[] in = client.recv(); Message response = parseMessage(in); + List answers = response.getSection(Section.ANSWER); if (response.getHeader().getRcode() == Rcode.NOERROR && verifier != null) { - int error = verifier.verify(response, in); + int error = + verifier.verify(response, in, answers.get(answers.size() - 1).getType() == Type.SOA); if (error != Rcode.NOERROR) { - fail("TSIG failure: " + Rcode.TSIGstring(error)); + if (verifier.getErrorMessage() != null) { + fail( + "TSIG failure: " + + Rcode.TSIGstring(error) + + " (" + + verifier.getErrorMessage() + + ")"); + } else { + fail("TSIG failure: " + Rcode.TSIGstring(error)); + } } } - List answers = response.getSection(Section.ANSWER); if (state == INITIALSOA) { int rcode = response.getRcode(); if (rcode != Rcode.NOERROR) { @@ -533,10 +546,6 @@ private void doxfr() throws IOException, ZoneTransferException { for (Record answer : answers) { parseRR(answer); } - - if (state == END && verifier != null && !response.isVerified()) { - fail("last message must be signed"); - } } } diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java index d4e52b0e..66da6dd9 100644 --- a/src/test/java/org/xbill/DNS/TSIGTest.java +++ b/src/test/java/org/xbill/DNS/TSIGTest.java @@ -11,33 +11,49 @@ import static org.mockito.ArgumentMatchers.anyInt; import java.io.IOException; +import java.lang.reflect.Field; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.time.ZoneId; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import javax.crypto.spec.SecretKeySpec; +import lombok.Getter; +import org.apache.commons.io.IOUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.xbill.DNS.TSIG.StreamGenerator; import org.xbill.DNS.utils.base64; class TSIGTest { + private final TSIG defaultKey = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); + @Test void signedQuery() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Record question = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN); Message query = Message.newQuery(question); - query.setTSIG(key); + query.setTSIG(defaultKey); byte[] qbytes = query.toWire(512); assertEquals(1, qbytes[11]); Message qparsed = new Message(qbytes); - int result = key.verify(qparsed, qbytes, null); + int result = defaultKey.verify(qparsed, qbytes, null); assertEquals(Rcode.NOERROR, result); assertTrue(qparsed.isSigned()); assertTrue(qparsed.isVerified()); @@ -87,12 +103,10 @@ void queryStringAlgError() { @Test void queryIsLastAddMessageRecord() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Record rec = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN); OPTRecord opt = new OPTRecord(SimpleResolver.DEFAULT_EDNS_PAYLOADSIZE, 0, 0, 0); Message msg = Message.newQuery(rec); - msg.setTSIG(key); + msg.setTSIG(defaultKey); msg.addRecord(opt, Section.ADDITIONAL); byte[] bytes = msg.toWire(512); assertEquals(2, bytes[11]); // additional RR count, lower byte @@ -101,7 +115,7 @@ void queryIsLastAddMessageRecord() throws IOException { List additionalSection = parsed.getSection(Section.ADDITIONAL); assertEquals(Type.string(Type.OPT), Type.string(additionalSection.get(0).getType())); assertEquals(Type.string(Type.TSIG), Type.string(additionalSection.get(1).getType())); - int result = key.verify(parsed, bytes, null); + int result = defaultKey.verify(parsed, bytes, null); assertEquals(Rcode.NOERROR, result); assertTrue(parsed.isSigned()); assertTrue(parsed.isVerified()); @@ -116,8 +130,7 @@ void queryAndTsigApplyMisbehaves() throws IOException { assertFalse(msg.isSigned()); assertFalse(msg.isVerified()); - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - key.apply(msg, null); // additional RR count, lower byte + defaultKey.apply(msg, null); // additional RR count, lower byte byte[] bytes = msg.toWire(Message.MAXLENGTH); assertThrows(WireParseException.class, () -> new Message(bytes), "Expected TSIG error"); @@ -125,7 +138,6 @@ void queryAndTsigApplyMisbehaves() throws IOException { @Test void tsigInQueryIsLastViaResolver() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); SimpleResolver res = new SimpleResolver("127.0.0.1") { @Override @@ -140,7 +152,7 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e } } }; - res.setTSIGKey(key); + res.setTSIGKey(defaultKey); Name qname = Name.fromString("www.example.com."); Record question = Record.newRecord(qname, Type.A, DClass.IN); @@ -150,7 +162,7 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e List additionalSection = response.getSection(Section.ADDITIONAL); assertEquals(Type.string(Type.OPT), Type.string(additionalSection.get(0).getType())); assertEquals(Type.string(Type.TSIG), Type.string(additionalSection.get(1).getType())); - int result = key.verify(response, response.toWire(), null); + int result = defaultKey.verify(response, response.toWire(), null); assertEquals(Rcode.NOERROR, result); assertTrue(response.isSigned()); assertTrue(response.isVerified()); @@ -158,14 +170,12 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e @Test void unsignedQuerySignedResponse() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Name qname = Name.fromString("www.example."); Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); Message response = new Message(query.getHeader().getID()); - response.setTSIG(key, Rcode.NOERROR, null); + response.setTSIG(defaultKey, Rcode.NOERROR, null); response.getHeader().setFlag(Flags.QR); response.addRecord(question, Section.QUESTION); Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); @@ -173,7 +183,7 @@ void unsignedQuerySignedResponse() throws IOException { byte[] rbytes = response.toWire(Message.MAXLENGTH); Message rparsed = new Message(rbytes); - int result = key.verify(rparsed, rbytes, null); + int result = defaultKey.verify(rparsed, rbytes, null); assertEquals(Rcode.NOERROR, result); assertTrue(rparsed.isSigned()); assertTrue(rparsed.isVerified()); @@ -181,19 +191,17 @@ void unsignedQuerySignedResponse() throws IOException { @Test void signedQuerySignedResponse() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Name qname = Name.fromString("www.example."); Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); - query.setTSIG(key); + query.setTSIG(defaultKey); byte[] qbytes = query.toWire(Message.MAXLENGTH); Message qparsed = new Message(qbytes); assertNotNull(query.getGeneratedTSIG()); assertEquals(query.getGeneratedTSIG(), qparsed.getTSIG()); Message response = new Message(query.getHeader().getID()); - response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG()); + response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG()); response.getHeader().setFlag(Flags.QR); response.addRecord(question, Section.QUESTION); Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); @@ -201,7 +209,7 @@ void signedQuerySignedResponse() throws IOException { byte[] rbytes = response.toWire(Message.MAXLENGTH); Message rparsed = new Message(rbytes); - int result = key.verify(rparsed, rbytes, query.getGeneratedTSIG()); + int result = defaultKey.verify(rparsed, rbytes, query.getGeneratedTSIG()); assertEquals(Rcode.NOERROR, result); assertTrue(rparsed.isSigned()); assertTrue(rparsed.isVerified()); @@ -209,8 +217,6 @@ void signedQuerySignedResponse() throws IOException { @Test void signedQuerySignedResponseViaResolver() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Name qname = Name.fromString("www.example."); Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); @@ -231,7 +237,7 @@ void signedQuerySignedResponseViaResolver() throws IOException { Message qparsed = new Message(a.getArgument(3, byte[].class)); Message response = new Message(qparsed.getHeader().getID()); - response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG()); + response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG()); response.getHeader().setFlag(Flags.QR); response.addRecord(question, Section.QUESTION); Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); @@ -243,7 +249,7 @@ void signedQuerySignedResponseViaResolver() throws IOException { return f; }); SimpleResolver res = new SimpleResolver("127.0.0.1"); - res.setTSIGKey(key); + res.setTSIGKey(defaultKey); Message responseFromResolver = res.send(query); assertTrue(responseFromResolver.isSigned()); @@ -253,17 +259,15 @@ void signedQuerySignedResponseViaResolver() throws IOException { @Test void truncated() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Name qname = Name.fromString("www.example."); Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); - query.setTSIG(key, Rcode.NOERROR, null); + query.setTSIG(defaultKey, Rcode.NOERROR, null); byte[] qbytes = query.toWire(512); Message qparsed = new Message(qbytes); Message response = new Message(query.getHeader().getID()); - response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG()); + response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG()); response.getHeader().setFlag(Flags.QR); response.addRecord(question, Section.QUESTION); for (int i = 0; i < 40; i++) { @@ -274,7 +278,7 @@ void truncated() throws IOException { Message rparsed = new Message(rbytes); assertTrue(rparsed.getHeader().getFlag(Flags.TC)); - int result = key.verify(rparsed, rbytes, qparsed.getTSIG()); + int result = defaultKey.verify(rparsed, rbytes, qparsed.getTSIG()); assertEquals(Rcode.NOERROR, result); assertTrue(rparsed.isSigned()); assertTrue(rparsed.isVerified()); @@ -291,7 +295,6 @@ void rdataFromString() { @Test void testTSIGMessageClone() throws IOException { - TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); TSIGRecord old = new TSIGRecord( Name.fromConstantString("example."), @@ -312,7 +315,7 @@ void testTSIGMessageClone() throws IOException { response.addRecord(question, Section.QUESTION); response.addRecord( new ARecord(qname, DClass.IN, 0, InetAddress.getByName("127.0.0.1")), Section.ANSWER); - response.setTSIG(key, Rcode.NOERROR, old); + response.setTSIG(defaultKey, Rcode.NOERROR, old); byte[] responseBytes = response.toWire(Message.MAXLENGTH); assertNotNull(responseBytes); assertNotEquals(0, responseBytes.length); @@ -325,4 +328,326 @@ void testTSIGMessageClone() throws IOException { assertNotNull(cloneBytes); assertNotEquals(0, cloneBytes.length); } + + @ParameterizedTest + @ValueSource(ints = {-1, 0, 101}) + void testStreamGeneratorNthMessageArgument(int signEvery) { + assertThrows( + IllegalArgumentException.class, + () -> new TSIG.StreamGenerator(defaultKey, null, signEvery)); + } + + @Test + void testTSIGStreamVerifierMissingMinimumTsig() throws Exception { + MockMessageClient client = new MockMessageClient(defaultKey); + int numResponses = 200; + byte[] query = client.createQuery(); + List response; + try (MockMessageServer server = new MockMessageServer(defaultKey, numResponses, 200, false)) { + server.send(query); + response = server.getMessages(); + } + Map expectedRcodes = new HashMap<>(); + for (int i = 0; i < numResponses; i++) { + expectedRcodes.put(i, i < 100 ? Rcode.NOERROR : Rcode.FORMERR); + } + expectedRcodes.put(numResponses - 1, Rcode.BADSIG); + client.validateResponse(query, response, expectedRcodes, false); + } + + @ParameterizedTest(name = "testTSIGStreamVerifier(numResponses: {0}, signEvery: {1})") + @CsvSource({ + "20,1", + "53,6", + "105,7", + "1000,100", + }) + void testTSIGStreamVerifier(int numResponses, int signEvery) throws Exception { + MockMessageClient client = new MockMessageClient(defaultKey); + byte[] query = client.createQuery(); + List response; + try (MockMessageServer server = + new MockMessageServer(defaultKey, numResponses, signEvery, false)) { + + server.send(query); + response = server.getMessages(); + } + Map expectedRcodes = new HashMap<>(); + for (int i = 0; i < numResponses; i++) { + expectedRcodes.put(i, Rcode.NOERROR); + } + client.validateResponse(query, response, expectedRcodes, true); + } + + @ParameterizedTest(name = "testTSIGStreamVerifierLastMessage(numResponses: {0}, signEvery: {1})") + @CsvSource({ + "53,6", + "105,7", + "1000,100", + }) + void testTSIGStreamVerifierLastMessage(int numResponses, int signEvery) throws Exception { + MockMessageClient client = new MockMessageClient(defaultKey); + byte[] query = client.createQuery(); + List response; + try (MockMessageServer server = + new MockMessageServer(defaultKey, numResponses, signEvery, true)) { + + server.send(query); + response = server.getMessages(); + } + Map expectedRcodes = new HashMap<>(); + for (int i = 0; i < numResponses; i++) { + expectedRcodes.put(i, Rcode.NOERROR); + } + + expectedRcodes.put(numResponses - 1, Rcode.FORMERR); + client.validateResponse(query, response, expectedRcodes, false); + } + + @Test + void testFromTcpStream() throws IOException { + DNSInput request = new DNSInput(IOUtils.resourceToByteArray("/tsig-axfr/request.bin")); + byte[] queryBytes = request.readByteArray(request.readU16()); + Message query = new Message(queryBytes); + assertNotNull(query.getTSIG()); + TSIG key = + new TSIG( + TSIG.HMAC_SHA256, + Name.fromConstantString("dnssecishardtest."), + new SecretKeySpec( + Objects.requireNonNull( + base64.fromString("q4Gsu0nYoyub20//PATXhABobmrVUQyqq5TFzYHfC7o=")), + "HmacSHA256"), + Clock.fixed(Instant.parse("2023-11-01T20:52:08Z"), ZoneId.of("UTC"))); + + TSIG.StreamVerifier verifier = new TSIG.StreamVerifier(key, query.getTSIG()); + DNSInput response = new DNSInput(IOUtils.resourceToByteArray("/tsig-axfr/response.bin")); + + // Use a list, not a map, to keep the message order intact + List> messages = new ArrayList<>(); + while (response.remaining() > 0) { + byte[] messageBytes = response.readByteArray(response.readU16()); + Message message = new Message(messageBytes); + messages.add(new AbstractMap.SimpleEntry<>(message, messageBytes)); + } + + for (int i = 0; i < messages.size(); i++) { + Map.Entry e = messages.get(i); + assertEquals( + Rcode.NOERROR, verifier.verify(e.getKey(), e.getValue(), i == messages.size() - 1)); + } + } + + @Test + void testAxfrLastNotSignedError() throws Exception { + Name name = Name.fromConstantString("example.com."); + ZoneTransferIn client = + new ZoneTransferIn( + name, + Type.AXFR, + 0, + false, + new InetSocketAddress(InetAddress.getLocalHost(), 53), + defaultKey) { + @Override + TCPClient createTcpClient(Duration timeout) throws IOException { + return new MockMessageServer(defaultKey, 200, 20, true); + } + }; + + ZoneTransferException exception = + assertThrows(ZoneTransferException.class, () -> client.run(new ZoneBuilderAxfrHandler())); + assertTrue(exception.getMessage().contains(Rcode.TSIGstring(Rcode.FORMERR))); + assertTrue(exception.getMessage().contains("last")); + } + + @Test + void testAxfr() throws Exception { + Name name = Name.fromConstantString("example.com."); + ZoneTransferIn client = + new ZoneTransferIn( + name, + Type.AXFR, + 0, + false, + new InetSocketAddress(InetAddress.getLocalHost(), 53), + defaultKey) { + @Override + TCPClient createTcpClient(Duration timeout) throws IOException { + return new MockMessageServer(defaultKey, 200, 20, false); + } + }; + + ZoneBuilderAxfrHandler handler = new ZoneBuilderAxfrHandler(); + client.run(handler); + // soa on first message, + a record on every message, +soa on last message + assertEquals(202, handler.getRecords().size()); + } + + @Getter + private static class ZoneBuilderAxfrHandler implements ZoneTransferIn.ZoneTransferHandler { + private final List records = new ArrayList<>(); + + @Override + public void startAXFR() {} + + @Override + public void startIXFR() {} + + @Override + public void startIXFRDeletes(Record soa) {} + + @Override + public void startIXFRAdds(Record soa) {} + + @Override + public void handleRecord(Record r) { + records.add(r); + } + } + + private static class MockMessageClient { + private final TSIG key; + + MockMessageClient(TSIG key) { + this.key = key; + } + + byte[] createQuery() throws TextParseException { + Name qname = Name.fromString("www.example."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message query = Message.newQuery(question); + query.setTSIG(key); + return query.toWire(Message.MAXLENGTH); + } + + public void validateResponse( + byte[] query, + List responses, + Map expectedRcodes, + boolean lastResponseSignedState) + throws IOException { + Message queryMessage = new Message(query); + TSIG.StreamVerifier verifier = new TSIG.StreamVerifier(key, queryMessage.getTSIG()); + + Map actualRcodes = new HashMap<>(); + for (int i = 0; i < responses.size(); i++) { + boolean isLastMessage = i == responses.size() - 1; + byte[] renderedMessage = responses.get(i).toWire(Message.MAXLENGTH); + Message messageFromWire = new Message(renderedMessage); + actualRcodes.put(i, verifier.verify(messageFromWire, renderedMessage, isLastMessage)); + if (isLastMessage) { + assertEquals(messageFromWire.isVerified(), lastResponseSignedState); + } + } + + assertEquals(expectedRcodes, actualRcodes); + } + } + + private static class MockMessageServer extends TCPClient { + private final TSIG key; + private final int responseMessageCount; + private final int signEvery; + private final boolean skipLast; + @Getter private List messages; + private int recvCalls; + + MockMessageServer(TSIG key, int responseMessageCount, int signEvery, boolean skipLast) + throws IOException { + super(Duration.ZERO); + this.key = key; + this.responseMessageCount = responseMessageCount; + this.signEvery = signEvery; + this.skipLast = skipLast; + } + + @Override + void bind(SocketAddress addr) { + // do nothing + } + + @Override + void connect(SocketAddress addr) { + // do nothing + } + + @Override + public void close() { + // do nothing + } + + @Override + void send(byte[] queryMessageBytes) throws IOException { + Message parsedQueryMessage = new Message(queryMessageBytes); + assertNotNull(parsedQueryMessage.getTSIG()); + + messages = new LinkedList<>(); + StreamGenerator generator; + try { + generator = getStreamGenerator(signEvery, parsedQueryMessage); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new IOException(e); + } + + Name queryName = parsedQueryMessage.getQuestion().getName(); + Record soa = + new SOARecord( + queryName, + DClass.IN, + 300, + new Name("ns1", queryName), + new Name("admin", queryName), + 1, + 3600, + 1, + 3600, + 1800); + for (int i = 0; i < responseMessageCount; i++) { + Message response = new Message(parsedQueryMessage.getHeader().getID()); + response.getHeader().setFlag(Flags.QR); + response.addRecord(parsedQueryMessage.getQuestion(), Section.QUESTION); + if (i == 0) { + response.addRecord(soa, Section.ANSWER); + } + Record answer = + new ARecord( + parsedQueryMessage.getQuestion().getName(), + DClass.IN, + 300, + InetAddress.getByAddress(ByteBuffer.allocate(4).putInt(i).array())); + response.addRecord(answer, Section.ANSWER); + + if (i == responseMessageCount - 1) { + response.addRecord(soa, Section.ANSWER); + } + + generator.generate(response, !skipLast && i == responseMessageCount - 1); + messages.add(response); + } + } + + @Override + byte[] recv() { + return messages.get(recvCalls++).toWire(Message.MAXLENGTH); + } + + private StreamGenerator getStreamGenerator(int signEvery, Message parsedQueryMessage) + throws NoSuchFieldException, IllegalAccessException { + TSIGRecord queryMessageTSIG = parsedQueryMessage.getTSIG(); + StreamGenerator generator; + + // Hack for testing invalid server responses, the constructor would normally prevent such an + // invalid argument + if (signEvery > 100) { + generator = new StreamGenerator(key, queryMessageTSIG, 1); + Field signEveryNthMessage = StreamGenerator.class.getDeclaredField("signEveryNthMessage"); + signEveryNthMessage.setAccessible(true); + signEveryNthMessage.set(generator, signEvery); + } else { + generator = new StreamGenerator(key, queryMessageTSIG, signEvery); + } + return generator; + } + } } diff --git a/src/test/resources/tsig-axfr/request.bin b/src/test/resources/tsig-axfr/request.bin new file mode 100644 index 00000000..9e110f7e Binary files /dev/null and b/src/test/resources/tsig-axfr/request.bin differ diff --git a/src/test/resources/tsig-axfr/response-messages.txt b/src/test/resources/tsig-axfr/response-messages.txt new file mode 100644 index 00000000..492d3d2a --- /dev/null +++ b/src/test/resources/tsig-axfr/response-messages.txt @@ -0,0 +1 @@ +dig @dnssec1.stage.arin.net -p 53 AXFR dnssecishard.com. -y 'hmac-sha256:dnssecishardtest:q4Gsu0nYoyub20//PATXhABobmrVUQyqq5TFzYHfC7o=' diff --git a/src/test/resources/tsig-axfr/response.bin b/src/test/resources/tsig-axfr/response.bin new file mode 100644 index 00000000..f861f8b1 Binary files /dev/null and b/src/test/resources/tsig-axfr/response.bin differ