diff --git a/src/main/java/org/xbill/DNS/DohResolver.java b/src/main/java/org/xbill/DNS/DohResolver.java index d6f27b93..6e7d73f3 100644 --- a/src/main/java/org/xbill/DNS/DohResolver.java +++ b/src/main/java/org/xbill/DNS/DohResolver.java @@ -640,7 +640,7 @@ private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { return; } - int error = tsig.verify(response, b, query.getTSIG()); + int error = tsig.verify(response, b, query.getGeneratedTSIG()); log.debug( "TSIG verify for query {}, {}/{}: {}", query.getHeader().getID(), diff --git a/src/main/java/org/xbill/DNS/Message.java b/src/main/java/org/xbill/DNS/Message.java index bbaf1ef3..1cdfa483 100644 --- a/src/main/java/org/xbill/DNS/Message.java +++ b/src/main/java/org/xbill/DNS/Message.java @@ -32,6 +32,7 @@ public class Message implements Cloneable { private List[] sections; private int size; private TSIG tsigkey; + private TSIGRecord generatedTsig; private TSIGRecord querytsig; private int tsigerror; private Resolver resolver; @@ -274,7 +275,7 @@ public boolean findRRset(Name name, int type) { */ public Record getQuestion() { List l = sections[Section.QUESTION]; - if (l == null || l.size() == 0) { + if (l == null || l.isEmpty()) { return null; } return l.get(0); @@ -300,6 +301,16 @@ public TSIGRecord getTSIG() { return (TSIGRecord) rec; } + /** + * Gets the generated {@link TSIGRecord}. Only valid if the messages has been converted to wire + * format with {@link #toWire(int)} before. + * + * @return A generated TSIG record or {@code null}. + */ + TSIGRecord getGeneratedTSIG() { + return generatedTsig; + } + /** * Was this message signed by a TSIG? * @@ -325,9 +336,9 @@ public boolean isVerified() { * @see Section */ public OPTRecord getOPT() { - for (Record record : getSection(Section.ADDITIONAL)) { - if (record instanceof OPTRecord) { - return (OPTRecord) record; + for (Record r : getSection(Section.ADDITIONAL)) { + if (r instanceof OPTRecord) { + return (OPTRecord) r; } } return null; @@ -516,6 +527,7 @@ private void toWire(DNSOutput out, int maxLength) { TSIGRecord tsigrec = tsigkey.generate(this, out.toByteArray(), tsigerror, querytsig); tsigrec.toWire(out, Section.ADDITIONAL, c); + generatedTsig = tsigrec; out.writeU16At(additionalCount + 1, startpos + 10); } } @@ -536,9 +548,9 @@ public byte[] toWire() { /** * Returns an array containing the wire format representation of the Message with the specified * maximum length. This will generate a truncated message (with the TC bit) if the message doesn't - * fit, and will also sign the message with the TSIG key set by a call to setTSIG(). This method - * may return an empty byte array if the message could not be rendered at all; this could happen - * if maxLength is smaller than a DNS header, for example. + * fit, and will also sign the message with the TSIG key set by a call to {@link #setTSIG(TSIG, + * int, TSIGRecord)}. This method may return an empty byte array if the message could not be + * rendered at all; this could happen if maxLength is smaller than a DNS header, for example. * *

Do NOT use this method in conjunction with {@link TSIG#apply(Message, TSIGRecord)}, it * produces inconsistent results! Use {@link #setTSIG(TSIG, int, TSIGRecord)} instead. @@ -556,6 +568,16 @@ public byte[] toWire(int maxLength) { return out.toByteArray(); } + /** + * Sets the TSIG key to sign a message. + * + * @param key The TSIG key. + * @since 3.5.1 + */ + public void setTSIG(TSIG key) { + setTSIG(key, Rcode.NOERROR, null); + } + /** * Sets the TSIG key and other necessary information to sign a message. * @@ -668,6 +690,9 @@ public Message clone() { if (querytsig != null) { m.querytsig = (TSIGRecord) querytsig.cloneRecord(); } + if (generatedTsig != null) { + m.generatedTsig = (TSIGRecord) generatedTsig.cloneRecord(); + } return m; } diff --git a/src/main/java/org/xbill/DNS/SimpleResolver.java b/src/main/java/org/xbill/DNS/SimpleResolver.java index eeeb28ac..4a564bb1 100644 --- a/src/main/java/org/xbill/DNS/SimpleResolver.java +++ b/src/main/java/org/xbill/DNS/SimpleResolver.java @@ -36,7 +36,8 @@ public class SimpleResolver implements Resolver { private InetSocketAddress address; private InetSocketAddress localAddress; - private boolean useTCP, ignoreTruncation; + private boolean useTCP; + private boolean ignoreTruncation; private OPTRecord queryOPT = new OPTRecord(DEFAULT_EDNS_PAYLOADSIZE, 0, 0, 0); private TSIG tsig; private Duration timeoutValue = Duration.ofSeconds(10); @@ -267,12 +268,13 @@ private Message parseMessage(byte[] b) throws WireParseException { } } - private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { + private void verifyTSIG(Message query, Message response, byte[] b) { if (tsig == null) { return; } - int error = tsig.verify(response, b, query.getTSIG()); - log.debug("TSIG verify: {}", Rcode.TSIGstring(error)); + int error = tsig.verify(response, b, query.getGeneratedTSIG()); + log.debug( + "TSIG verify on message id {}: {}", query.getHeader().getID(), Rcode.TSIGstring(error)); } private void applyEDNS(Message query) { @@ -431,7 +433,7 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e return f; } - verifyTSIG(query, response, in, tsig); + verifyTSIG(query, response, in); if (!tcp && !ignoreTruncation && response.getHeader().getFlag(Flags.TC)) { if (log.isTraceEnabled()) { log.trace( @@ -466,8 +468,8 @@ private Message sendAXFR(Message query) throws IOException { response.getHeader().setFlag(Flags.AA); response.getHeader().setFlag(Flags.QR); response.addRecord(query.getQuestion(), Section.QUESTION); - for (Record record : records) { - response.addRecord(record, Section.ANSWER); + for (Record r : records) { + response.addRecord(r, Section.ANSWER); } return response; } diff --git a/src/main/java/org/xbill/DNS/TSIG.java b/src/main/java/org/xbill/DNS/TSIG.java index f55f4eac..aabd364b 100644 --- a/src/main/java/org/xbill/DNS/TSIG.java +++ b/src/main/java/org/xbill/DNS/TSIG.java @@ -296,7 +296,7 @@ public TSIG(Name algorithm, String name, String key) { * @param key The shared key's data represented as a base64 encoded string. * @throws IllegalArgumentException The key name is an invalid name * @throws IllegalArgumentException The key data is improperly encoded - * @see RFC8945 + * @see RFC8945 */ public TSIG(String algorithm, String name, String key) { this(algorithmToName(algorithm), name, key); @@ -543,15 +543,15 @@ public byte verify(Message m, byte[] b, int length, TSIGRecord old) { * routine, Message.isVerified() may be called on this message. * * @param m The message to verify - * @param b An array containing the message in unparsed form. This is necessary since TSIG signs - * the message in wire format, and we can't recreate the exact wire format (with the same name - * compression). - * @param old If this message is a response, the TSIG from the request + * @param messageBytes An array containing the message in unparsed form. This is necessary since + * TSIG signs the message in wire format, and we can't recreate the exact wire format (with + * the same name compression). + * @param requestTSIG If this message is a response, the TSIG from the request * @return The result of the verification (as an Rcode) * @see Rcode */ - public int verify(Message m, byte[] b, TSIGRecord old) { - return verify(m, b, old, true); + public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG) { + return verify(m, messageBytes, requestTSIG, true); } /** @@ -560,10 +560,10 @@ public int verify(Message m, byte[] b, TSIGRecord old) { * routine, Message.isVerified() may be called on this message. * * @param m The message to verify - * @param b An array containing the message in unparsed form. This is necessary since TSIG signs - * the message in wire format, and we can't recreate the exact wire format (with the same name - * compression). - * @param old If this message is a response, the TSIG from the request + * @param messageBytes An array containing the message in unparsed form. This is necessary since + * TSIG signs the message in wire format, and we can't recreate the exact wire format (with + * the same name compression). + * @param requestTSIG If this message is a response, the TSIG from the request * @param fullSignature {@code true} if this message is the first of many in a TCP connection and * all TSIG variables (rfc2845, 3.4.2.) should be included in the signature. {@code false} for * subsequent messages with reduced TSIG variables set (rfc2845, 4.4.). @@ -571,7 +571,7 @@ public int verify(Message m, byte[] b, TSIGRecord old) { * @see Rcode * @since 3.2 */ - public int verify(Message m, byte[] b, TSIGRecord old, boolean fullSignature) { + public int verify(Message m, byte[] messageBytes, TSIGRecord requestTSIG, boolean fullSignature) { m.tsigState = Message.TSIG_FAILED; TSIGRecord tsig = m.getTSIG(); if (tsig == null) { @@ -580,7 +580,8 @@ public int verify(Message m, byte[] b, TSIGRecord old, boolean fullSignature) { if (!tsig.getName().equals(name) || !tsig.getAlgorithm().equals(alg)) { log.debug( - "BADKEY failure, expected: {}/{}, actual: {}/{}", + "BADKEY failure on message id {}, expected: {}/{}, actual: {}/{}", + m.getHeader().getID(), name, alg, tsig.getName(), @@ -589,8 +590,8 @@ public int verify(Message m, byte[] b, TSIGRecord old, boolean fullSignature) { } Mac hmac = initHmac(); - if (old != null && tsig.getError() != Rcode.BADKEY && tsig.getError() != Rcode.BADSIG) { - hmacAddSignature(hmac, old); + if (requestTSIG != null && tsig.getError() != Rcode.BADKEY && tsig.getError() != Rcode.BADSIG) { + hmacAddSignature(hmac, requestTSIG); } m.getHeader().decCount(Section.ADDITIONAL); @@ -603,9 +604,9 @@ public int verify(Message m, byte[] b, TSIGRecord old, boolean fullSignature) { int len = m.tsigstart - header.length; if (log.isTraceEnabled()) { - log.trace(hexdump.dump("TSIG-HMAC message after header", b, header.length, len)); + log.trace(hexdump.dump("TSIG-HMAC message after header", messageBytes, header.length, len)); } - hmac.update(b, header.length, len); + hmac.update(messageBytes, header.length, len); DNSOutput out = new DNSOutput(); if (fullSignature) { diff --git a/src/main/java/org/xbill/DNS/tools/jnamed.java b/src/main/java/org/xbill/DNS/tools/jnamed.java index e71dcadf..63393862 100644 --- a/src/main/java/org/xbill/DNS/tools/jnamed.java +++ b/src/main/java/org/xbill/DNS/tools/jnamed.java @@ -64,8 +64,8 @@ public jnamed(String conffile) throws IOException, ZoneTransferException { FileInputStream fs; InputStreamReader isr; BufferedReader br; - List ports = new ArrayList(); - List addresses = new ArrayList(); + List ports = new ArrayList<>(); + List addresses = new ArrayList<>(); try { fs = new FileInputStream(conffile); isr = new InputStreamReader(fs); @@ -76,9 +76,9 @@ public jnamed(String conffile) throws IOException, ZoneTransferException { } try { - caches = new HashMap(); + caches = new HashMap<>(); znames = new HashMap<>(); - TSIGs = new HashMap(); + TSIGs = new HashMap<>(); String line; while ((line = br.readLine()) != null) { @@ -127,21 +127,20 @@ public jnamed(String conffile) throws IOException, ZoneTransferException { } } - if (ports.size() == 0) { + if (ports.isEmpty()) { ports.add(53); } - if (addresses.size() == 0) { + if (addresses.isEmpty()) { addresses.add(Address.getByAddress("0.0.0.0")); } - for (Object address : addresses) { - InetAddress addr = (InetAddress) address; - for (Object o : ports) { - int port = (Integer) o; - addUDP(addr, port); - addTCP(addr, port); - System.out.println("jnamed: listening on " + addrport(addr, port)); + for (InetAddress address : addresses) { + for (Integer o : ports) { + int port = o; + addUDP(address, port); + addTCP(address, port); + System.out.println("jnamed: listening on " + addrport(address, port)); } } System.out.println("jnamed: running"); @@ -172,12 +171,7 @@ public void addTSIG(String algstr, String namestr, String key) throws IOExceptio } public Cache getCache(int dclass) { - Cache c = caches.get(dclass); - if (c == null) { - c = new Cache(dclass); - caches.put(dclass, c); - } - return c; + return caches.computeIfAbsent(dclass, Cache::new); } public Zone findBestZone(Name name) { @@ -197,7 +191,7 @@ public Zone findBestZone(Name name) { return null; } - public RRset findExactMatch(Name name, int type, int dclass, boolean glue) { + public RRset findExactMatch(Name name, int type, int dclass, boolean glue) { Zone zone = findBestZone(name); if (zone != null) { return zone.findExactMatch(name, type); @@ -217,8 +211,7 @@ public RRset findExactMatch(Name name, int type, int dclass, } } - void addRRset( - Name name, Message response, RRset rrset, int section, int flags) { + void addRRset(Name name, Message response, RRset rrset, int section, int flags) { for (int s = 1; s <= section; s++) { if (response.findRRset(name, rrset.getType(), s)) { return; @@ -403,6 +396,7 @@ byte[] doAXFR(Name name, Message query, TSIG tsig, TSIGRecord qtsig, Socket s) { try { s.close(); } catch (IOException ex) { + // ignore } return null; } @@ -414,7 +408,6 @@ byte[] doAXFR(Name name, Message query, TSIG tsig, TSIGRecord qtsig, Socket s) { */ byte[] generateReply(Message query, byte[] in, Socket s) { Header header; - boolean badversion; int maxLength; int flags = 0; @@ -515,13 +508,12 @@ public byte[] errorMessage(Message query, int rcode) { } public void TCPclient(Socket s) { - try { + try (InputStream is = s.getInputStream()) { int inLength; DataInputStream dataIn; DataOutputStream dataOut; byte[] in; - InputStream is = s.getInputStream(); dataIn = new DataInputStream(is); inLength = dataIn.readUnsignedShort(); in = new byte[inLength]; @@ -544,11 +536,6 @@ public void TCPclient(Socket s) { } catch (IOException e) { System.out.println( "TCPclient(" + addrport(s.getLocalAddress(), s.getLocalPort()) + "): " + e); - } finally { - try { - s.close(); - } catch (IOException e) { - } } } diff --git a/src/test/java/org/xbill/DNS/MessageTest.java b/src/test/java/org/xbill/DNS/MessageTest.java index a4637996..86100147 100644 --- a/src/test/java/org/xbill/DNS/MessageTest.java +++ b/src/test/java/org/xbill/DNS/MessageTest.java @@ -47,102 +47,124 @@ import org.junit.jupiter.api.Test; import org.xbill.DNS.utils.base64; -public class MessageTest { - static class Test_init { - @Test - void ctor_0arg() { - Message m = new Message(); - assertTrue(m.getSection(0).isEmpty()); - assertTrue(m.getSection(1).isEmpty()); - assertTrue(m.getSection(3).isEmpty()); - assertTrue(m.getSection(2).isEmpty()); - assertThrows(IndexOutOfBoundsException.class, () -> m.getSection(4)); - Header h = m.getHeader(); - assertEquals(0, h.getCount(0)); - assertEquals(0, h.getCount(1)); - assertEquals(0, h.getCount(2)); - assertEquals(0, h.getCount(3)); - } +class MessageTest { + @Test + void ctor_0arg() { + Message m = new Message(); + assertTrue(m.getSection(0).isEmpty()); + assertTrue(m.getSection(1).isEmpty()); + assertTrue(m.getSection(3).isEmpty()); + assertTrue(m.getSection(2).isEmpty()); + assertThrows(IndexOutOfBoundsException.class, () -> m.getSection(4)); + Header h = m.getHeader(); + assertEquals(0, h.getCount(0)); + assertEquals(0, h.getCount(1)); + assertEquals(0, h.getCount(2)); + assertEquals(0, h.getCount(3)); + } - @Test - void ctor_1arg() { - Message m = new Message(10); - assertEquals(new Header(10).toString(), m.getHeader().toString()); - assertTrue(m.getSection(0).isEmpty()); - assertTrue(m.getSection(1).isEmpty()); - assertTrue(m.getSection(2).isEmpty()); - assertTrue(m.getSection(3).isEmpty()); - assertThrows(IndexOutOfBoundsException.class, () -> m.getSection(4)); - Header h = m.getHeader(); - assertEquals(0, h.getCount(0)); - assertEquals(0, h.getCount(1)); - assertEquals(0, h.getCount(2)); - assertEquals(0, h.getCount(3)); - } + @Test + void ctor_1arg() { + Message m = new Message(10); + assertEquals(new Header(10).toString(), m.getHeader().toString()); + assertTrue(m.getSection(0).isEmpty()); + assertTrue(m.getSection(1).isEmpty()); + assertTrue(m.getSection(2).isEmpty()); + assertTrue(m.getSection(3).isEmpty()); + assertThrows(IndexOutOfBoundsException.class, () -> m.getSection(4)); + Header h = m.getHeader(); + assertEquals(0, h.getCount(0)); + assertEquals(0, h.getCount(1)); + assertEquals(0, h.getCount(2)); + assertEquals(0, h.getCount(3)); + } - @Test - void ctor_byteBuffer() throws IOException { - byte[] arr = - base64.fromString( - "EEuBgAABAAEABAAIA3d3dwZnb29nbGUDY29tAAABAAHADAABAAEAAAAaAASO+rokwBAAAgABAAFHCwAGA25zMcAQwBAAAgABAAFHCwAGA25zNMAQwBAAAgABAAFHCwAGA25zM8AQwBAAAgABAAFHCwAGA25zMsAQwDwAAQABAADObwAE2O8gCsByAAEAAQABrVEABNjvIgrAYAABAAEAAVqZAATY7yQKwE4AAQABAAK9RQAE2O8mCsA8ABwAAQAD4a0AECABSGBIAgAyAAAAAAAAAArAcgAcAAEAAtDgABAgAUhgSAIANAAAAAAAAAAKwGAAHAABAACSagAQIAFIYEgCADYAAAAAAAAACsBOABwAAQAErVoAECABSGBIAgA4AAAAAAAAAAo="); + @Test + void ctor_byteBuffer() throws IOException { + byte[] arr = + base64.fromString( + "EEuBgAABAAEABAAIA3d3dwZnb29nbGUDY29tAAABAAHADAABAAEAAAAaAASO+rokwBAAAgABAAFHCwAGA25zMcAQwBAAAgABAAFHCwAGA25zNMAQwBAAAgABAAFHCwAGA25zM8AQwBAAAgABAAFHCwAGA25zMsAQwDwAAQABAADObwAE2O8gCsByAAEAAQABrVEABNjvIgrAYAABAAEAAVqZAATY7yQKwE4AAQABAAK9RQAE2O8mCsA8ABwAAQAD4a0AECABSGBIAgAyAAAAAAAAAArAcgAcAAEAAtDgABAgAUhgSAIANAAAAAAAAAAKwGAAHAABAACSagAQIAFIYEgCADYAAAAAAAAACsBOABwAAQAErVoAECABSGBIAgA4AAAAAAAAAAo="); - ByteBuffer wrap = ByteBuffer.allocate(arr.length + 2); + ByteBuffer wrap = ByteBuffer.allocate(arr.length + 2); - // prepend length, like when reading a response from a TCP channel - wrap.putShort((short) arr.length); - wrap.put(arr); - wrap.flip(); - wrap.getShort(); // read the prepended length + // prepend length, like when reading a response from a TCP channel + wrap.putShort((short) arr.length); + wrap.put(arr); + wrap.flip(); + wrap.getShort(); // read the prepended length - Message m = new Message(wrap); - assertEquals(Name.fromConstantString("www.google.com."), m.getQuestion().getName()); - } + Message m = new Message(wrap); + assertEquals(Name.fromConstantString("www.google.com."), m.getQuestion().getName()); + } - @Test - void newQuery() throws TextParseException, UnknownHostException { - Name n = Name.fromString("The.Name."); - ARecord ar = new ARecord(n, DClass.IN, 1, InetAddress.getByName("192.168.101.110")); - - Message m = Message.newQuery(ar); - assertEquals(1, m.getSection(0).size()); - assertEquals(ar, m.getSection(0).get(0)); - assertTrue(m.getSection(1).isEmpty()); - assertTrue(m.getSection(2).isEmpty()); - assertTrue(m.getSection(3).isEmpty()); - - Header h = m.getHeader(); - assertEquals(1, h.getCount(0)); - assertEquals(0, h.getCount(1)); - assertEquals(0, h.getCount(2)); - assertEquals(0, h.getCount(3)); - assertEquals(Opcode.QUERY, h.getOpcode()); - assertTrue(h.getFlag(Flags.RD)); - } + @Test + void newQuery() throws TextParseException, UnknownHostException { + Name n = Name.fromString("The.Name."); + ARecord ar = new ARecord(n, DClass.IN, 1, InetAddress.getByName("192.168.101.110")); + + Message m = Message.newQuery(ar); + assertEquals(1, m.getSection(0).size()); + assertEquals(ar, m.getSection(0).get(0)); + assertTrue(m.getSection(1).isEmpty()); + assertTrue(m.getSection(2).isEmpty()); + assertTrue(m.getSection(3).isEmpty()); - @Test - void sectionToWire() throws IOException { - Message m = new Message(4711); - Name n2 = Name.fromConstantString("test2.example."); - m.addRecord(new TXTRecord(n2, DClass.IN, 86400, "other record"), Section.ADDITIONAL); - Name n = Name.fromConstantString("test.example."); - m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -1-"), Section.ADDITIONAL); - m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -2-"), Section.ADDITIONAL); - m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -3-"), Section.ADDITIONAL); - m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -4-"), Section.ADDITIONAL); - m.addRecord(new OPTRecord(512, 0, 0, 0), Section.ADDITIONAL); - - for (int i = 5; i < 50; i++) { - m.addRecord( - new TXTRecord(n, DClass.IN, 86400, "example text -" + i + "-"), Section.ADDITIONAL); - } - - byte[] binary = m.toWire(512); - Message m2 = new Message(binary); - assertEquals(2, m2.getHeader().getCount(Section.ADDITIONAL)); - List records = m2.getSection(Section.ADDITIONAL); - assertEquals(2, records.size()); - assertEquals(TXTRecord.class, records.get(0).getClass()); - assertEquals(OPTRecord.class, records.get(1).getClass()); + Header h = m.getHeader(); + assertEquals(1, h.getCount(0)); + assertEquals(0, h.getCount(1)); + assertEquals(0, h.getCount(2)); + assertEquals(0, h.getCount(3)); + assertEquals(Opcode.QUERY, h.getOpcode()); + assertTrue(h.getFlag(Flags.RD)); + } + + @Test + void sectionToWire() throws IOException { + Message m = new Message(4711); + Name n2 = Name.fromConstantString("test2.example."); + m.addRecord(new TXTRecord(n2, DClass.IN, 86400, "other record"), Section.ADDITIONAL); + Name n = Name.fromConstantString("test.example."); + m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -1-"), Section.ADDITIONAL); + m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -2-"), Section.ADDITIONAL); + m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -3-"), Section.ADDITIONAL); + m.addRecord(new TXTRecord(n, DClass.IN, 86400, "example text -4-"), Section.ADDITIONAL); + m.addRecord(new OPTRecord(512, 0, 0, 0), Section.ADDITIONAL); + + for (int i = 5; i < 50; i++) { + m.addRecord( + new TXTRecord(n, DClass.IN, 86400, "example text -" + i + "-"), Section.ADDITIONAL); } + + byte[] binary = m.toWire(512); + Message m2 = new Message(binary); + assertEquals(2, m2.getHeader().getCount(Section.ADDITIONAL)); + List records = m2.getSection(Section.ADDITIONAL); + assertEquals(2, records.size()); + assertEquals(TXTRecord.class, records.get(0).getClass()); + assertEquals(OPTRecord.class, records.get(1).getClass()); + } + + @Test + void testQuestionClone() { + Name qname = Name.fromConstantString("www.example."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message query = Message.newQuery(question); + Message clone = query.clone(); + assertEquals(query.getHeader().getID(), clone.getHeader().getID()); + assertEquals(query.getQuestion().getName(), clone.getQuestion().getName()); + } + + @Test + void testResponseClone() throws UnknownHostException { + Name qname = Name.fromConstantString("www.example."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message response = new Message(); + response.getHeader().setFlag(Flags.QR); + response.addRecord(question, Section.QUESTION); + response.addRecord( + new ARecord(qname, DClass.IN, 0, InetAddress.getByName("127.0.0.1")), Section.ANSWER); + Message clone = response.clone(); + assertEquals(clone.getQuestion(), response.getQuestion()); + assertEquals(clone.getSection(Section.ANSWER), response.getSection(Section.ANSWER)); } } diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java index 16d3df55..8ad800b9 100644 --- a/src/test/java/org/xbill/DNS/TSIGTest.java +++ b/src/test/java/org/xbill/DNS/TSIGTest.java @@ -3,33 +3,44 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.time.Instant; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.xbill.DNS.utils.base64; class TSIGTest { @Test - void TSIG_query() throws IOException { + void signedQuery() throws IOException { TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Name qname = Name.fromString("www.example."); - Record rec = Record.newRecord(qname, Type.A, DClass.IN); - Message msg = Message.newQuery(rec); - msg.setTSIG(key, Rcode.NOERROR, null); - byte[] bytes = msg.toWire(512); - assertEquals(1, bytes[11]); + Record question = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN); + Message query = Message.newQuery(question); + query.setTSIG(key); + byte[] qbytes = query.toWire(512); + assertEquals(1, qbytes[11]); - Message parsed = new Message(bytes); - int result = key.verify(parsed, bytes, null); + Message qparsed = new Message(qbytes); + int result = key.verify(qparsed, qbytes, null); assertEquals(Rcode.NOERROR, result); - assertTrue(parsed.isSigned()); + assertTrue(qparsed.isSigned()); + assertTrue(qparsed.isVerified()); } /** @@ -51,13 +62,12 @@ void TSIG_query() throws IOException { "HmacMD5", "HmacSHA256" }) - void TSIG_query_stringalg(String alg) throws IOException { + void queryStringAlg(String alg) throws IOException { TSIG key = new TSIG(alg, "example.", "12345678"); - Name qname = Name.fromString("www.example."); - Record rec = Record.newRecord(qname, Type.A, DClass.IN); + Record rec = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN); Message msg = Message.newQuery(rec); - msg.setTSIG(key, Rcode.NOERROR, null); + msg.setTSIG(key); byte[] bytes = msg.toWire(512); assertEquals(1, bytes[11]); @@ -65,24 +75,24 @@ void TSIG_query_stringalg(String alg) throws IOException { int result = key.verify(parsed, bytes, null); assertEquals(Rcode.NOERROR, result); assertTrue(parsed.isSigned()); + assertTrue(parsed.isVerified()); } /** Confirm error thrown with illegal algorithm name. */ @Test - void TSIG_query_stringalg_err() throws IOException { + void queryStringAlgError() { assertThrows( IllegalArgumentException.class, () -> new TSIG("randomalg", "example.", "12345678")); } @Test - void TSIG_queryIsLastAddMessageRecord() throws IOException { + void queryIsLastAddMessageRecord() throws IOException { TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); - Name qname = Name.fromString("www.example."); - Record rec = Record.newRecord(qname, Type.A, DClass.IN); + Record rec = Record.newRecord(Name.fromString("www.example."), Type.A, DClass.IN); OPTRecord opt = new OPTRecord(SimpleResolver.DEFAULT_EDNS_PAYLOADSIZE, 0, 0, 0); Message msg = Message.newQuery(rec); - msg.setTSIG(key, Rcode.NOERROR, null); + msg.setTSIG(key); msg.addRecord(opt, Section.ADDITIONAL); byte[] bytes = msg.toWire(512); assertEquals(2, bytes[11]); // additional RR count, lower byte @@ -94,16 +104,17 @@ void TSIG_queryIsLastAddMessageRecord() throws IOException { int result = key.verify(parsed, bytes, null); assertEquals(Rcode.NOERROR, result); assertTrue(parsed.isSigned()); + assertTrue(parsed.isVerified()); } @Test - void TSIG_queryAndTsigApplyMisbehaves() throws IOException { - Name qname = Name.fromString("www.example.com."); - Record rec = Record.newRecord(qname, Type.A, DClass.IN); + void queryAndTsigApplyMisbehaves() throws IOException { + Record rec = Record.newRecord(Name.fromString("www.example.com."), Type.A, DClass.IN); OPTRecord opt = new OPTRecord(SimpleResolver.DEFAULT_EDNS_PAYLOADSIZE, 0, 0, 0); Message msg = Message.newQuery(rec); msg.addRecord(opt, Section.ADDITIONAL); assertFalse(msg.isSigned()); + assertFalse(msg.isVerified()); TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); key.apply(msg, null); // additional RR count, lower byte @@ -113,11 +124,7 @@ void TSIG_queryAndTsigApplyMisbehaves() throws IOException { } @Test - void TSIG_queryIsLastResolver() throws IOException { - Name qname = Name.fromString("www.example.com."); - Record rec = Record.newRecord(qname, Type.A, DClass.IN); - Message msg = Message.newQuery(rec); - + void tsigInQueryIsLastViaResolver() throws IOException { TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); SimpleResolver res = new SimpleResolver("127.0.0.1") { @@ -134,26 +141,56 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e } }; res.setTSIGKey(key); - Message parsed = res.send(msg); - List additionalSection = parsed.getSection(Section.ADDITIONAL); + Name qname = Name.fromString("www.example.com."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message query = Message.newQuery(question); + Message response = res.send(query); + + List additionalSection = response.getSection(Section.ADDITIONAL); assertEquals(Type.string(Type.OPT), Type.string(additionalSection.get(0).getType())); assertEquals(Type.string(Type.TSIG), Type.string(additionalSection.get(1).getType())); - int result = key.verify(parsed, parsed.toWire(), null); + int result = key.verify(response, response.toWire(), null); assertEquals(Rcode.NOERROR, result); - assertTrue(parsed.isSigned()); + assertTrue(response.isSigned()); + assertTrue(response.isVerified()); } @Test - void TSIG_response() throws IOException { + void unsignedQuerySignedResponse() throws IOException { TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); Name qname = Name.fromString("www.example."); Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); - query.setTSIG(key, Rcode.NOERROR, null); - byte[] qbytes = query.toWire(); + + Message response = new Message(query.getHeader().getID()); + response.setTSIG(key, Rcode.NOERROR, null); + response.getHeader().setFlag(Flags.QR); + response.addRecord(question, Section.QUESTION); + Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); + response.addRecord(answer, Section.ANSWER); + byte[] rbytes = response.toWire(Message.MAXLENGTH); + + Message rparsed = new Message(rbytes); + int result = key.verify(rparsed, rbytes, null); + assertEquals(Rcode.NOERROR, result); + assertTrue(rparsed.isSigned()); + assertTrue(rparsed.isVerified()); + } + + @Test + void signedQuerySignedResponse() throws IOException { + TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); + + Name qname = Name.fromString("www.example."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message query = Message.newQuery(question); + query.setTSIG(key); + byte[] qbytes = query.toWire(Message.MAXLENGTH); Message qparsed = new Message(qbytes); + assertNotNull(query.getGeneratedTSIG()); + assertEquals(query.getGeneratedTSIG(), qparsed.getTSIG()); Message response = new Message(query.getHeader().getID()); response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG()); @@ -161,23 +198,67 @@ void TSIG_response() throws IOException { response.addRecord(question, Section.QUESTION); Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); response.addRecord(answer, Section.ANSWER); - byte[] bytes = response.toWire(512); + byte[] rbytes = response.toWire(Message.MAXLENGTH); - Message parsed = new Message(bytes); - int result = key.verify(parsed, bytes, qparsed.getTSIG()); + Message rparsed = new Message(rbytes); + int result = key.verify(rparsed, rbytes, query.getGeneratedTSIG()); assertEquals(Rcode.NOERROR, result); - assertTrue(parsed.isSigned()); + assertTrue(rparsed.isSigned()); + assertTrue(rparsed.isVerified()); + } + + @Test + void signedQuerySignedResponseViaResolver() throws IOException { + TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); + + Name qname = Name.fromString("www.example."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message query = Message.newQuery(question); + + try (MockedStatic udpClient = Mockito.mockStatic(NioUdpClient.class)) { + udpClient + .when( + () -> + NioUdpClient.sendrecv( + any(), + any(InetSocketAddress.class), + any(byte[].class), + anyInt(), + any(Duration.class))) + .thenAnswer( + a -> { + Message qparsed = new Message(a.getArgument(2, byte[].class)); + + Message response = new Message(qparsed.getHeader().getID()); + response.setTSIG(key, Rcode.NOERROR, qparsed.getTSIG()); + response.getHeader().setFlag(Flags.QR); + response.addRecord(question, Section.QUESTION); + Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); + response.addRecord(answer, Section.ANSWER); + byte[] rbytes = response.toWire(Message.MAXLENGTH); + + CompletableFuture f = new CompletableFuture<>(); + f.complete(rbytes); + return f; + }); + SimpleResolver res = new SimpleResolver("127.0.0.1"); + res.setTSIGKey(key); + + Message responseFromResolver = res.send(query); + assertTrue(responseFromResolver.isSigned()); + assertTrue(responseFromResolver.isVerified()); + } } @Test - void TSIG_truncated() throws IOException { + void truncated() throws IOException { TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); Name qname = Name.fromString("www.example."); Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); query.setTSIG(key, Rcode.NOERROR, null); - byte[] qbytes = query.toWire(); + byte[] qbytes = query.toWire(512); Message qparsed = new Message(qbytes); Message response = new Message(query.getHeader().getID()); @@ -188,13 +269,14 @@ void TSIG_truncated() throws IOException { Record answer = Record.fromString(qname, Type.TXT, DClass.IN, 300, "foo" + i, null); response.addRecord(answer, Section.ANSWER); } - byte[] bytes = response.toWire(512); + byte[] rbytes = response.toWire(512); - Message parsed = new Message(bytes); - assertTrue(parsed.getHeader().getFlag(Flags.TC)); - int result = key.verify(parsed, bytes, qparsed.getTSIG()); + Message rparsed = new Message(rbytes); + assertTrue(rparsed.getHeader().getFlag(Flags.TC)); + int result = key.verify(rparsed, rbytes, qparsed.getTSIG()); assertEquals(Rcode.NOERROR, result); - assertTrue(parsed.isSigned()); + assertTrue(rparsed.isSigned()); + assertTrue(rparsed.isVerified()); } @Test @@ -205,4 +287,41 @@ void rdataFromString() { () -> new TSIGRecord().rdataFromString(new Tokenizer(" "), null)); assertTrue(thrown.getMessage().contains("no text format defined for TSIG")); } + + @Test + void testTSIGMessageClone() throws IOException { + TSIG key = new TSIG(TSIG.HMAC_SHA256, "example.", "12345678"); + TSIGRecord old = + new TSIGRecord( + Name.fromConstantString("example."), + DClass.IN, + 0, + TSIG.HMAC_SHA256, + Instant.ofEpochSecond(1647025759), + Duration.ofSeconds(300), + base64.fromString("zcHnvVwo0Zlsj0WckOO/ctRD2Znh+BjIWnSvTQdvj94="), + 32, + Rcode.NOERROR, + null); + + Name qname = Name.fromConstantString("www.example."); + Record question = Record.newRecord(qname, Type.A, DClass.IN); + Message response = new Message(); + response.getHeader().setFlag(Flags.QR); + response.addRecord(question, Section.QUESTION); + response.addRecord( + new ARecord(qname, DClass.IN, 0, InetAddress.getByName("127.0.0.1")), Section.ANSWER); + response.setTSIG(key, Rcode.NOERROR, old); + byte[] responseBytes = response.toWire(Message.MAXLENGTH); + assertNotNull(responseBytes); + assertNotEquals(0, responseBytes.length); + + Message clone = response.clone(); + assertEquals(response.getQuestion(), clone.getQuestion()); + assertEquals(response.getSection(Section.ANSWER), clone.getSection(Section.ANSWER)); + assertEquals(response.getGeneratedTSIG(), clone.getGeneratedTSIG()); + byte[] cloneBytes = clone.toWire(Message.MAXLENGTH); + assertNotNull(cloneBytes); + assertNotEquals(0, cloneBytes.length); + } }