From 98a264b8a0e6725bcb220617596f66a7bdf529f3 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 25 Feb 2026 16:00:51 +0000 Subject: [PATCH 1/4] ByteBufBsonDocument & ByteBufBsonArray refactorings (#1874) * ByteBufBsonDocument & ByteBufBsonArray refactoring * Now implement `Closeable` to track and manage lifecycle with try-with-resources * `ByteBufBsonDocument`: Added resource tracking, OP_MSG parsing, caching strategy * `ByteBufBsonArray`: Added resource tracking and cleanup CommandMessage Changes: * `getCommandDocument()` returns `ByteBufBsonDocument` (was `BsonDocument`) * Delegates document composition to `ByteBufBsonDocument` * Simplified `OP_MSG` document sequence parsing JAVA-6010 * Nit fixes and usability improvements If the document is hydrated allow continued use after resource closing. * PR updates * PR updates - ensure iterators track open resources and normalize the tests --- .../internal/connection/ByteBufBsonArray.java | 115 +- .../connection/ByteBufBsonDocument.java | 1104 +++++++++++++---- .../connection/ByteBufBsonHelper.java | 13 +- .../internal/connection/CommandMessage.java | 76 +- .../connection/InternalStreamConnection.java | 189 +-- .../micrometer/TracingManager.java | 18 +- .../connection/ByteBufBsonArrayTest.java | 277 +++-- .../ByteBufBsonDocumentSpecification.groovy | 313 ----- .../connection/ByteBufBsonDocumentTest.java | 795 ++++++++++++ .../CommandMessageSpecification.groovy | 365 ------ .../connection/CommandMessageTest.java | 472 ++++++- ...gingCommandEventSenderSpecification.groovy | 24 +- 12 files changed, 2564 insertions(+), 1197 deletions(-) delete mode 100644 driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy create mode 100644 driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java delete mode 100644 driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java index e02cee12629..00d28d1d5d1 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java @@ -16,6 +16,9 @@ package com.mongodb.internal.connection; +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.internal.diagnostics.logging.Logger; +import com.mongodb.internal.diagnostics.logging.Loggers; import org.bson.BsonArray; import org.bson.BsonBinaryReader; import org.bson.BsonType; @@ -23,6 +26,7 @@ import org.bson.ByteBuf; import org.bson.io.ByteBufferBsonInput; +import java.io.Closeable; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; @@ -33,20 +37,33 @@ import static com.mongodb.internal.connection.ByteBufBsonHelper.readBsonValue; -final class ByteBufBsonArray extends BsonArray { +@NotThreadSafe +final class ByteBufBsonArray extends BsonArray implements Closeable { + private static final Logger LOGGER = Loggers.getLogger("connection"); private final ByteBuf byteBuf; + /** + * List of resources that need to be closed when this array is closed. + * Tracks the main ByteBuf and iterator duplicates. Iterator buffers are automatically + * removed and released when iteration completes normally to prevent memory accumulation. + */ + private final List trackedResources = new ArrayList<>(); + private boolean closed; + ByteBufBsonArray(final ByteBuf byteBuf) { this.byteBuf = byteBuf; + trackedResources.add(byteBuf::release); } @Override public Iterator iterator() { + ensureOpen(); return new ByteBufBsonArrayIterator(); } @Override public List getValues() { + ensureOpen(); List values = new ArrayList<>(); for (BsonValue cur: this) { //noinspection UseBulkOperation @@ -59,6 +76,7 @@ public List getValues() { @Override public int size() { + ensureOpen(); int size = 0; for (BsonValue ignored : this) { size++; @@ -68,11 +86,13 @@ public int size() { @Override public boolean isEmpty() { + ensureOpen(); return !iterator().hasNext(); } @Override public boolean equals(final Object o) { + ensureOpen(); if (o == this) { return true; } @@ -91,6 +111,7 @@ public boolean equals(final Object o) { @Override public int hashCode() { + ensureOpen(); int hashCode = 1; for (BsonValue cur : this) { hashCode = 31 * hashCode + (cur == null ? 0 : cur.hashCode()); @@ -100,6 +121,7 @@ public int hashCode() { @Override public boolean contains(final Object o) { + ensureOpen(); for (BsonValue cur : this) { if (Objects.equals(o, cur)) { return true; @@ -111,6 +133,7 @@ public boolean contains(final Object o) { @Override public Object[] toArray() { + ensureOpen(); Object[] retVal = new Object[size()]; Iterator it = iterator(); for (int i = 0; i < retVal.length; i++) { @@ -122,6 +145,7 @@ public Object[] toArray() { @Override @SuppressWarnings("unchecked") public T[] toArray(final T[] a) { + ensureOpen(); int size = size(); T[] retVal = a.length >= size ? a : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); Iterator it = iterator(); @@ -133,6 +157,7 @@ public T[] toArray(final T[] a) { @Override public boolean containsAll(final Collection c) { + ensureOpen(); for (Object e : c) { if (!contains(e)) { return false; @@ -143,6 +168,7 @@ public boolean containsAll(final Collection c) { @Override public BsonValue get(final int index) { + ensureOpen(); if (index < 0) { throw new IndexOutOfBoundsException("Index out of range: " + index); } @@ -159,6 +185,7 @@ public BsonValue get(final int index) { @Override public int indexOf(final Object o) { + ensureOpen(); int i = 0; for (BsonValue cur : this) { if (Objects.equals(o, cur)) { @@ -172,6 +199,7 @@ public int indexOf(final Object o) { @Override public int lastIndexOf(final Object o) { + ensureOpen(); ListIterator listIterator = listIterator(size()); while (listIterator.hasPrevious()) { if (Objects.equals(o, listIterator.previous())) { @@ -183,17 +211,20 @@ public int lastIndexOf(final Object o) { @Override public ListIterator listIterator() { + ensureOpen(); return listIterator(0); } @Override public ListIterator listIterator(final int index) { + ensureOpen(); // Not the most efficient way to do this, but unlikely anyone will notice in practice return new ArrayList<>(this).listIterator(index); } @Override public List subList(final int fromIndex, final int toIndex) { + ensureOpen(); if (fromIndex < 0) { throw new IndexOutOfBoundsException("fromIndex = " + fromIndex); } @@ -234,6 +265,7 @@ public boolean addAll(final Collection c) { @Override public boolean addAll(final int index, final Collection c) { + ensureOpen(); throw new UnsupportedOperationException(READ_ONLY_MESSAGE); } @@ -267,23 +299,73 @@ public BsonValue remove(final int index) { throw new UnsupportedOperationException(READ_ONLY_MESSAGE); } + @Override + public void close(){ + if (!closed) { + for (Closeable closeable : trackedResources) { + try { + closeable.close(); + } catch (Exception e) { + // Log and continue closing other resources + LOGGER.error("Error closing resource", e); + } + } + trackedResources.clear(); + closed = true; + } + } + + private void ensureOpen() { + if (closed) { + throw new IllegalStateException("The BsonArray resources have been released."); + } + } + private class ByteBufBsonArrayIterator implements Iterator { - private final ByteBuf duplicatedByteBuf = byteBuf.duplicate(); - private final BsonBinaryReader bsonReader; + private ByteBuf duplicatedByteBuf; + private BsonBinaryReader reader; + private Closeable resourceHandle; + private boolean finished; { - bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); + ensureOpen(); + // Create duplicate buffer for iteration and track it temporarily + duplicatedByteBuf = byteBuf.duplicate(); + resourceHandle = () -> { + if (duplicatedByteBuf != null) { + try { + if (reader != null) { + reader.close(); + } + } catch (Exception e) { + // Ignore + } + duplicatedByteBuf.release(); + duplicatedByteBuf = null; + reader = null; + } + }; + trackedResources.add(resourceHandle); + reader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); // While one might expect that this would be a call to BsonReader#readStartArray that doesn't work because BsonBinaryReader // expects to be positioned at the start at the beginning of a document, not an array. Fortunately, a BSON array has exactly // the same structure as a BSON document (the keys are just the array indices converted to a strings). So it works fine to // call BsonReader#readStartDocument here, and just skip all the names via BsonReader#skipName below. - bsonReader.readStartDocument(); - bsonReader.readBsonType(); + reader.readStartDocument(); + reader.readBsonType(); } @Override public boolean hasNext() { - return bsonReader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; + if (finished) { + return false; + } + ensureOpen(); + boolean hasNext = reader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; + if (!hasNext) { + cleanup(); + } + return hasNext; } @Override @@ -291,10 +373,23 @@ public BsonValue next() { if (!hasNext()) { throw new NoSuchElementException(); } - bsonReader.skipName(); - BsonValue value = readBsonValue(duplicatedByteBuf, bsonReader); - bsonReader.readBsonType(); + reader.skipName(); + BsonValue value = readBsonValue(duplicatedByteBuf, reader, trackedResources); + reader.readBsonType(); return value; } + + private void cleanup() { + if (!finished) { + finished = true; + // Remove from tracked resources since we're cleaning up immediately + trackedResources.remove(resourceHandle); + try { + resourceHandle.close(); + } catch (Exception e) { + // Ignore + } + } + } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java index 70ed10a75a8..4d4ebcc1169 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java @@ -16,139 +16,535 @@ package com.mongodb.internal.connection; +import com.mongodb.MongoInternalException; +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.diagnostics.logging.Logger; +import com.mongodb.internal.diagnostics.logging.Loggers; import com.mongodb.lang.Nullable; +import org.bson.BsonArray; import org.bson.BsonBinaryReader; import org.bson.BsonDocument; +import org.bson.BsonReader; import org.bson.BsonType; import org.bson.BsonValue; import org.bson.ByteBuf; -import org.bson.RawBsonDocument; import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.DecoderContext; import org.bson.io.ByteBufferBsonInput; import org.bson.json.JsonMode; -import org.bson.json.JsonWriter; import org.bson.json.JsonWriterSettings; +import java.io.ByteArrayOutputStream; +import java.io.Closeable; import java.io.InvalidObjectException; import java.io.ObjectInputStream; -import java.io.StringWriter; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; import java.util.AbstractCollection; import java.util.AbstractMap; import java.util.AbstractSet; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; +import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PACKAGE; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; import static com.mongodb.internal.connection.ByteBufBsonHelper.readBsonValue; +import static java.util.Collections.emptyMap; -final class ByteBufBsonDocument extends BsonDocument { +/** + * A memory-efficient, read-only {@link BsonDocument} implementation backed by a {@link ByteBuf}. + * + *

Overview

+ *

This class provides lazy access to BSON document fields without fully deserializing the document + * into memory. It reads field values directly from the underlying byte buffer on demand, which is + * particularly useful for large documents where only a few fields need to be accessed.

+ * + *

Data Sources

+ *

A {@code ByteBufBsonDocument} can contain data from two sources:

+ *
    + *
  • Body fields: Standard BSON document fields stored in {@link #bodyByteBuf}. These are + * read lazily using a {@link BsonBinaryReader}.
  • + *
  • Sequence fields: MongoDB OP_MSG Type 1 payload sequences stored in {@link #sequenceFields}. + * These are used when parsing command messages that contain document sequences (e.g., bulk inserts). + * Each sequence field appears as an array of documents when accessed.
  • + *
+ * + *

OP_MSG Command Message Support

+ *

The {@link #createCommandMessage(CompositeByteBuf)} factory method parses MongoDB OP_MSG format, + * which consists of:

+ *
    + *
  1. A body section (Type 0): The main command document
  2. + *
  3. Zero or more document sequence sections (Type 1): Arrays of documents identified by field name
  4. + *
+ *

For example, an insert command might have the body containing {@code {insert: "collection", $db: "test"}} + * and a sequence section with field name "documents" containing the documents to insert.

+ * + *

Resource Management

+ *

This class implements {@link Closeable} and manages several types of resources:

+ *
    + *
  • ByteBuf instances: The body buffer and any duplicated buffers created during iteration + * or value access are tracked in {@link #trackedResources} and released on {@link #close()}.
  • + *
  • Nested ByteBufBsonDocument/ByteBufBsonArray: When accessing nested documents or arrays, + * new {@code ByteBufBsonDocument} or {@link ByteBufBsonArray} instances are created. These are + * registered as closeables and closed recursively when the parent is closed.
  • + *
  • Sequence field documents: Documents within sequence fields are also {@code ByteBufBsonDocument} + * instances that are tracked and closed with the parent.
  • + *
+ * + *

Important: Always close this document when done to prevent memory leaks. After closing, + * any operation will throw {@link IllegalStateException}.

+ * + *

Caching Strategy

+ *

The class uses lazy caching to optimize repeated access:

+ *
    + *
  • {@link #cachedDocument}: Once {@link #toBsonDocument()} is called, the fully hydrated document + * is cached and all subsequent operations use this cache. At this point, the underlying buffers + * are released since they're no longer needed.
  • + *
  • {@link #cachedFirstKey}: The first key is cached after the first call to {@link #getFirstKey()}.
  • + *
  • Sequence field arrays are cached within {@link SequenceField} after first access.
  • + *
+ * + *

Immutability

+ *

This class is read-only. All mutation methods ({@link #put}, {@link #remove}, {@link #clear}, etc.) + * throw {@link UnsupportedOperationException}.

+ * + *

Thread Safety

+ *

This class is not thread-safe. Concurrent access from multiple threads requires external synchronization.

+ * + *

Serialization

+ *

Java serialization is supported via {@link #writeReplace()}, which converts this document to a + * regular {@link BsonDocument} before serialization.

+ * + * @see ByteBufBsonArray + * @see ByteBufBsonHelper + */ +@NotThreadSafe +public final class ByteBufBsonDocument extends BsonDocument implements Closeable { + private static final Logger LOGGER = Loggers.getLogger("connection"); private static final long serialVersionUID = 2L; - private final transient ByteBuf byteBuf; + /** + * The underlying byte buffer containing the BSON document body. + * This is the main document data, excluding any OP_MSG sequence sections. + * Set to null after {@link #releaseResources()} is called. + */ + private transient ByteBuf bodyByteBuf; + + /** + * Map of sequence field names to their corresponding {@link SequenceField} instances. + * These represent OP_MSG Type 1 payload sections. Each sequence field appears as an + * array when accessed via {@link #get(Object)}. + * Empty for simple documents not created from OP_MSG. + */ + private transient Map sequenceFields; + + /** + * List of resources that need to be closed/released when this document is closed. + * + *

Memory Management Strategy:

+ *
    + *
  • Always tracked: The main bodyByteBuf and any nested ByteBufBsonDocument/ByteBufBsonArray + * instances returned to callers are permanently tracked until this document is closed or + * {@link #toBsonDocument()} caches and releases them.
  • + *
  • Temporarily tracked: Iterator duplicate buffers are tracked during iteration + * but automatically removed and released when iteration completes. This prevents memory accumulation + * from completed iterations while ensuring cleanup if the parent document is closed mid-iteration.
  • + *
  • Not tracked: Short-lived duplicate buffers used in query methods + * (e.g., {@link #findKeyInBody}, {@link #containsKey}) are released immediately in finally blocks + * and never added to this list. Temporary nested documents created during value comparison + * use separate tracking lists.
  • + *
+ */ + private final transient List trackedResources; /** - * Create a list of ByteBufBsonDocument from a buffer positioned at the start of the first document of an OP_MSG Section - * of type Document Sequence (Kind 1). - *

- * The provided buffer will be positioned at the end of the section upon normal completion of the method + * Cached fully-hydrated BsonDocument. Once populated via {@link #toBsonDocument()}, + * all subsequent read operations use this cache instead of reading from the byte buffer. */ - static List createList(final ByteBuf outputByteBuf) { - List documents = new ArrayList<>(); - while (outputByteBuf.hasRemaining()) { - ByteBufBsonDocument curDocument = createOne(outputByteBuf); - documents.add(curDocument); + private transient BsonDocument cachedDocument; + + /** + * Cached first key of the document. Populated on first call to {@link #getFirstKey()}. + */ + private transient String cachedFirstKey; + + /** + * Flag indicating whether this document has been closed. + * Once closed, all operations throw {@link IllegalStateException}. + */ + private transient boolean closed; + + + /** + * Creates a {@code ByteBufBsonDocument} from an OP_MSG command message. + * + *

This factory method parses the MongoDB OP_MSG wire protocol format, which consists of:

+ *
    + *
  1. Body section (Type 0): A single BSON document containing the command
  2. + *
  3. Document sequence sections (Type 1): Zero or more sections, each containing + * a field identifier and a sequence of BSON documents
  4. + *
+ * + *

The sequence sections are stored in {@link #sequenceFields} and appear as array fields + * when the document is accessed. For example, an insert command's "documents" sequence + * will appear as an array when calling {@code get("documents")}.

+ * + *

Wire Format Parsed

+ *
+     * [body document bytes]
+     * [section type: 1 byte] [section size: 4 bytes] [identifier: cstring] [document bytes...]
+     * ... (more sections)
+     * 
+ * + * @param commandMessageByteBuf The composite buffer positioned at the start of the body document. + * Position will be advanced past all parsed sections. + * @return A new {@code ByteBufBsonDocument} representing the command with any sequence fields. + */ + @VisibleForTesting(otherwise = PRIVATE) + public static ByteBufBsonDocument createCommandMessage(final CompositeByteBuf commandMessageByteBuf) { + // Parse body document: read size, create a view of just the body bytes + int bodyStart = commandMessageByteBuf.position(); + int bodySizeInBytes = commandMessageByteBuf.getInt(); + int bodyEnd = bodyStart + bodySizeInBytes; + ByteBuf bodyByteBuf = commandMessageByteBuf.duplicate().position(bodyStart).limit(bodyEnd); + + List trackedResources = new ArrayList<>(); + commandMessageByteBuf.position(bodyEnd); + + // Parse any Type 1 (document sequence) sections that follow the body + Map sequences = new LinkedHashMap<>(); + while (commandMessageByteBuf.hasRemaining()) { + // Skip section type byte (we only support Type 1 here) + commandMessageByteBuf.position(commandMessageByteBuf.position() + 1); + + // Read section size and calculate bounds + int sequenceStart = commandMessageByteBuf.position(); + int sequenceSizeInBytes = commandMessageByteBuf.getInt(); + int sectionEnd = sequenceStart + sequenceSizeInBytes; + + // Read the field identifier (null-terminated string) + String fieldName = readCString(commandMessageByteBuf); + assertFalse(fieldName.contains(".")); + + // Create a view of just the document sequence bytes (after the identifier) + ByteBuf sequenceByteBuf = commandMessageByteBuf.duplicate(); + sequenceByteBuf.position(commandMessageByteBuf.position()).limit(sectionEnd); + sequences.put(fieldName, new SequenceField(sequenceByteBuf, trackedResources)); + commandMessageByteBuf.position(sectionEnd); } - return documents; + return new ByteBufBsonDocument(bodyByteBuf, trackedResources, sequences); } /** - * Create a ByteBufBsonDocument from a buffer positioned at the start of a BSON document. - * The provided buffer will be positioned at the end of the document upon normal completion of the method + * Creates a simple {@code ByteBufBsonDocument} from a byte buffer containing a single BSON document. + * + *

Use this constructor for standard BSON documents. For OP_MSG command messages with + * document sequences, use {@link #createCommandMessage(CompositeByteBuf)} instead.

+ * + * @param byteBuf The buffer containing the BSON document. The buffer should be positioned + * at the start of the document and contain the complete document bytes. + */ + @VisibleForTesting(otherwise = PACKAGE) + public ByteBufBsonDocument(final ByteBuf byteBuf) { + this(byteBuf, new ArrayList<>(), new HashMap<>()); + } + + /** + * Private constructor used by factory methods. + * + * @param bodyByteBuf The buffer containing the body document bytes + * @param trackedResources Mutable list for tracking resources to close + * @param sequenceFields Map of sequence field names to their data (empty for simple documents) */ - static ByteBufBsonDocument createOne(final ByteBuf outputByteBuf) { - int documentStart = outputByteBuf.position(); - int documentSizeInBytes = outputByteBuf.getInt(); - int documentEnd = documentStart + documentSizeInBytes; - ByteBuf slice = outputByteBuf.duplicate().position(documentStart).limit(documentEnd); - outputByteBuf.position(documentEnd); - return new ByteBufBsonDocument(slice); + private ByteBufBsonDocument(final ByteBuf bodyByteBuf, final List trackedResources, + final Map sequenceFields) { + this.bodyByteBuf = bodyByteBuf; + this.trackedResources = trackedResources; + this.sequenceFields = sequenceFields; + trackedResources.add(bodyByteBuf::release); } + // ==================== Size and Empty Checks ==================== + @Override - public String toJson() { - return toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()); + public int size() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.size(); + } + // Total size = body fields + sequence fields + return countBodyFields() + sequenceFields.size(); } @Override - public String toJson(final JsonWriterSettings settings) { - StringWriter stringWriter = new StringWriter(); - JsonWriter jsonWriter = new JsonWriter(stringWriter, settings); - ByteBuf duplicate = byteBuf.duplicate(); - try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(duplicate))) { - jsonWriter.pipe(reader); - return stringWriter.toString(); - } finally { - duplicate.release(); + public boolean isEmpty() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.isEmpty(); } + return !hasBodyFields() && sequenceFields.isEmpty(); } + // ==================== Key/Value Lookups ==================== + @Override - public BsonBinaryReader asBsonReader() { - return new BsonBinaryReader(new ByteBufferBsonInput(byteBuf.duplicate())); + public boolean containsKey(final Object key) { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.containsKey(key); + } + if (key == null) { + throw new IllegalArgumentException("key can not be null"); + } + // Check sequence fields first (fast HashMap lookup), then scan body + if (sequenceFields.containsKey(key)) { + return true; + } + return findKeyInBody((String) key); } - @SuppressWarnings("MethodDoesntCallSuperMethod") @Override - public BsonDocument clone() { - byte[] clonedBytes = new byte[byteBuf.remaining()]; - byteBuf.get(byteBuf.position(), clonedBytes); - return new RawBsonDocument(clonedBytes); + public boolean containsValue(final Object value) { + ensureOpen(); + if (!(value instanceof BsonValue)) { + return false; + } + + if (cachedDocument != null) { + return cachedDocument.containsValue(value); + } + + // Search body fields first, then sequence fields + if (findValueInBody((BsonValue) value)) { + return true; + } + for (SequenceField field : sequenceFields.values()) { + if (field.containsValue(value)) { + return true; + } + } + return false; } + /** + * {@inheritDoc} + * + *

For sequence fields (OP_MSG document sequences), returns a {@link BsonArray} containing + * {@code ByteBufBsonDocument} instances for each document in the sequence.

+ */ @Nullable - T findInDocument(final Finder finder) { - ByteBuf duplicateByteBuf = byteBuf.duplicate(); - try (BsonBinaryReader bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicateByteBuf))) { - bsonReader.readStartDocument(); - while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) { - T found = finder.find(duplicateByteBuf, bsonReader); - if (found != null) { - return found; - } + @Override + public BsonValue get(final Object key) { + ensureOpen(); + notNull("key", key); + + if (!(key instanceof String)) { + return null; + } + if (cachedDocument != null) { + return cachedDocument.get(key); + } + + // Check sequence fields first, then body + if (sequenceFields.containsKey(key)) { + return sequenceFields.get(key).asArray(); + } + return getValueFromBody((String) key); + } + + @Override + public String getFirstKey() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.getFirstKey(); + } + if (cachedFirstKey != null) { + return cachedFirstKey; + } + cachedFirstKey = getFirstKeyFromBody(); + return assertNotNull(cachedFirstKey); + } + + // ==================== Collection Views ==================== + // These return lazy views that iterate over both body and sequence fields + + @Override + public Set> entrySet() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.entrySet(); + } + return new AbstractSet>() { + @Override + public Iterator> iterator() { + // Combine body entries with sequence entries + return new CombinedIterator<>(createBodyIterator(IteratorMode.ENTRIES), createSequenceEntryIterator()); } - bsonReader.readEndDocument(); - } finally { - duplicateByteBuf.release(); + + @Override + public int size() { + return ByteBufBsonDocument.this.size(); + } + }; + } + + @Override + public Collection values() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.values(); } + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new CombinedIterator<>(createBodyIterator(IteratorMode.VALUES), createSequenceValueIterator()); + } - return finder.notFound(); + @Override + public int size() { + return ByteBufBsonDocument.this.size(); + } + }; } - BsonDocument toBaseBsonDocument() { - ByteBuf duplicateByteBuf = byteBuf.duplicate(); - try (BsonBinaryReader bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicateByteBuf))) { - return new BsonDocumentCodec().decode(bsonReader, DecoderContext.builder().build()); - } finally { - duplicateByteBuf.release(); + @Override + public Set keySet() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.keySet(); + } + return new AbstractSet() { + @Override + public Iterator iterator() { + return new CombinedIterator<>(createBodyIterator(IteratorMode.KEYS), sequenceFields.keySet().iterator()); + } + + @Override + public int size() { + return ByteBufBsonDocument.this.size(); + } + }; + } + + // ==================== Conversion Methods ==================== + + @Override + public BsonReader asBsonReader() { + ensureOpen(); + // Must hydrate first since we need to include sequence fields + return toBsonDocument().asBsonReader(); + } + + /** + * Converts this document to a regular {@link BsonDocument}, fully deserializing all data. + * + *

After this method is called:

+ *
    + *
  • The result is cached for future calls
  • + *
  • All underlying byte buffers are released
  • + *
  • Sequence field documents are hydrated to regular {@code BsonDocument} instances
  • + *
  • All subsequent read operations use the cached document
  • + *
+ * + * @return A fully materialized {@link BsonDocument} containing all fields + */ + @Override + public BsonDocument toBsonDocument() { + ensureOpen(); + if (cachedDocument == null) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + // Decode body document + BsonDocument doc = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()); + // Add hydrated sequence fields + for (Map.Entry entry : sequenceFields.entrySet()) { + doc.put(entry.getKey(), entry.getValue().toHydratedArray()); + } + cachedDocument = doc; + closed = true; + // Release buffers since we no longer need them + releaseResources(); + } finally { + dup.release(); + } } + return cachedDocument; } - ByteBufBsonDocument(final ByteBuf byteBuf) { - this.byteBuf = byteBuf; + @Override + public String toJson() { + return toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()); } @Override - public void clear() { - throw new UnsupportedOperationException("ByteBufBsonDocument instances are immutable"); + public String toJson(final JsonWriterSettings settings) { + ensureOpen(); + return toBsonDocument().toJson(settings); } + @Override + public String toString() { + ensureOpen(); + return toBsonDocument().toString(); + } + + @SuppressWarnings("MethodDoesntCallSuperMethod") + @Override + public BsonDocument clone() { + ensureOpen(); + return toBsonDocument().clone(); + } + + @SuppressWarnings("EqualsDoesntCheckParameterClass") + @Override + public boolean equals(final Object o) { + ensureOpen(); + return toBsonDocument().equals(o); + } + + @Override + public int hashCode() { + ensureOpen(); + return toBsonDocument().hashCode(); + } + + // ==================== Resource Management ==================== + + /** + * Releases all resources held by this document. + * + *

This includes:

+ *
    + *
  • Releasing all tracked {@link ByteBuf} instances
  • + *
  • Closing all nested {@code ByteBufBsonDocument} and {@link ByteBufBsonArray} instances
  • + *
  • Clearing internal references
  • + *
+ * + *

After calling this method, any operation on this document will throw + * {@link IllegalStateException}. This method is idempotent.

+ */ + @Override + public void close() { + if (!closed) { + closed = true; + releaseResources(); + } + } + + // ==================== Mutation Methods (Unsupported) ==================== + @Override public BsonValue put(final String key, final BsonValue value) { throw new UnsupportedOperationException("ByteBufBsonDocument instances are immutable"); @@ -170,260 +566,480 @@ public BsonValue remove(final Object key) { } @Override - public boolean isEmpty() { - return assertNotNull(findInDocument(new Finder() { - @Override - public Boolean find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - return false; - } - - @Override - public Boolean notFound() { - return true; - } - })); + public void clear() { + throw new UnsupportedOperationException("ByteBufBsonDocument instances are immutable"); } - @Override - public int size() { - return assertNotNull(findInDocument(new Finder() { - private int size; + // ==================== Private Body Field Operations ==================== + // These methods read from bodyByteBuf using a temporary duplicate buffer - @Override - @Nullable - public Integer find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - size++; - bsonReader.readName(); - bsonReader.skipValue(); - return null; + /** + * Searches the body for a field with the given key. + * Uses a duplicated buffer to avoid modifying the original position. + */ + private boolean findKeyInBody(final String key) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + if (reader.readName().equals(key)) { + return true; + } + reader.skipValue(); } + return false; + } finally { + dup.release(); + } + } - @Override - public Integer notFound() { - return size; + /** + * Searches the body for a field with the given value. + * Creates ByteBufBsonDocument/ByteBufBsonArray for nested structures during comparison or vanilla BsonValues. + * Uses temporary tracking list to avoid polluting the main trackedResources with short-lived objects. + */ + private boolean findValueInBody(final BsonValue targetValue) { + ByteBuf dup = bodyByteBuf.duplicate(); + List tempTrackedResources = new ArrayList<>(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + reader.skipName(); + if (readBsonValue(dup, reader, tempTrackedResources).equals(targetValue)) { + return true; + } } - })); + return false; + } finally { + // Release temporary resources created during comparison + for (Closeable resource : tempTrackedResources) { + try { + resource.close(); + } catch (Exception e) { + // Continue closing other resources + } + } + dup.release(); + } } - @Override - public Set> entrySet() { - return new ByteBufBsonDocumentEntrySet(); + /** + * Retrieves a value from the body by key. + * Returns null if the key is not found in the body. + */ + @Nullable + private BsonValue getValueFromBody(final String key) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + if (reader.readName().equals(key)) { + return readBsonValue(dup, reader, trackedResources); + } + reader.skipValue(); + } + return null; + } finally { + dup.release(); + } } - @Override - public Collection values() { - return new ByteBufBsonDocumentValuesCollection(); + /** + * Gets the first key from the body, or from sequence fields if body is empty. + * Throws NoSuchElementException if the document is completely empty. + */ + private String getFirstKeyFromBody() { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + if (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + return reader.readName(); + } + // Body is empty, try sequence fields + if (!sequenceFields.isEmpty()) { + return sequenceFields.keySet().iterator().next(); + } + throw new NoSuchElementException(); + } finally { + dup.release(); + } } - @Override - public Set keySet() { - return new ByteBufBsonDocumentKeySet(); + /** + * Checks if the body contains at least one field. + */ + private boolean hasBodyFields() { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + return reader.readBsonType() != BsonType.END_OF_DOCUMENT; + } finally { + dup.release(); + } } - @Override - public boolean containsKey(final Object key) { - if (key == null) { - throw new IllegalArgumentException("key can not be null"); + /** + * Counts the number of fields in the body document. + */ + private int countBodyFields() { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + int count = 0; + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + count++; + reader.skipName(); + reader.skipValue(); + } + return count; + } finally { + dup.release(); } + } + + // ==================== Iterator Support ==================== + + /** + * Mode for the body iterator, determining what type of elements it produces. + */ + private enum IteratorMode { ENTRIES, KEYS, VALUES } + + /** + * Creates an iterator over the body document fields. + * + *

The iterator creates a duplicated ByteBuf that is temporarily tracked for safety. + * When iteration completes normally, the buffer is released immediately and removed from tracking. + * This prevents accumulation of finished iterator buffers while ensuring cleanup if the parent + * document is closed before iteration completes.

+ * + * @param mode Determines whether to return entries, keys, or values + * @return An iterator of the appropriate type + */ + @SuppressWarnings("unchecked") + private Iterator createBodyIterator(final IteratorMode mode) { + return new Iterator() { + private final Closeable resourceHandle; + private ByteBuf duplicatedByteBuf; + private BsonBinaryReader reader; + private boolean started; + private boolean finished; + + { + // Create duplicate buffer for iteration and track it temporarily + duplicatedByteBuf = bodyByteBuf.duplicate(); + resourceHandle = () -> { + if (duplicatedByteBuf != null) { + try { + if (reader != null) { + reader.close(); + } + } catch (Exception e) { + // Ignore + } + duplicatedByteBuf.release(); + duplicatedByteBuf = null; + reader = null; + } + }; + trackedResources.add(resourceHandle); + reader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); + } - Boolean containsKey = findInDocument(new Finder() { @Override - public Boolean find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - if (bsonReader.readName().equals(key)) { - return true; + public boolean hasNext() { + if (finished) { + return false; } - bsonReader.skipValue(); - return null; + ensureOpen(); + if (!started) { + reader.readStartDocument(); + reader.readBsonType(); + started = true; + } + boolean hasNext = reader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; + if (!hasNext) { + cleanup(); + } + return hasNext; } @Override - public Boolean notFound() { - return false; + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + ensureOpen(); + String key = reader.readName(); + BsonValue value = readBsonValue(duplicatedByteBuf, reader, trackedResources); + reader.readBsonType(); + + switch (mode) { + case ENTRIES: + return (T) new AbstractMap.SimpleImmutableEntry<>(key, value); + case KEYS: + return (T) key; + case VALUES: + return (T) value; + default: + throw new IllegalStateException("Unknown iterator mode: " + mode); + } } - }); - return containsKey != null ? containsKey : false; + + private void cleanup() { + if (!finished) { + finished = true; + // Remove from tracked resources since we're cleaning up immediately + trackedResources.remove(resourceHandle); + try { + resourceHandle.close(); + } catch (Exception e) { + // Ignore + } + } + } + }; } - @Override - public boolean containsValue(final Object value) { - Boolean containsValue = findInDocument(new Finder() { + /** + * Creates an iterator over sequence fields as map entries. + * Each entry contains the field name and its array value. + */ + private Iterator> createSequenceEntryIterator() { + Iterator> iter = sequenceFields.entrySet().iterator(); + return new Iterator>() { @Override - public Boolean find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - bsonReader.skipName(); - if (readBsonValue(byteBuf, bsonReader).equals(value)) { - return true; - } - return null; + public boolean hasNext() { + return iter.hasNext(); } @Override - public Boolean notFound() { - return false; + public Entry next() { + Entry entry = iter.next(); + return new AbstractMap.SimpleImmutableEntry<>(entry.getKey(), entry.getValue().asArray()); } - }); - return containsValue != null ? containsValue : false; + }; } - @Nullable - @Override - public BsonValue get(final Object key) { - notNull("key", key); - return findInDocument(new Finder() { + /** + * Creates an iterator over sequence field values (arrays). + */ + private Iterator createSequenceValueIterator() { + Iterator iter = sequenceFields.values().iterator(); + return new Iterator() { @Override - public BsonValue find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - if (bsonReader.readName().equals(key)) { - return readBsonValue(byteBuf, bsonReader); - } - bsonReader.skipValue(); - return null; + public boolean hasNext() { + return iter.hasNext(); } - @Nullable @Override - public BsonValue notFound() { - return null; + public BsonValue next() { + return iter.next().asArray(); } - }); + }; } + // ==================== Resource Management Helpers ==================== + /** - * Gets the first key in this document. + * Releases all tracked resources and clears internal state. * - * @return the first key in this document - * @throws java.util.NoSuchElementException if the document is empty + *

Called by {@link #close()} and after {@link #toBsonDocument()} caches the result. + * Resources include ByteBuf instances and nested ByteBufBsonDocument/ByteBufBsonArray.

*/ - public String getFirstKey() { - return assertNotNull(findInDocument(new Finder() { - @Override - public String find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - return bsonReader.readName(); + private void releaseResources() { + for (Closeable resource : trackedResources) { + try { + resource.close(); + } catch (Exception e) { + // Log and continue closing other resources + LOGGER.error("Error closing resource", e); } + } - @Override - public String notFound() { - throw new NoSuchElementException(); - } - })); + assertTrue(bodyByteBuf == null || bodyByteBuf.getReferenceCount() == 0, "Failed to release all `bodyByteBuf` resources"); + assertTrue(sequenceFields.values().stream().allMatch(b -> b.sequenceByteBuf.getReferenceCount() == 0), + "Failed to release all `sequenceField` resources"); + + trackedResources.clear(); + sequenceFields = emptyMap(); + bodyByteBuf = null; + cachedFirstKey = null; } - private interface Finder { - @Nullable - T find(ByteBuf byteBuf, BsonBinaryReader bsonReader); - @Nullable - T notFound(); + /** + * Throws IllegalStateException if this document has been closed and there is no cached document. + */ + private void ensureOpen() { + if (closed && cachedDocument == null) { + throw new IllegalStateException("The underlying BsonDocument resources have been released and the data is no longer " + + "accessible. Use `ByteBufBsonDocument.toBsonDocument()` to create a fully hydrated BsonDocument."); + } + } + + // ==================== Utility Methods ==================== + + /** + * Reads a null-terminated C-string from the buffer. + * Used for parsing OP_MSG sequence identifiers. + */ + private static String readCString(final ByteBuf byteBuf) { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + byte b = byteBuf.get(); + while (b != 0) { + bytes.write(b); + b = byteBuf.get(); + } + try { + return bytes.toString(StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + throw new MongoInternalException("Unexpected exception", e); + } } - // see https://docs.oracle.com/javase/6/docs/platform/serialization/spec/output.html + /** + * Serialization support: converts to a regular BsonDocument before serialization. + */ private Object writeReplace() { - return toBaseBsonDocument(); + ensureOpen(); + return toBsonDocument(); } - // see https://docs.oracle.com/javase/6/docs/platform/serialization/spec/input.html private void readObject(final ObjectInputStream stream) throws InvalidObjectException { throw new InvalidObjectException("Proxy required"); } - private class ByteBufBsonDocumentEntrySet extends AbstractSet> { - @Override - public Iterator> iterator() { - return new Iterator>() { - private final ByteBuf duplicatedByteBuf = byteBuf.duplicate(); - private final BsonBinaryReader bsonReader; - - { - bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); - bsonReader.readStartDocument(); - bsonReader.readBsonType(); - } - - @Override - public boolean hasNext() { - return bsonReader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; - } + // ==================== Inner Classes ==================== - @Override - public Entry next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - String key = bsonReader.readName(); - BsonValue value = readBsonValue(duplicatedByteBuf, bsonReader); - bsonReader.readBsonType(); - return new AbstractMap.SimpleEntry<>(key, value); - } + /** + * Represents an OP_MSG Type 1 document sequence section. + * + *

A sequence field contains a contiguous series of BSON documents in the buffer. + * When accessed via {@link #asArray()}, it returns a {@link BsonArray} containing + * {@link ByteBufBsonDocument} instances for each document.

+ * + *

The documents are lazily parsed on first access and cached for subsequent calls.

+ */ + private static final class SequenceField { + /** Buffer containing the sequence of BSON documents */ + private final ByteBuf sequenceByteBuf; - }; - } + /** Reference to parent's tracked resources for registering created documents */ + private final List trackedResources; - @Override - public boolean isEmpty() { - return !iterator().hasNext(); - } + /** Cached list of parsed documents, populated on first access */ + private List documents; - @Override - public int size() { - return ByteBufBsonDocument.this.size(); + SequenceField(final ByteBuf sequenceByteBuf, final List trackedResources) { + this.sequenceByteBuf = sequenceByteBuf; + this.trackedResources = trackedResources; + trackedResources.add(sequenceByteBuf::release); } - } - private class ByteBufBsonDocumentKeySet extends AbstractSet { - @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") - private final Set> entrySet = new ByteBufBsonDocumentEntrySet(); - - @Override - public Iterator iterator() { - final Iterator> entrySetIterator = entrySet.iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - return entrySetIterator.hasNext(); - } + /** + * Returns this sequence as a BsonArray of ByteBufBsonDocument instances. + * + *

On first call, parses the buffer to create ByteBufBsonDocument for each + * document and registers them with the parent's tracked resources.

+ * + * @return A BsonArray containing the sequence documents + */ + BsonValue asArray() { + if (documents == null) { + documents = new ArrayList<>(); + ByteBuf dup = sequenceByteBuf.duplicate(); + try { + while (dup.hasRemaining()) { + // Read document size to determine bounds + int docStart = dup.position(); + int docSize = dup.getInt(); + int docEnd = docStart + docSize; - @Override - public String next() { - return entrySetIterator.next().getKey(); + // Create a view of just this document's bytes + ByteBuf docBuf = sequenceByteBuf.duplicate().position(docStart).limit(docEnd); + ByteBufBsonDocument doc = new ByteBufBsonDocument(docBuf); + // Track for cleanup when parent is closed + trackedResources.add(doc); + documents.add(doc); + dup.position(docEnd); + } + } finally { + dup.release(); } - }; + } + // Return a new array each time to prevent external modification of cached list + return new BsonArray(new ArrayList<>(documents)); } - @Override - public boolean isEmpty() { - return entrySet.isEmpty(); + /** + * Checks if this sequence contains the given value. + */ + boolean containsValue(final Object value) { + return value instanceof BsonValue && asArray().asArray().contains(value); } - @Override - public int size() { - return entrySet.size(); + /** + * Converts this sequence to a BsonArray of regular BsonDocument instances. + * + *

Used by {@link ByteBufBsonDocument#toBsonDocument()} to fully hydrate the document. + * Unlike {@link #asArray()}, this creates regular BsonDocument instances, not + * ByteBufBsonDocument wrappers.

+ * + * @return A BsonArray containing fully deserialized BsonDocument instances + */ + BsonArray toHydratedArray() { + ByteBuf dup = sequenceByteBuf.duplicate(); + try { + List hydratedDocs = new ArrayList<>(); + while (dup.hasRemaining()) { + int docStart = dup.position(); + int docSize = dup.getInt(); + int docEnd = docStart + docSize; + ByteBuf docBuf = dup.duplicate().position(docStart).limit(docEnd); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(docBuf))) { + hydratedDocs.add(new BsonDocumentCodec().decode(reader, DecoderContext.builder().build())); + } finally { + docBuf.release(); + } + dup.position(docEnd); + } + return new BsonArray(hydratedDocs); + } finally { + dup.release(); + } } } - private class ByteBufBsonDocumentValuesCollection extends AbstractCollection { - @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") - private final Set> entrySet = new ByteBufBsonDocumentEntrySet(); - - @Override - public Iterator iterator() { - final Iterator> entrySetIterator = entrySet.iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - return entrySetIterator.hasNext(); - } + /** + * An iterator that combines two iterators sequentially. + * + *

Used to merge body field iteration with sequence field iteration, + * presenting a unified view of all document fields.

+ * + * @param The type of elements returned by the iterator + */ + private static final class CombinedIterator implements Iterator { + private final Iterator primary; + private final Iterator secondary; - @Override - public BsonValue next() { - return entrySetIterator.next().getValue(); - } - }; + CombinedIterator(final Iterator primary, final Iterator secondary) { + this.primary = primary; + this.secondary = secondary; } @Override - public boolean isEmpty() { - return entrySet.isEmpty(); + public boolean hasNext() { + return primary.hasNext() || secondary.hasNext(); } + @Override - public int size() { - return entrySet.size(); + public T next() { + if (primary.hasNext()) { + return primary.next(); + } + if (secondary.hasNext()) { + return secondary.next(); + } + throw new NoSuchElementException(); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java index 55054112bf2..4d4d4846afa 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java @@ -38,18 +38,25 @@ import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.DecoderContext; +import java.io.Closeable; +import java.util.List; + final class ByteBufBsonHelper { - static BsonValue readBsonValue(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { + static BsonValue readBsonValue(final ByteBuf byteBuf, final BsonBinaryReader bsonReader, final List trackedResources) { BsonValue value; switch (bsonReader.getCurrentBsonType()) { case DOCUMENT: ByteBuf documentByteBuf = byteBuf.duplicate(); - value = new ByteBufBsonDocument(documentByteBuf); + ByteBufBsonDocument document = new ByteBufBsonDocument(documentByteBuf); + trackedResources.add(document); + value = document; bsonReader.skipValue(); break; case ARRAY: ByteBuf arrayByteBuf = byteBuf.duplicate(); - value = new ByteBufBsonArray(arrayByteBuf); + ByteBufBsonArray array = new ByteBufBsonArray(arrayByteBuf); + trackedResources.add(array); + value = array; bsonReader.skipValue(); break; case INT32: diff --git a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java index 348349fd18c..6f300dc226b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java @@ -17,17 +17,16 @@ package com.mongodb.internal.connection; import com.mongodb.MongoClientException; -import com.mongodb.MongoInternalException; import com.mongodb.MongoNamespace; import com.mongodb.ReadPreference; import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.internal.MongoNamespaceHelper; +import com.mongodb.internal.ResourceUtil; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.session.SessionContext; import com.mongodb.lang.Nullable; -import org.bson.BsonArray; import org.bson.BsonBinaryWriter; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -38,9 +37,6 @@ import org.bson.FieldNameValidator; import org.bson.io.BsonOutput; -import java.io.ByteArrayOutputStream; -import java.io.UnsupportedEncodingException; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -61,8 +57,6 @@ import static com.mongodb.internal.connection.BsonWriterHelper.encodeUsingRegistry; import static com.mongodb.internal.connection.BsonWriterHelper.writeDocumentsOfDualMessageSequences; import static com.mongodb.internal.connection.BsonWriterHelper.writePayload; -import static com.mongodb.internal.connection.ByteBufBsonDocument.createList; -import static com.mongodb.internal.connection.ByteBufBsonDocument.createOne; import static com.mongodb.internal.connection.ReadConcernHelper.getReadConcernDocument; import static com.mongodb.internal.operation.ServerVersionHelper.UNKNOWN_WIRE_VERSION; @@ -143,74 +137,26 @@ public final class CommandMessage extends RequestMessage { } /** - * Create a BsonDocument representing the logical document encoded by an OP_MSG. + * Create a ByteBufBsonDocument representing the logical document encoded by an OP_MSG. *

* The returned document will contain all the fields from the `PAYLOAD_TYPE_0_DOCUMENT` section, as well as all fields represented by * `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE` sections. + * + *

Note: This document MUST be closed after use, otherwise when using Netty it could report the leaking of resources when the + * underlying {@code byteBuf's} are garbage collected */ - BsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { + ByteBufBsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { List byteBuffers = bsonOutput.getByteBuffers(); try { - CompositeByteBuf byteBuf = new CompositeByteBuf(byteBuffers); + CompositeByteBuf compositeByteBuf = new CompositeByteBuf(byteBuffers); try { - byteBuf.position(firstDocumentPosition); - ByteBufBsonDocument byteBufBsonDocument = createOne(byteBuf); - - // If true, it means there is at least one `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE` section in the OP_MSG - if (byteBuf.hasRemaining()) { - BsonDocument commandBsonDocument = byteBufBsonDocument.toBaseBsonDocument(); - - // Each loop iteration processes one Document Sequence - // When there are no more bytes remaining, there are no more Document Sequences - while (byteBuf.hasRemaining()) { - // skip reading the payload type, we know it is `PAYLOAD_TYPE_1` - byteBuf.position(byteBuf.position() + 1); - int sequenceStart = byteBuf.position(); - int sequenceSizeInBytes = byteBuf.getInt(); - int sectionEnd = sequenceStart + sequenceSizeInBytes; - - String fieldName = getSequenceIdentifier(byteBuf); - // If this assertion fires, it means that the driver has started using document sequences for nested fields. If - // so, this method will need to change in order to append the value to the correct nested document. - assertFalse(fieldName.contains(".")); - - ByteBuf documentsByteBufSlice = byteBuf.duplicate().limit(sectionEnd); - try { - commandBsonDocument.append(fieldName, new BsonArray(createList(documentsByteBufSlice))); - } finally { - documentsByteBufSlice.release(); - } - byteBuf.position(sectionEnd); - } - return commandBsonDocument; - } else { - return byteBufBsonDocument; - } + compositeByteBuf.position(firstDocumentPosition); + return ByteBufBsonDocument.createCommandMessage(compositeByteBuf); } finally { - byteBuf.release(); + compositeByteBuf.release(); } } finally { - byteBuffers.forEach(ByteBuf::release); - } - } - - /** - * Get the field name from a buffer positioned at the start of the document sequence identifier of an OP_MSG Section of type - * `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE`. - *

- * Upon normal completion of the method, the buffer will be positioned at the start of the first BSON object in the sequence. - */ - private String getSequenceIdentifier(final ByteBuf byteBuf) { - ByteArrayOutputStream sequenceIdentifierBytes = new ByteArrayOutputStream(); - byte curByte = byteBuf.get(); - while (curByte != 0) { - sequenceIdentifierBytes.write(curByte); - curByte = byteBuf.get(); - } - try { - return sequenceIdentifierBytes.toString(StandardCharsets.UTF_8.name()); - } catch (UnsupportedEncodingException e) { - throw new MongoInternalException("Unexpected exception", e); + ResourceUtil.release(byteBuffers); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 6b20c467191..aeef4e0a6a1 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -53,12 +53,10 @@ import com.mongodb.internal.session.SessionContext; import com.mongodb.internal.time.Timeout; import com.mongodb.lang.Nullable; -import org.bson.BsonBinaryReader; import org.bson.BsonDocument; import org.bson.ByteBuf; import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.Decoder; -import org.bson.io.ByteBufferBsonInput; import java.io.IOException; import java.net.SocketTimeoutException; @@ -444,46 +442,44 @@ private T sendAndReceiveInternal(final CommandMessage message, final Decoder Span tracingSpan; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, operationContext); - tracingSpan = operationContext - .getTracingManager() - .createTracingSpan(message, - operationContext, - () -> message.getCommandDocument(bsonOutput), - cmdName -> SECURITY_SENSITIVE_COMMANDS.contains(cmdName) - || SECURITY_SENSITIVE_HELLO_COMMANDS.contains(cmdName), - () -> getDescription().getServerAddress(), - () -> getDescription().getConnectionId() - ); - - boolean isLoggingCommandNeeded = isLoggingCommandNeeded(); - boolean isTracingCommandPayloadNeeded = tracingSpan != null && operationContext.getTracingManager().isCommandPayloadEnabled(); - - // Only hydrate the command document if necessary - BsonDocument commandDocument = null; - if (isLoggingCommandNeeded || isTracingCommandPayloadNeeded) { - commandDocument = message.getCommandDocument(bsonOutput); - } - if (isLoggingCommandNeeded) { - commandEventSender = new LoggingCommandEventSender( - SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, - operationContext, message, commandDocument, - COMMAND_PROTOCOL_LOGGER, loggerSettings); - commandEventSender.sendStartedEvent(); - } else { - commandEventSender = new NoOpCommandEventSender(); - } - if (isTracingCommandPayloadNeeded) { - tracingSpan.tagHighCardinality(QUERY_TEXT.asString(), commandDocument); - } + try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { + tracingSpan = operationContext + .getTracingManager() + .createTracingSpan(message, + operationContext, + commandDocument, + cmdName -> SECURITY_SENSITIVE_COMMANDS.contains(cmdName) + || SECURITY_SENSITIVE_HELLO_COMMANDS.contains(cmdName), + () -> getDescription().getServerAddress(), + () -> getDescription().getConnectionId() + ); + + boolean isLoggingCommandNeeded = isLoggingCommandNeeded(); + + if (isLoggingCommandNeeded) { + commandEventSender = new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + operationContext, message, commandDocument, + COMMAND_PROTOCOL_LOGGER, loggerSettings); + commandEventSender.sendStartedEvent(); + } else { + commandEventSender = new NoOpCommandEventSender(); + } - try { - sendCommandMessage(message, bsonOutput, operationContext); - } catch (Exception e) { - if (tracingSpan != null) { - tracingSpan.error(e); + boolean isTracingCommandPayloadNeeded = tracingSpan != null && operationContext.getTracingManager().isCommandPayloadEnabled(); + if (isTracingCommandPayloadNeeded) { + tracingSpan.tagHighCardinality(QUERY_TEXT.asString(), commandDocument); + } + + try { + sendCommandMessage(commandDocument.getFirstKey(), message, bsonOutput, operationContext); + } catch (Exception e) { + if (tracingSpan != null) { + tracingSpan.error(e); + } + commandEventSender.sendFailedEvent(e); + throw e; } - commandEventSender.sendFailedEvent(e); - throw e; } } @@ -502,7 +498,9 @@ private T sendAndReceiveInternal(final CommandMessage message, final Decoder public void send(final CommandMessage message, final Decoder decoder, final OperationContext operationContext) { try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, operationContext); - sendCommandMessage(message, bsonOutput, operationContext); + try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { + sendCommandMessage(commandDocument.getFirstKey(), message, bsonOutput, operationContext); + } if (message.isResponseExpected()) { hasMoreToCome = true; } @@ -520,27 +518,41 @@ public boolean hasMoreToCome() { return hasMoreToCome; } - private void sendCommandMessage(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, - final OperationContext operationContext) { + private void sendCommandMessage(final String commandName, final CommandMessage message, + final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { + List messageByteBuffers = getMessageByteBuffers(commandName, message, bsonOutput, operationContext); + try { + Timeout.onExistsAndExpired(operationContext.getTimeoutContext().timeoutIncludingRoundTrip(), () -> { + throw TimeoutContext.createMongoRoundTripTimeoutException(); + }); + sendMessage(messageByteBuffers, message.getId(), operationContext); + } finally { + ResourceUtil.release(messageByteBuffers); + } + responseTo = message.getId(); + } + private List getMessageByteBuffers(final String commandName, final CommandMessage message, + final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { Compressor localSendCompressor = sendCompressor; - if (localSendCompressor == null || SECURITY_SENSITIVE_COMMANDS.contains(message.getCommandDocument(bsonOutput).getFirstKey())) { - trySendMessage(message, bsonOutput, operationContext); + List messageByteBuffers; + // Check if compressed + if (localSendCompressor == null || SECURITY_SENSITIVE_COMMANDS.contains(commandName)) { + messageByteBuffers = bsonOutput.getByteBuffers(); } else { - ByteBufferBsonOutput compressedBsonOutput; List byteBuffers = bsonOutput.getByteBuffers(); try { CompressedMessage compressedMessage = new CompressedMessage(message.getOpCode(), byteBuffers, localSendCompressor, getMessageSettings(description, initialServerDescription)); - compressedBsonOutput = new ByteBufferBsonOutput(this); - compressedMessage.encode(compressedBsonOutput, operationContext); + try (ByteBufferBsonOutput compressedBsonOutput = new ByteBufferBsonOutput(this)) { + compressedMessage.encode(compressedBsonOutput, operationContext); + messageByteBuffers = compressedBsonOutput.getByteBuffers(); + } } finally { ResourceUtil.release(byteBuffers); - bsonOutput.close(); } - trySendMessage(message, compressedBsonOutput, operationContext); } - responseTo = message.getId(); + return messageByteBuffers; } private void trySendMessage(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, @@ -598,60 +610,54 @@ private T receiveCommandMessageResponse(final Decoder decoder, final Comm } private void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder decoder, - final OperationContext operationContext, final SingleResultCallback callback) { + final OperationContext operationContext, final SingleResultCallback callback) { if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); return; } + // Async try with resources release after the write ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this); - ByteBufferBsonOutput compressedBsonOutput = new ByteBufferBsonOutput(this); - try { message.encode(bsonOutput, operationContext); + String commandName; CommandEventSender commandEventSender; - if (isLoggingCommandNeeded()) { - BsonDocument commandDocument = message.getCommandDocument(bsonOutput); - commandEventSender = new LoggingCommandEventSender( - SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, - operationContext, message, commandDocument, - COMMAND_PROTOCOL_LOGGER, loggerSettings); - } else { - commandEventSender = new NoOpCommandEventSender(); - } - - commandEventSender.sendStartedEvent(); - Compressor localSendCompressor = sendCompressor; - if (localSendCompressor == null || SECURITY_SENSITIVE_COMMANDS.contains(message.getCommandDocument(bsonOutput).getFirstKey())) { - sendCommandMessageAsync(message.getId(), decoder, operationContext, callback, bsonOutput, commandEventSender, - message.isResponseExpected()); - } else { - List byteBuffers = bsonOutput.getByteBuffers(); - try { - CompressedMessage compressedMessage = new CompressedMessage(message.getOpCode(), byteBuffers, localSendCompressor, - getMessageSettings(description, initialServerDescription)); - compressedMessage.encode(compressedBsonOutput, operationContext); - } finally { - ResourceUtil.release(byteBuffers); - bsonOutput.close(); + try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { + commandName = commandDocument.getFirstKey(); + if (isLoggingCommandNeeded()) { + commandEventSender = new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + operationContext, message, commandDocument, + COMMAND_PROTOCOL_LOGGER, loggerSettings); + } else { + commandEventSender = new NoOpCommandEventSender(); } - sendCommandMessageAsync(message.getId(), decoder, operationContext, callback, compressedBsonOutput, commandEventSender, - message.isResponseExpected()); + commandEventSender.sendStartedEvent(); } + + List messageByteBuffers = getMessageByteBuffers(commandName, message, bsonOutput, operationContext); + sendCommandMessageAsync(messageByteBuffers, message.getId(), decoder, operationContext, + commandEventSender, message.isResponseExpected(), (r, t) -> { + ResourceUtil.release(messageByteBuffers); + bsonOutput.close(); // Close AFTER async write completes + if (t != null) { + callback.onResult(null, t); + } else { + callback.onResult(r, null); + } + }); } catch (Throwable t) { bsonOutput.close(); - compressedBsonOutput.close(); callback.onResult(null, t); } } - private void sendCommandMessageAsync(final int messageId, final Decoder decoder, final OperationContext operationContext, - final SingleResultCallback callback, final ByteBufferBsonOutput bsonOutput, - final CommandEventSender commandEventSender, final boolean responseExpected) { + private void sendCommandMessageAsync(final List messageByteBuffers, final int messageId, final Decoder decoder, + final OperationContext operationContext, final CommandEventSender commandEventSender, + final boolean responseExpected, final SingleResultCallback callback) { boolean[] shouldReturn = {false}; Timeout.onExistsAndExpired(operationContext.getTimeoutContext().timeoutIncludingRoundTrip(), () -> { - bsonOutput.close(); MongoOperationTimeoutException operationTimeoutException = TimeoutContext.createMongoRoundTripTimeoutException(); commandEventSender.sendFailedEvent(operationTimeoutException); callback.onResult(null, operationTimeoutException); @@ -661,10 +667,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d return; } - List byteBuffers = bsonOutput.getByteBuffers(); - sendMessageAsync(byteBuffers, messageId, operationContext, (result, t) -> { - ResourceUtil.release(byteBuffers); - bsonOutput.close(); + sendMessageAsync(messageByteBuffers, messageId, operationContext, (result, t) -> { if (t != null) { commandEventSender.sendFailedEvent(t); callback.onResult(null, t); @@ -682,18 +685,16 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d T commandResult; try { updateSessionContext(operationContext.getSessionContext(), responseBuffers); - boolean commandOk = - isCommandOk(new BsonBinaryReader(new ByteBufferBsonInput(responseBuffers.getBodyByteBuffer()))); - responseBuffers.reset(); - if (!commandOk) { + + if (!isCommandOk(responseBuffers)) { MongoException commandFailureException = getCommandFailureException( responseBuffers.getResponseDocument(messageId, new BsonDocumentCodec()), description.getServerAddress(), operationContext.getTimeoutContext()); commandEventSender.sendFailedEvent(commandFailureException); throw commandFailureException; } - commandEventSender.sendSucceededEvent(responseBuffers); + commandEventSender.sendSucceededEvent(responseBuffers); commandResult = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext()); } catch (Throwable localThrowable) { callback.onResult(null, localThrowable); diff --git a/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java b/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java index 2dadd11efec..b26cb396e7b 100644 --- a/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java +++ b/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java @@ -37,6 +37,7 @@ import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CLIENT_CONNECTION_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COLLECTION; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COMMAND_NAME; +import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CURSOR_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NAMESPACE; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NETWORK_TRANSPORT; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.QUERY_SUMMARY; @@ -46,7 +47,6 @@ import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SESSION_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SYSTEM; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.TRANSACTION_NUMBER; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CURSOR_ID; import static java.lang.System.getenv; /** @@ -178,7 +178,7 @@ public boolean isCommandPayloadEnabled() { * * @param message the command message to trace * @param operationContext the operation context containing tracing and session information - * @param commandDocumentSupplier a supplier that provides the command document when needed + * @param commandDocument the command document, note this is an internally managed resource * @param isSensitiveCommand a predicate that determines if a command is security-sensitive based on its name * @param serverAddressSupplier a supplier that provides the server address when needed * @param connectionIdSupplier a supplier that provides the connection ID when needed @@ -187,26 +187,26 @@ public boolean isCommandPayloadEnabled() { @Nullable public Span createTracingSpan(final CommandMessage message, final OperationContext operationContext, - final Supplier commandDocumentSupplier, + final BsonDocument commandDocument, final Predicate isSensitiveCommand, final Supplier serverAddressSupplier, final Supplier connectionIdSupplier ) { - if (!isEnabled()) { + if (!isEnabled()) { return null; } - BsonDocument command = commandDocumentSupplier.get(); - String commandName = command.getFirstKey(); + + String commandName = commandDocument.getFirstKey(); if (isSensitiveCommand.test(commandName)) { return null; } Span operationSpan = operationContext.getTracingSpan(); - Span span = addSpan(commandName, operationSpan != null ? operationSpan.context() : null); + Span span = addSpan(commandName, operationSpan != null ? operationSpan.context() : null); - if (command.containsKey("getMore")) { - long cursorId = command.getInt64("getMore").longValue(); + if (commandDocument.containsKey("getMore")) { + long cursorId = commandDocument.getInt64("getMore").longValue(); span.tagLowCardinality(CURSOR_ID.withValue(String.valueOf(cursorId))); if (operationSpan != null) { operationSpan.tagLowCardinality(CURSOR_ID.withValue(String.valueOf(cursorId))); diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java index f7cefbf57c0..637f89cb347 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java @@ -46,6 +46,7 @@ import org.bson.io.BasicOutputBuffer; import org.bson.types.Decimal128; import org.bson.types.ObjectId; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.io.ByteArrayOutputStream; @@ -63,145 +64,242 @@ import static org.bson.BsonBoolean.FALSE; import static org.bson.BsonBoolean.TRUE; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +@DisplayName("ByteBufBsonArray") class ByteBufBsonArrayTest { + // Basic Operations + @Test + @DisplayName("getValues() returns array values") void testGetValues() { List values = asList(new BsonInt32(0), new BsonInt32(1), new BsonInt32(2)); - ByteBufBsonArray bsonArray = fromBsonValues(values); - assertEquals(values, bsonArray.getValues()); + try (ByteBufBsonArray bsonArray = fromBsonValues(values)) { + assertEquals(values, bsonArray.getValues()); + } } @Test + @DisplayName("size() returns correct count") void testSize() { - assertEquals(0, fromBsonValues(emptyList()).size()); - assertEquals(1, fromBsonValues(singletonList(TRUE)).size()); - assertEquals(2, fromBsonValues(asList(TRUE, TRUE)).size()); + try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) { + assertEquals(0, bsonArray.size()); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) { + assertEquals(1, bsonArray.size()); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, TRUE))) { + assertEquals(2, bsonArray.size()); + } } @Test + @DisplayName("isEmpty() returns correct result") void testIsEmpty() { - assertTrue(fromBsonValues(emptyList()).isEmpty()); - assertFalse(fromBsonValues(singletonList(TRUE)).isEmpty()); - assertFalse(fromBsonValues(asList(TRUE, TRUE)).isEmpty()); + try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) { + assertTrue(bsonArray.isEmpty()); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) { + assertFalse(bsonArray.isEmpty()); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, TRUE))) { + assertFalse(bsonArray.isEmpty()); + } } @Test + @DisplayName("contains() finds existing values and rejects missing values") void testContains() { - assertFalse(fromBsonValues(emptyList()).contains(TRUE)); - assertTrue(fromBsonValues(singletonList(TRUE)).contains(TRUE)); - assertTrue(fromBsonValues(asList(FALSE, TRUE)).contains(TRUE)); - assertFalse(fromBsonValues(singletonList(FALSE)).contains(TRUE)); - assertFalse(fromBsonValues(asList(FALSE, FALSE)).contains(TRUE)); + try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) { + assertFalse(bsonArray.contains(TRUE)); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) { + assertTrue(bsonArray.contains(TRUE)); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(FALSE, TRUE))) { + assertTrue(bsonArray.contains(TRUE)); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(FALSE))) { + assertFalse(bsonArray.contains(TRUE)); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(FALSE, FALSE))) { + assertFalse(bsonArray.contains(TRUE)); + } } @Test + @DisplayName("iterator() navigates through all elements") void testIterator() { - Iterator iterator = fromBsonValues(emptyList()).iterator(); - assertFalse(iterator.hasNext()); - assertThrows(NoSuchElementException.class, iterator::next); - - iterator = fromBsonValues(singletonList(TRUE)).iterator(); - assertTrue(iterator.hasNext()); - assertEquals(TRUE, iterator.next()); - assertFalse(iterator.hasNext()); - assertThrows(NoSuchElementException.class, iterator::next); - - iterator = fromBsonValues(asList(TRUE, FALSE)).iterator(); - assertTrue(iterator.hasNext()); - assertEquals(TRUE, iterator.next()); - assertTrue(iterator.hasNext()); - assertEquals(FALSE, iterator.next()); - assertFalse(iterator.hasNext()); - assertThrows(NoSuchElementException.class, iterator::next); + try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) { + Iterator iterator = bsonArray.iterator(); + assertFalse(iterator.hasNext()); + assertThrows(NoSuchElementException.class, iterator::next); + } + + try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) { + Iterator iterator = bsonArray.iterator(); + assertTrue(iterator.hasNext()); + assertEquals(TRUE, iterator.next()); + assertFalse(iterator.hasNext()); + assertThrows(NoSuchElementException.class, iterator::next); + } + + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + Iterator iterator = bsonArray.iterator(); + assertTrue(iterator.hasNext()); + assertEquals(TRUE, iterator.next()); + assertTrue(iterator.hasNext()); + assertEquals(FALSE, iterator.next()); + assertFalse(iterator.hasNext()); + assertThrows(NoSuchElementException.class, iterator::next); + } + } + + @Test + @DisplayName("Iterators ensure the resource is still open") + void iteratorsEnsureResourceIsStillOpen() { + ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE)); + Iterator arrayIterator = bsonArray.iterator(); + + assertDoesNotThrow(arrayIterator::hasNext); + + bsonArray.close(); + assertThrows(IllegalStateException.class, arrayIterator::hasNext); } @Test + @DisplayName("toArray() converts array to Object array") void testToArray() { - assertArrayEquals(new BsonValue[]{TRUE, FALSE}, fromBsonValues(asList(TRUE, FALSE)).toArray()); - assertArrayEquals(new BsonValue[]{TRUE, FALSE}, fromBsonValues(asList(TRUE, FALSE)).toArray(new BsonValue[0])); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertArrayEquals(new BsonValue[]{TRUE, FALSE}, bsonArray.toArray()); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertArrayEquals(new BsonValue[]{TRUE, FALSE}, bsonArray.toArray(new BsonValue[0])); + } } @Test + @DisplayName("containsAll() checks if all elements are present") void testContainsAll() { - assertTrue(fromBsonValues(asList(TRUE, FALSE)).containsAll(asList(TRUE, FALSE))); - assertFalse(fromBsonValues(asList(TRUE, TRUE)).containsAll(asList(TRUE, FALSE))); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertTrue(bsonArray.containsAll(asList(TRUE, FALSE))); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, TRUE))) { + assertFalse(bsonArray.containsAll(asList(TRUE, FALSE))); + } } @Test + @DisplayName("get() retrieves element at index and throws for out of bounds") void testGet() { - ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE)); - assertEquals(TRUE, bsonArray.get(0)); - assertEquals(FALSE, bsonArray.get(1)); - assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(-1)); - assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(2)); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertEquals(TRUE, bsonArray.get(0)); + assertEquals(FALSE, bsonArray.get(1)); + assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(2)); + } } @Test + @DisplayName("indexOf() finds element position or returns -1") void testIndexOf() { - ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE)); - assertEquals(0, bsonArray.indexOf(TRUE)); - assertEquals(1, bsonArray.indexOf(FALSE)); - assertEquals(-1, bsonArray.indexOf(BsonNull.VALUE)); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertEquals(0, bsonArray.indexOf(TRUE)); + assertEquals(1, bsonArray.indexOf(FALSE)); + assertEquals(-1, bsonArray.indexOf(BsonNull.VALUE)); + } } @Test + @DisplayName("lastIndexOf() finds last element position or returns -1") void testLastIndexOf() { - ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE, FALSE)); - assertEquals(2, bsonArray.lastIndexOf(TRUE)); - assertEquals(3, bsonArray.lastIndexOf(FALSE)); - assertEquals(-1, bsonArray.lastIndexOf(BsonNull.VALUE)); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE, FALSE))) { + assertEquals(2, bsonArray.lastIndexOf(TRUE)); + assertEquals(3, bsonArray.lastIndexOf(FALSE)); + assertEquals(-1, bsonArray.lastIndexOf(BsonNull.VALUE)); + } } @Test + @DisplayName("listIterator() supports bidirectional iteration") void testListIterator() { // implementation is delegated to ArrayList, so not much testing is needed - ListIterator iterator = fromBsonValues(emptyList()).listIterator(); - assertFalse(iterator.hasNext()); - assertFalse(iterator.hasPrevious()); + try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) { + ListIterator iterator = bsonArray.listIterator(); + assertFalse(iterator.hasNext()); + assertFalse(iterator.hasPrevious()); + } } @Test + @DisplayName("subList() returns subset of array elements") void testSubList() { - ByteBufBsonArray bsonArray = fromBsonValues(asList(new BsonInt32(0), new BsonInt32(1), new BsonInt32(2))); - assertEquals(emptyList(), bsonArray.subList(0, 0)); - assertEquals(singletonList(new BsonInt32(0)), bsonArray.subList(0, 1)); - assertEquals(singletonList(new BsonInt32(2)), bsonArray.subList(2, 3)); - assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(-1, 1)); - assertThrows(IllegalArgumentException.class, () -> bsonArray.subList(3, 2)); - assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(2, 4)); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(new BsonInt32(0), new BsonInt32(1), new BsonInt32(2)))) { + assertEquals(emptyList(), bsonArray.subList(0, 0)); + assertEquals(singletonList(new BsonInt32(0)), bsonArray.subList(0, 1)); + assertEquals(singletonList(new BsonInt32(2)), bsonArray.subList(2, 3)); + assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(-1, 1)); + assertThrows(IllegalArgumentException.class, () -> bsonArray.subList(3, 2)); + assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(2, 4)); + } } + // Equality and HashCode + @Test + @DisplayName("equals() and hashCode() work correctly") void testEquals() { - assertEquals(new BsonArray(asList(TRUE, FALSE)), fromBsonValues(asList(TRUE, FALSE))); - assertEquals(fromBsonValues(asList(TRUE, FALSE)), new BsonArray(asList(TRUE, FALSE))); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertEquals(new BsonArray(asList(TRUE, FALSE)), bsonArray); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertEquals(bsonArray, new BsonArray(asList(TRUE, FALSE))); + } - assertNotEquals(new BsonArray(asList(TRUE, FALSE)), fromBsonValues(asList(FALSE, TRUE))); - assertNotEquals(fromBsonValues(asList(TRUE, FALSE)), new BsonArray(asList(FALSE, TRUE))); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(FALSE, TRUE))) { + assertNotEquals(new BsonArray(asList(TRUE, FALSE)), bsonArray); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertNotEquals(bsonArray, new BsonArray(asList(FALSE, TRUE))); + } - assertNotEquals(new BsonArray(asList(TRUE, FALSE)), fromBsonValues(asList(TRUE, FALSE, TRUE))); - assertNotEquals(fromBsonValues(asList(TRUE, FALSE)), new BsonArray(asList(TRUE, FALSE, TRUE))); - assertNotEquals(fromBsonValues(asList(TRUE, FALSE, TRUE)), new BsonArray(asList(TRUE, FALSE))); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE))) { + assertNotEquals(new BsonArray(asList(TRUE, FALSE)), bsonArray); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertNotEquals(bsonArray, new BsonArray(asList(TRUE, FALSE, TRUE))); + } + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE))) { + assertNotEquals(bsonArray, new BsonArray(asList(TRUE, FALSE))); + } } @Test + @DisplayName("hashCode() is consistent with equals()") void testHashCode() { - assertEquals(new BsonArray(asList(TRUE, FALSE)).hashCode(), fromBsonValues(asList(TRUE, FALSE)).hashCode()); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertEquals(new BsonArray(asList(TRUE, FALSE)).hashCode(), bsonArray.hashCode()); + } } @Test + @DisplayName("toString() returns equivalent string") void testToString() { - assertEquals(new BsonArray(asList(TRUE, FALSE)).toString(), fromBsonValues(asList(TRUE, FALSE)).toString()); + try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) { + assertEquals(new BsonArray(asList(TRUE, FALSE)).toString(), bsonArray.toString()); + } } + // Type Support + @Test + @DisplayName("All BSON types are supported") void testAllBsonTypes() { BsonValue bsonNull = new BsonNull(); BsonValue bsonInt32 = new BsonInt32(42); @@ -225,30 +323,31 @@ void testAllBsonTypes() { BsonValue document = new BsonDocument("a", new BsonInt32(1)); BsonValue dbPointer = new BsonDbPointer("db.coll", new ObjectId()); - ByteBufBsonArray bsonArray = fromBsonValues(asList( + try (ByteBufBsonArray bsonArray = fromBsonValues(asList( bsonNull, bsonInt32, bsonInt64, bsonDecimal128, bsonBoolean, bsonDateTime, bsonDouble, bsonString, minKey, maxKey, - javaScript, objectId, scope, regularExpression, symbol, timestamp, undefined, binary, array, document, dbPointer)); - assertEquals(bsonNull, bsonArray.get(0)); - assertEquals(bsonInt32, bsonArray.get(1)); - assertEquals(bsonInt64, bsonArray.get(2)); - assertEquals(bsonDecimal128, bsonArray.get(3)); - assertEquals(bsonBoolean, bsonArray.get(4)); - assertEquals(bsonDateTime, bsonArray.get(5)); - assertEquals(bsonDouble, bsonArray.get(6)); - assertEquals(bsonString, bsonArray.get(7)); - assertEquals(minKey, bsonArray.get(8)); - assertEquals(maxKey, bsonArray.get(9)); - assertEquals(javaScript, bsonArray.get(10)); - assertEquals(objectId, bsonArray.get(11)); - assertEquals(scope, bsonArray.get(12)); - assertEquals(regularExpression, bsonArray.get(13)); - assertEquals(symbol, bsonArray.get(14)); - assertEquals(timestamp, bsonArray.get(15)); - assertEquals(undefined, bsonArray.get(16)); - assertEquals(binary, bsonArray.get(17)); - assertEquals(array, bsonArray.get(18)); - assertEquals(document, bsonArray.get(19)); - assertEquals(dbPointer, bsonArray.get(20)); + javaScript, objectId, scope, regularExpression, symbol, timestamp, undefined, binary, array, document, dbPointer))) { + assertEquals(bsonNull, bsonArray.get(0)); + assertEquals(bsonInt32, bsonArray.get(1)); + assertEquals(bsonInt64, bsonArray.get(2)); + assertEquals(bsonDecimal128, bsonArray.get(3)); + assertEquals(bsonBoolean, bsonArray.get(4)); + assertEquals(bsonDateTime, bsonArray.get(5)); + assertEquals(bsonDouble, bsonArray.get(6)); + assertEquals(bsonString, bsonArray.get(7)); + assertEquals(minKey, bsonArray.get(8)); + assertEquals(maxKey, bsonArray.get(9)); + assertEquals(javaScript, bsonArray.get(10)); + assertEquals(objectId, bsonArray.get(11)); + assertEquals(scope, bsonArray.get(12)); + assertEquals(regularExpression, bsonArray.get(13)); + assertEquals(symbol, bsonArray.get(14)); + assertEquals(timestamp, bsonArray.get(15)); + assertEquals(undefined, bsonArray.get(16)); + assertEquals(binary, bsonArray.get(17)); + assertEquals(array, bsonArray.get(18)); + assertEquals(document, bsonArray.get(19)); + assertEquals(dbPointer, bsonArray.get(20)); + } } static ByteBufBsonArray fromBsonValues(final List values) { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy deleted file mode 100644 index 8dc599706a9..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection - -import org.bson.BsonArray -import org.bson.BsonBinaryWriter -import org.bson.BsonBoolean -import org.bson.BsonDocument -import org.bson.BsonInt32 -import org.bson.BsonNull -import org.bson.BsonValue -import org.bson.ByteBuf -import org.bson.ByteBufNIO -import org.bson.codecs.BsonDocumentCodec -import org.bson.codecs.DecoderContext -import org.bson.codecs.EncoderContext -import org.bson.io.BasicOutputBuffer -import org.bson.json.JsonMode -import org.bson.json.JsonWriterSettings -import spock.lang.Specification - -import java.nio.ByteBuffer - -import static java.util.Arrays.asList - -class ByteBufBsonDocumentSpecification extends Specification { - def emptyDocumentByteBuf = new ByteBufNIO(ByteBuffer.wrap([5, 0, 0, 0, 0] as byte[])) - ByteBuf documentByteBuf - ByteBufBsonDocument emptyByteBufDocument = new ByteBufBsonDocument(emptyDocumentByteBuf) - def document = new BsonDocument() - .append('a', new BsonInt32(1)) - .append('b', new BsonInt32(2)) - .append('c', new BsonDocument('x', BsonBoolean.TRUE)) - .append('d', new BsonArray(asList(new BsonDocument('y', BsonBoolean.FALSE), new BsonInt32(1)))) - - ByteBufBsonDocument byteBufDocument - - def setup() { - def buffer = new BasicOutputBuffer() - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()) - ByteArrayOutputStream baos = new ByteArrayOutputStream() - buffer.pipe(baos) - documentByteBuf = new ByteBufNIO(ByteBuffer.wrap(baos.toByteArray())) - byteBufDocument = new ByteBufBsonDocument(documentByteBuf) - } - - def 'get should get the value of the given key'() { - expect: - emptyByteBufDocument.get('a') == null - byteBufDocument.get('z') == null - byteBufDocument.get('a') == new BsonInt32(1) - byteBufDocument.get('b') == new BsonInt32(2) - } - - def 'get should throw if the key is null'() { - when: - byteBufDocument.get(null) - - then: - thrown(IllegalArgumentException) - documentByteBuf.referenceCount == 1 - } - - def 'containKey should throw if the key name is null'() { - when: - byteBufDocument.containsKey(null) - - then: - thrown(IllegalArgumentException) - documentByteBuf.referenceCount == 1 - } - - def 'containsKey should find an existing key'() { - expect: - byteBufDocument.containsKey('a') - byteBufDocument.containsKey('b') - byteBufDocument.containsKey('c') - byteBufDocument.containsKey('d') - documentByteBuf.referenceCount == 1 - } - - def 'containsKey should not find a non-existing key'() { - expect: - !byteBufDocument.containsKey('e') - !byteBufDocument.containsKey('x') - !byteBufDocument.containsKey('y') - documentByteBuf.referenceCount == 1 - } - - def 'containValue should find an existing value'() { - expect: - byteBufDocument.containsValue(document.get('a')) - byteBufDocument.containsValue(document.get('b')) - byteBufDocument.containsValue(document.get('c')) - byteBufDocument.containsValue(document.get('d')) - documentByteBuf.referenceCount == 1 - } - - def 'containValue should not find a non-existing value'() { - expect: - !byteBufDocument.containsValue(new BsonInt32(3)) - !byteBufDocument.containsValue(new BsonDocument('e', BsonBoolean.FALSE)) - !byteBufDocument.containsValue(new BsonArray(asList(new BsonInt32(2), new BsonInt32(4)))) - documentByteBuf.referenceCount == 1 - } - - def 'isEmpty should return false when the document is not empty'() { - expect: - !byteBufDocument.isEmpty() - documentByteBuf.referenceCount == 1 - } - - def 'isEmpty should return true when the document is empty'() { - expect: - emptyByteBufDocument.isEmpty() - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct size'() { - expect: - emptyByteBufDocument.size() == 0 - byteBufDocument.size() == 4 - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct key set'() { - expect: - emptyByteBufDocument.keySet().isEmpty() - byteBufDocument.keySet() == ['a', 'b', 'c', 'd'] as Set - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct values set'() { - expect: - emptyByteBufDocument.values().isEmpty() - byteBufDocument.values() as Set == [document.get('a'), document.get('b'), document.get('c'), document.get('d')] as Set - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct entry set'() { - expect: - emptyByteBufDocument.entrySet().isEmpty() - byteBufDocument.entrySet() == [new TestEntry('a', document.get('a')), - new TestEntry('b', document.get('b')), - new TestEntry('c', document.get('c')), - new TestEntry('d', document.get('d'))] as Set - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'all write methods should throw UnsupportedOperationException'() { - when: - byteBufDocument.clear() - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.put('x', BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.append('x', BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.putAll(new BsonDocument('x', BsonNull.VALUE)) - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.remove(BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - } - - def 'should get first key'() { - expect: - byteBufDocument.getFirstKey() == document.keySet().iterator().next() - documentByteBuf.referenceCount == 1 - } - - def 'getFirstKey should throw NoSuchElementException if the document is empty'() { - when: - emptyByteBufDocument.getFirstKey() - - then: - thrown(NoSuchElementException) - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should create BsonReader'() { - when: - def reader = document.asBsonReader() - - then: - new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()) == document - - cleanup: - reader.close() - } - - def 'clone should make a deep copy'() { - when: - BsonDocument cloned = byteBufDocument.clone() - - then: - cloned == byteBufDocument - documentByteBuf.referenceCount == 1 - } - - def 'should serialize and deserialize'() { - given: - def baos = new ByteArrayOutputStream() - def oos = new ObjectOutputStream(baos) - - when: - oos.writeObject(byteBufDocument) - def bais = new ByteArrayInputStream(baos.toByteArray()) - def ois = new ObjectInputStream(bais) - def deserializedDocument = ois.readObject() - - then: - byteBufDocument == deserializedDocument - documentByteBuf.referenceCount == 1 - } - - def 'toJson should return equivalent'() { - expect: - document.toJson() == byteBufDocument.toJson() - documentByteBuf.referenceCount == 1 - } - - def 'toJson should be callable multiple times'() { - expect: - byteBufDocument.toJson() - byteBufDocument.toJson() - documentByteBuf.referenceCount == 1 - } - - def 'size should be callable multiple times'() { - expect: - byteBufDocument.size() - byteBufDocument.size() - documentByteBuf.referenceCount == 1 - } - - def 'toJson should respect JsonWriteSettings'() { - given: - def settings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build() - - expect: - document.toJson(settings) == byteBufDocument.toJson(settings) - } - - def 'toJson should return equivalent when a ByteBufBsonDocument is nested in a BsonDocument'() { - given: - def topLevel = new BsonDocument('nested', byteBufDocument) - - expect: - new BsonDocument('nested', document).toJson() == topLevel.toJson() - } - - class TestEntry implements Map.Entry { - - private final String key - private BsonValue value - - TestEntry(String key, BsonValue value) { - this.key = key - this.value = value - } - - @Override - String getKey() { - key - } - - @Override - BsonValue getValue() { - value - } - - @Override - BsonValue setValue(final BsonValue value) { - this.value = value - } - } - -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java new file mode 100644 index 00000000000..1f61f309d14 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java @@ -0,0 +1,795 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import org.bson.BsonArray; +import org.bson.BsonBinaryWriter; +import org.bson.BsonBoolean; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonReader; +import org.bson.BsonString; +import org.bson.BsonValue; +import org.bson.ByteBuf; +import org.bson.ByteBufNIO; +import org.bson.RawBsonDocument; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.DecoderContext; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; +import org.bson.json.JsonMode; +import org.bson.json.JsonWriterSettings; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@DisplayName("ByteBufBsonDocument") +class ByteBufBsonDocumentTest { + private ByteBuf documentByteBuf; + private ByteBufBsonDocument emptyByteBufDocument; + private BsonDocument document; + + @BeforeEach + void setUp() { + ByteBuf emptyDocumentByteBuf = new ByteBufNIO(ByteBuffer.wrap(new byte[]{5, 0, 0, 0, 0})); + emptyByteBufDocument = new ByteBufBsonDocument(emptyDocumentByteBuf); + + document = new BsonDocument() + .append("a", new BsonInt32(1)) + .append("b", new BsonInt32(2)) + .append("c", new BsonDocument("x", BsonBoolean.TRUE)) + .append("d", new BsonArray(asList( + new BsonDocument("y", BsonBoolean.FALSE), + new BsonInt32(1) + ))); + + RawBsonDocument rawBsonDocument = RawBsonDocument.parse(document.toString()); + documentByteBuf = rawBsonDocument.getByteBuffer(); + } + + @AfterEach + void tearDown() { + if (emptyByteBufDocument != null) { + emptyByteBufDocument.close(); + } + } + + // Basic Operations + + @Test + @DisplayName("get() returns value for existing key, null for missing key") + void getShouldReturnCorrectValue() { + assertNull(emptyByteBufDocument.get("a")); + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertNull(byteBufDocument.get("z")); + assertEquals(new BsonInt32(1), byteBufDocument.get("a")); + assertEquals(new BsonInt32(2), byteBufDocument.get("b")); + } + } + + @Test + @DisplayName("get() throws IllegalArgumentException for null key") + void getShouldThrowForNullKey() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertThrows(IllegalArgumentException.class, () -> byteBufDocument.get(null)); + assertEquals(1, documentByteBuf.getReferenceCount()); + } + } + + @Test + @DisplayName("containsKey() finds existing keys and rejects missing keys") + void containsKeyShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertThrows(IllegalArgumentException.class, () -> byteBufDocument.containsKey(null)); + assertTrue(byteBufDocument.containsKey("a")); + assertTrue(byteBufDocument.containsKey("d")); + assertFalse(byteBufDocument.containsKey("z")); + assertEquals(1, documentByteBuf.getReferenceCount()); + } + } + + @Test + @DisplayName("containsValue() finds existing values and rejects missing values") + void containsValueShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertTrue(byteBufDocument.containsValue(document.get("a"))); + assertTrue(byteBufDocument.containsValue(document.get("c"))); + assertFalse(byteBufDocument.containsValue(new BsonInt32(999))); + assertEquals(1, documentByteBuf.getReferenceCount()); + } + } + + @Test + @DisplayName("isEmpty() returns correct result") + void isEmptyShouldWork() { + assertTrue(emptyByteBufDocument.isEmpty()); + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertFalse(byteBufDocument.isEmpty()); + } + } + + @Test + @DisplayName("size() returns correct count") + void sizeShouldWork() { + assertEquals(0, emptyByteBufDocument.size()); + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(4, byteBufDocument.size()); + assertEquals(4, byteBufDocument.size()); // Verify caching works + } + } + + @Test + @DisplayName("getFirstKey() returns first key or throws for empty document") + void getFirstKeyShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals("a", byteBufDocument.getFirstKey()); + } + assertThrows(NoSuchElementException.class, () -> emptyByteBufDocument.getFirstKey()); + } + + // Collection Views + + @Test + @DisplayName("keySet() returns all keys") + void keySetShouldWork() { + assertTrue(emptyByteBufDocument.keySet().isEmpty()); + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(new HashSet<>(asList("a", "b", "c", "d")), byteBufDocument.keySet()); + } + } + + @Test + @DisplayName("values() returns all values") + void valuesShouldWork() { + assertTrue(emptyByteBufDocument.values().isEmpty()); + Set expected = new HashSet<>(asList( + document.get("a"), document.get("b"), document.get("c"), document.get("d") + )); + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(expected, new HashSet<>(byteBufDocument.values())); + } + } + + @Test + @DisplayName("entrySet() returns all entries") + void entrySetShouldWork() { + assertTrue(emptyByteBufDocument.entrySet().isEmpty()); + Set> expected = new HashSet<>(asList( + new AbstractMap.SimpleImmutableEntry<>("a", document.get("a")), + new AbstractMap.SimpleImmutableEntry<>("b", document.get("b")), + new AbstractMap.SimpleImmutableEntry<>("c", document.get("c")), + new AbstractMap.SimpleImmutableEntry<>("d", document.get("d")) + )); + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(expected, byteBufDocument.entrySet()); + } + } + + // Type-Specific Accessors + + @Test + @DisplayName("getDocument() returns nested document") + void getDocumentShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + BsonDocument nested = byteBufDocument.getDocument("c"); + assertNotNull(nested); + assertEquals(BsonBoolean.TRUE, nested.get("x")); + } + } + + @Test + @DisplayName("getArray() returns array") + void getArrayShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + BsonArray array = byteBufDocument.getArray("d"); + assertNotNull(array); + assertEquals(2, array.size()); + } + } + + @Test + @DisplayName("get() with default value works correctly") + void getWithDefaultShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(new BsonInt32(1), byteBufDocument.get("a", new BsonInt32(999))); + assertEquals(new BsonInt32(999), byteBufDocument.get("missing", new BsonInt32(999))); + } + } + + @Test + @DisplayName("Type check methods return correct results") + void typeChecksShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertTrue(byteBufDocument.isNumber("a")); + assertTrue(byteBufDocument.isInt32("a")); + assertTrue(byteBufDocument.isDocument("c")); + assertTrue(byteBufDocument.isArray("d")); + assertFalse(byteBufDocument.isDocument("a")); + } + } + + // Immutability + + @Test + @DisplayName("All write methods throw UnsupportedOperationException") + void writeMethodsShouldThrow() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.clear()); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.put("x", new BsonInt32(1))); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.append("x", new BsonInt32(1))); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.putAll(new BsonDocument())); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.remove("a")); + } + } + + // Conversion and Serialization + + @Test + @DisplayName("toBsonDocument() returns equivalent document and caches result") + void toBsonDocumentShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(document, byteBufDocument.toBsonDocument()); + BsonDocument first = byteBufDocument.toBsonDocument(); + BsonDocument second = byteBufDocument.toBsonDocument(); + assertEquals(first, second); + } } + + @Test + @DisplayName("asBsonReader() creates valid reader") + void asBsonReaderShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + try (BsonReader reader = byteBufDocument.asBsonReader()) { + BsonDocument decoded = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()); + assertEquals(document, decoded); + } + } + } + + @Test + @DisplayName("toJson() returns correct JSON ") + void toJsonShouldWork() { + ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf); + assertEquals(document.toJson(), byteBufDocument.toJson()); + byteBufDocument.close(); + + assertNotNull(byteBufDocument.toJson()); // Verify caching + } + + @Test + @DisplayName("toJson() returns correct JSON with different settings") + void toJsonWithSettingsShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + JsonWriterSettings shellSettings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build(); + assertEquals(document.toJson(shellSettings), byteBufDocument.toJson(shellSettings)); + } + } + + @Test + @DisplayName("toString() returns equivalent string") + void toStringShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(document.toString(), byteBufDocument.toString()); + } + } + + @Test + @DisplayName("clone() creates deep copy") + void cloneShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + BsonDocument cloned = byteBufDocument.clone(); + assertEquals(byteBufDocument, cloned); + } + } + + @Test + @DisplayName("Java serialization works correctly") + void serializationShouldWork() throws Exception { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + new ObjectOutputStream(baos).writeObject(byteBufDocument); + Object deserialized = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())).readObject(); + assertEquals(byteBufDocument, deserialized); + } + } + + // Equality and HashCode + + @Test + @DisplayName("equals() and hashCode() work correctly") + void equalsAndHashCodeShouldWork() { + try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { + assertEquals(document, byteBufDocument); + assertEquals(byteBufDocument, document); + assertEquals(document.hashCode(), byteBufDocument.hashCode()); + assertNotEquals(byteBufDocument, new BsonDocument("x", new BsonInt32(99))); + } + } + + // Resource Management + + @Test + @DisplayName("Closed document throws IllegalStateException on all operations") + void closedDocumentShouldThrow() { + ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf); + byteBufDocument.close(); + assertThrows(IllegalStateException.class, () -> byteBufDocument.size()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.isEmpty()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.containsKey("a")); + assertThrows(IllegalStateException.class, () -> byteBufDocument.get("a")); + assertThrows(IllegalStateException.class, () -> byteBufDocument.keySet()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.values()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.entrySet()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.getFirstKey()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.toBsonDocument()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.toJson()); + } + + @Test + @DisplayName("close() can be called multiple times safely") + void closeIsIdempotent() { + ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf); + byteBufDocument.close(); + byteBufDocument.close(); // Should not throw + } + + @Test + @DisplayName("Nested documents are closed when parent is closed") + void nestedDocumentsClosedWithParent() { + BsonDocument doc = new BsonDocument("outer", new BsonDocument("inner", new BsonInt32(42))); + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonDocument retrieved = byteBufDoc.getDocument("outer"); + byteBufDoc.close(); + + assertThrows(IllegalStateException.class, byteBufDoc::size); + if (retrieved instanceof ByteBufBsonDocument) { + assertThrows(IllegalStateException.class, retrieved::size); + } + } + + @Test + @DisplayName("Nested arrays are closed when parent is closed") + void nestedArraysClosedWithParent() { + BsonDocument doc = new BsonDocument("arr", new BsonArray(asList( + new BsonInt32(1), new BsonDocument("x", new BsonInt32(2)) + ))); + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonArray retrieved = byteBufDoc.getArray("arr"); + byteBufDoc.close(); + + assertThrows(IllegalStateException.class, byteBufDoc::size); + if (retrieved instanceof ByteBufBsonArray) { + assertThrows(IllegalStateException.class, retrieved::size); + } + } + + @Test + @DisplayName("Deeply nested structures are closed recursively") + void deeplyNestedClosedRecursively() { + BsonDocument doc = new BsonDocument() + .append("level1", new BsonArray(asList( + new BsonDocument("level2", new BsonDocument("level3", new BsonInt32(999))), + new BsonInt32(1) + ))) + .append("sibling", new BsonDocument("key", new BsonString("value"))); + + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonArray level1 = byteBufDoc.getArray("level1"); + byteBufDoc.getDocument("sibling"); + + if (level1.get(0).isDocument()) { + BsonDocument level2Doc = level1.get(0).asDocument(); + if (level2Doc.containsKey("level2")) { + assertEquals(new BsonInt32(999), level2Doc.getDocument("level2").get("level3")); + } + } + + byteBufDoc.close(); + assertThrows(IllegalStateException.class, byteBufDoc::size); + } + + @Test + @DisplayName("Iteration tracks resources correctly") + void iterationTracksResources() { + BsonDocument doc = new BsonDocument() + .append("doc1", new BsonDocument("a", new BsonInt32(1))) + .append("arr1", new BsonArray(asList(new BsonInt32(2), new BsonInt32(3)))) + .append("primitive", new BsonString("test")); + + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + int count = 0; + for (Map.Entry entry : byteBufDoc.entrySet()) { + assertNotNull(entry.getKey()); + assertNotNull(entry.getValue()); + count++; + } + assertEquals(3, count); + + byteBufDoc.close(); + assertThrows(IllegalStateException.class, byteBufDoc::size); + } + + @Test + @DisplayName("Iterators ensure the resource is still open") + void iteratorsEnsureResourceIsStillOpen() { + BsonDocument doc = new BsonDocument() + .append("doc1", new BsonDocument("a", new BsonInt32(1))) + .append("arr1", new BsonArray(asList(new BsonInt32(2), new BsonInt32(3)))) + .append("primitive", new BsonString("test")); + + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + Iterator keysIterator = byteBufDoc.keySet().iterator(); + assertDoesNotThrow(keysIterator::hasNext); + + Iterator nestedKeysIterator = byteBufDoc.getDocument("doc1").keySet().iterator(); + assertDoesNotThrow(nestedKeysIterator::hasNext); + + Iterator arrayIterator = byteBufDoc.getArray("arr1").iterator(); + assertDoesNotThrow(arrayIterator::hasNext); + + byteBufDoc.close(); + assertThrows(IllegalStateException.class, keysIterator::hasNext); + assertThrows(IllegalStateException.class, nestedKeysIterator::hasNext); + assertThrows(IllegalStateException.class, arrayIterator::hasNext); + } + + @Test + @DisplayName("toBsonDocument() handles nested structures and allows close") + void toBsonDocumentHandlesNestedStructures() { + BsonDocument complexDoc = new BsonDocument() + .append("doc", new BsonDocument("x", new BsonInt32(1))) + .append("arr", new BsonArray(asList(new BsonDocument("y", new BsonInt32(2)), new BsonInt32(3)))); + + ByteBuf buf = createByteBufFromDocument(complexDoc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonDocument hydrated = byteBufDoc.toBsonDocument(); + assertEquals(complexDoc, hydrated); + + byteBufDoc.close(); + } + + @Test + @DisplayName("cachedDocument is usable after close") + void cachedDocumentIsUsableAfterClose() { + BsonDocument complexDoc = new BsonDocument() + .append("doc", new BsonDocument("x", new BsonInt32(1))) + .append("arr", new BsonArray(asList(new BsonDocument("y", new BsonInt32(2)), new BsonInt32(3)))); + + ByteBuf buf = createByteBufFromDocument(complexDoc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + BsonDocument hydrated = byteBufDoc.toBsonDocument(); + + byteBufDoc.close(); + assertEquals(complexDoc, hydrated); + assertEquals(complexDoc.toJson(), hydrated.toJson()); + } + + // Sequence Fields (OP_MSG) + + @Test + @DisplayName("Sequence field is accessible as array of ByteBufBsonDocuments") + void sequenceFieldAccessibleAsArray() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 3)) { + + BsonValue documentsValue = commandDoc.get("documents"); + assertNotNull(documentsValue); + assertTrue(documentsValue.isArray()); + + BsonArray documents = documentsValue.asArray(); + assertEquals(3, documents.size()); + + for (int i = 0; i < 3; i++) { + BsonValue doc = documents.get(i); + assertInstanceOf(ByteBufBsonDocument.class, doc); + assertEquals(new BsonInt32(i), doc.asDocument().get("_id")); + assertEquals(new BsonString("doc" + i), doc.asDocument().get("name")); + } + } + } + + @Test + @DisplayName("Sequence field is included in size, keySet, values, and entrySet") + void sequenceFieldIncludedInCollectionViews() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + assertTrue(commandDoc.size() >= 3); + assertTrue(commandDoc.keySet().contains("documents")); + assertTrue(commandDoc.keySet().contains("insert")); + + boolean foundDocumentsArray = false; + for (BsonValue value : commandDoc.values()) { + if (value.isArray() && value.asArray().size() == 2) { + foundDocumentsArray = true; + break; + } + } + assertTrue(foundDocumentsArray); + + boolean foundDocumentsEntry = false; + for (Map.Entry entry : commandDoc.entrySet()) { + if ("documents".equals(entry.getKey())) { + foundDocumentsEntry = true; + assertEquals(2, entry.getValue().asArray().size()); + break; + } + } + assertTrue(foundDocumentsEntry); + } + } + + @Test + @DisplayName("containsKey and containsValue work with sequence fields") + void containsMethodsWorkWithSequenceFields() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 3)) { + + assertTrue(commandDoc.containsKey("documents")); + assertTrue(commandDoc.containsKey("insert")); + assertFalse(commandDoc.containsKey("nonexistent")); + + BsonDocument expectedDoc = new BsonDocument() + .append("_id", new BsonInt32(1)) + .append("name", new BsonString("doc1")); + assertTrue(commandDoc.containsValue(expectedDoc)); + } + } + + @Test + @DisplayName("Sequence field documents are closed when parent is closed") + void sequenceFieldDocumentsClosedWithParent() { + ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2); + + BsonArray documents = commandDoc.getArray("documents"); + List docRefs = new ArrayList<>(); + for (BsonValue doc : documents) { + docRefs.add(doc.asDocument()); + } + + commandDoc.close(); + output.close(); + + assertThrows(IllegalStateException.class, commandDoc::size); + for (BsonDocument doc : docRefs) { + if (doc instanceof ByteBufBsonDocument) { + assertThrows(IllegalStateException.class, doc::size); + } + } + } + + @Test + @DisplayName("Sequence field is cached on multiple access") + void sequenceFieldCached() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + BsonArray first = commandDoc.getArray("documents"); + BsonArray second = commandDoc.getArray("documents"); + assertNotNull(first); + assertEquals(first.size(), second.size()); + } + } + + @Test + @DisplayName("toBsonDocument() hydrates sequence fields to regular BsonDocuments") + void toBsonDocumentHydratesSequenceFields() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + BsonDocument hydrated = commandDoc.toBsonDocument(); + assertTrue(hydrated.containsKey("documents")); + + BsonArray documents = hydrated.getArray("documents"); + assertEquals(2, documents.size()); + for (BsonValue doc : documents) { + assertFalse(doc instanceof ByteBufBsonDocument); + } + } + } + + @Test + @DisplayName("Sequence field with nested documents works correctly") + void sequenceFieldWithNestedDocuments() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + ByteBufBsonDocument commandDoc = createNestedCommandMessageDocument(output); + + BsonArray documents = commandDoc.getArray("documents"); + assertEquals(2, documents.size()); + + BsonDocument firstDoc = documents.get(0).asDocument(); + BsonDocument nested = firstDoc.getDocument("nested"); + assertEquals(new BsonInt32(0), nested.get("inner")); + + BsonArray array = firstDoc.getArray("array"); + assertEquals(2, array.size()); + + commandDoc.close(); + } + } + + @Test + @DisplayName("Empty sequence field returns empty array") + void emptySequenceField() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 0)) { + + assertTrue(commandDoc.containsKey("insert")); + assertTrue(commandDoc.containsKey("documents")); + assertTrue(commandDoc.getArray("documents").isEmpty()); + } + } + + @Test + @DisplayName("getFirstKey() returns body field, not sequence field") + void getFirstKeyReturnsBodyField() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + assertEquals("insert", commandDoc.getFirstKey()); + } + } + + @Test + @DisplayName("toJson() includes sequence fields") + void toJsonIncludesSequenceFields() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + String json = commandDoc.toJson(); + assertTrue(json.contains("documents")); + assertTrue(json.contains("_id")); + } + } + + @Test + @DisplayName("equals() and hashCode() include sequence fields") + void equalsAndHashCodeIncludeSequenceFields() { + try (ByteBufferBsonOutput output1 = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc1 = createCommandMessageDocument(output1, 2); + ByteBufferBsonOutput output2 = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc2 = createCommandMessageDocument(output2, 2)) { + + assertEquals(commandDoc1.toBsonDocument(), commandDoc2.toBsonDocument()); + assertEquals(commandDoc1.hashCode(), commandDoc2.hashCode()); + } + } + + // --- Helper Methods --- + + private ByteBufBsonDocument createCommandMessageDocument(final ByteBufferBsonOutput output, final int numDocuments) { + BsonDocument bodyDoc = new BsonDocument() + .append("insert", new BsonString("test")) + .append("$db", new BsonString("db")); + + byte[] bodyBytes = encodeBsonDocument(bodyDoc); + List sequenceDocBytes = new ArrayList<>(); + for (int i = 0; i < numDocuments; i++) { + BsonDocument seqDoc = new BsonDocument() + .append("_id", new BsonInt32(i)) + .append("name", new BsonString("doc" + i)); + sequenceDocBytes.add(encodeBsonDocument(seqDoc)); + } + + writeOpMsgFormat(output, bodyBytes, "documents", sequenceDocBytes); + + List buffers = output.getByteBuffers(); + return ByteBufBsonDocument.createCommandMessage(new CompositeByteBuf(buffers)); + } + + private ByteBufBsonDocument createNestedCommandMessageDocument(final ByteBufferBsonOutput output) { + BsonDocument bodyDoc = new BsonDocument() + .append("insert", new BsonString("test")) + .append("$db", new BsonString("db")); + + byte[] bodyBytes = encodeBsonDocument(bodyDoc); + List sequenceDocBytes = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + BsonDocument seqDoc = new BsonDocument() + .append("_id", new BsonInt32(i)) + .append("nested", new BsonDocument("inner", new BsonInt32(i * 10))) + .append("array", new BsonArray(asList( + new BsonInt32(i), + new BsonDocument("arrayNested", new BsonString("value" + i)) + ))); + sequenceDocBytes.add(encodeBsonDocument(seqDoc)); + } + + writeOpMsgFormat(output, bodyBytes, "documents", sequenceDocBytes); + return ByteBufBsonDocument.createCommandMessage(new CompositeByteBuf(output.getByteBuffers())); + } + + private void writeOpMsgFormat(final ByteBufferBsonOutput output, final byte[] bodyBytes, + final String sequenceIdentifier, final List sequenceDocBytes) { + output.writeBytes(bodyBytes, 0, bodyBytes.length); + + int sequencePayloadSize = sequenceDocBytes.stream().mapToInt(b -> b.length).sum(); + int sequenceSectionSize = 4 + sequenceIdentifier.length() + 1 + sequencePayloadSize; + + output.writeByte(1); + output.writeInt32(sequenceSectionSize); + output.writeCString(sequenceIdentifier); + for (byte[] docBytes : sequenceDocBytes) { + output.writeBytes(docBytes, 0, docBytes.length); + } + } + + private static byte[] encodeBsonDocument(final BsonDocument doc) { + try { + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), doc, EncoderContext.builder().build()); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + buffer.pipe(baos); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static ByteBuf createByteBufFromDocument(final BsonDocument doc) { + return new ByteBufNIO(ByteBuffer.wrap(encodeBsonDocument(doc))); + } + + private static class SimpleBufferProvider implements BufferProvider { + @NotNull + @Override + public ByteBuf getBuffer(final int size) { + return new ByteBufNIO(ByteBuffer.allocate(size).order(ByteOrder.LITTLE_ENDIAN)); + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy deleted file mode 100644 index 77bdd5e2045..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy +++ /dev/null @@ -1,365 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection - - -import com.mongodb.MongoNamespace -import com.mongodb.ReadConcern -import com.mongodb.ReadPreference -import com.mongodb.connection.ClusterConnectionMode -import com.mongodb.connection.ServerType -import com.mongodb.internal.IgnorableRequestContext -import com.mongodb.internal.TimeoutContext -import com.mongodb.internal.bulk.InsertRequest -import com.mongodb.internal.bulk.WriteRequestWithIndex -import com.mongodb.internal.session.SessionContext -import com.mongodb.internal.validator.NoOpFieldNameValidator -import org.bson.BsonArray -import org.bson.BsonBinary -import org.bson.BsonDocument -import org.bson.BsonInt32 -import org.bson.BsonMaximumSizeExceededException -import org.bson.BsonString -import org.bson.BsonTimestamp -import org.bson.ByteBuf -import org.bson.ByteBufNIO -import org.bson.codecs.BsonDocumentCodec -import spock.lang.Specification - -import java.nio.ByteBuffer - -import static com.mongodb.internal.connection.SplittablePayload.Type.INSERT -import static com.mongodb.internal.operation.ServerVersionHelper.LATEST_WIRE_VERSION - -/** - * New tests must be added to {@link CommandMessageTest}. - */ -class CommandMessageSpecification extends Specification { - - def namespace = new MongoNamespace('db.test') - def command = new BsonDocument('find', new BsonString(namespace.collectionName)) - def fieldNameValidator = NoOpFieldNameValidator.INSTANCE - - def 'should encode command message with OP_MSG when server version is >= 3.6'() { - given: - def message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, readPreference, - MessageSettings.builder() - .maxWireVersion(LATEST_WIRE_VERSION) - .serverType(serverType as ServerType) - .sessionSupported(true) - .build(), - responseExpected, MessageSequences.EmptyMessageSequences.INSTANCE, clusterConnectionMode, null) - def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) - - when: - message.encode(output, operationContext) - - then: - def byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - def messageHeader = new MessageHeader(byteBuf, 512) - def replyHeader = new ReplyHeader(byteBuf, messageHeader) - messageHeader.opCode == OpCode.OP_MSG.value - replyHeader.requestId < RequestMessage.currentGlobalId - replyHeader.responseTo == 0 - replyHeader.hasMoreToCome() != responseExpected - - def expectedCommandDocument = command.clone() - .append('$db', new BsonString(namespace.databaseName)) - - if (operationContext.getSessionContext().clusterTime != null) { - expectedCommandDocument.append('$clusterTime', operationContext.getSessionContext().clusterTime) - } - if (operationContext.getSessionContext().hasSession() && responseExpected) { - expectedCommandDocument.append('lsid', operationContext.getSessionContext().sessionId) - } - - if (readPreference != ReadPreference.primary()) { - expectedCommandDocument.append('$readPreference', readPreference.toDocument()) - } else if (clusterConnectionMode == ClusterConnectionMode.SINGLE && serverType != ServerType.SHARD_ROUTER) { - expectedCommandDocument.append('$readPreference', ReadPreference.primaryPreferred().toDocument()) - } - getCommandDocument(byteBuf, replyHeader) == expectedCommandDocument - - cleanup: - output.close() - - where: - [readPreference, serverType, clusterConnectionMode, operationContext, responseExpected, isCryptd] << [ - [ReadPreference.primary(), ReadPreference.secondary()], - [ServerType.REPLICA_SET_PRIMARY, ServerType.SHARD_ROUTER], - [ClusterConnectionMode.SINGLE, ClusterConnectionMode.MULTIPLE], - [ - new OperationContext( - IgnorableRequestContext.INSTANCE, - Stub(SessionContext) { - hasSession() >> false - getClusterTime() >> null - getSessionId() >> new BsonDocument('id', new BsonBinary([1, 2, 3] as byte[])) - getReadConcern() >> ReadConcern.DEFAULT - }, Stub(TimeoutContext), null), - new OperationContext( - IgnorableRequestContext.INSTANCE, - Stub(SessionContext) { - hasSession() >> false - getClusterTime() >> new BsonDocument('clusterTime', new BsonTimestamp(42, 1)) - getReadConcern() >> ReadConcern.DEFAULT - }, Stub(TimeoutContext), null), - new OperationContext( - IgnorableRequestContext.INSTANCE, - Stub(SessionContext) { - hasSession() >> true - getClusterTime() >> null - getSessionId() >> new BsonDocument('id', new BsonBinary([1, 2, 3] as byte[])) - getReadConcern() >> ReadConcern.DEFAULT - }, Stub(TimeoutContext), null), - new OperationContext( - IgnorableRequestContext.INSTANCE, - Stub(SessionContext) { - hasSession() >> true - getClusterTime() >> new BsonDocument('clusterTime', new BsonTimestamp(42, 1)) - getSessionId() >> new BsonDocument('id', new BsonBinary([1, 2, 3] as byte[])) - getReadConcern() >> ReadConcern.DEFAULT - }, Stub(TimeoutContext), null) - ], - [true, false], - [true, false] - ].combinations() - } - - String getString(final ByteBuf byteBuf) { - def byteArrayOutputStream = new ByteArrayOutputStream() - def cur = byteBuf.get() - while (cur != 0) { - byteArrayOutputStream.write(cur) - cur = byteBuf.get() - } - new String(byteArrayOutputStream.toByteArray(), 'UTF-8') - } - - def 'should get command document'() { - given: - def message = new CommandMessage(namespace.getDatabaseName(), originalCommandDocument, fieldNameValidator, - ReadPreference.primary(), MessageSettings.builder().maxWireVersion(maxWireVersion).build(), true, - payload == null ? MessageSequences.EmptyMessageSequences.INSTANCE : payload, - ClusterConnectionMode.MULTIPLE, null) - def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, - Stub(TimeoutContext), null)) - - when: - def commandDocument = message.getCommandDocument(output) - - def expectedCommandDocument = new BsonDocument('insert', new BsonString('coll')).append('documents', - new BsonArray([new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))])) - expectedCommandDocument.append('$db', new BsonString(namespace.getDatabaseName())) - then: - commandDocument == expectedCommandDocument - - - where: - [maxWireVersion, originalCommandDocument, payload] << [ - [ - LATEST_WIRE_VERSION, - new BsonDocument('insert', new BsonString('coll')), - new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)), - new BsonDocument('_id', new BsonInt32(2))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, - true, NoOpFieldNameValidator.INSTANCE), - ], - [ - LATEST_WIRE_VERSION, - new BsonDocument('insert', new BsonString('coll')).append('documents', - new BsonArray([new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))])), - null - ] - ] - } - - def 'should respect the max message size'() { - given: - def maxMessageSize = 1024 - def messageSettings = MessageSettings.builder().maxMessageSize(maxMessageSize).maxWireVersion(LATEST_WIRE_VERSION).build() - def insertCommand = new BsonDocument('insert', new BsonString(namespace.collectionName)) - def payload = new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)).append('a', new BsonBinary(new byte[913])), - new BsonDocument('_id', new BsonInt32(2)).append('b', new BsonBinary(new byte[441])), - new BsonDocument('_id', new BsonInt32(3)).append('c', new BsonBinary(new byte[450])), - new BsonDocument('_id', new BsonInt32(4)).append('b', new BsonBinary(new byte[441])), - new BsonDocument('_id', new BsonInt32(5)).append('c', new BsonBinary(new byte[451]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) - def message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(), - messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null) - def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) - def sessionContext = Stub(SessionContext) { - getReadConcern() >> ReadConcern.DEFAULT - } - - when: - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, - Stub(TimeoutContext), null)) - def byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - def messageHeader = new MessageHeader(byteBuf, maxMessageSize) - - then: - messageHeader.opCode == OpCode.OP_MSG.value - messageHeader.requestId < RequestMessage.currentGlobalId - messageHeader.responseTo == 0 - messageHeader.messageLength == 1024 - byteBuf.getInt() == 0 - payload.getPosition() == 1 - payload.hasAnotherSplit() - - when: - payload = payload.getNextSplit() - message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(), - messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null) - output.truncateToPosition(0) - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) - byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - messageHeader = new MessageHeader(byteBuf, maxMessageSize) - - then: - messageHeader.opCode == OpCode.OP_MSG.value - messageHeader.requestId < RequestMessage.currentGlobalId - messageHeader.responseTo == 0 - messageHeader.messageLength == 1024 - byteBuf.getInt() == 0 - payload.getPosition() == 2 - payload.hasAnotherSplit() - - when: - payload = payload.getNextSplit() - message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(), - messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null) - output.truncateToPosition(0) - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) - byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - messageHeader = new MessageHeader(byteBuf, maxMessageSize) - - then: - messageHeader.opCode == OpCode.OP_MSG.value - messageHeader.requestId < RequestMessage.currentGlobalId - messageHeader.responseTo == 0 - messageHeader.messageLength == 552 - byteBuf.getInt() == 0 - payload.getPosition() == 1 - payload.hasAnotherSplit() - - when: - payload = payload.getNextSplit() - message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(), - messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null) - output.truncateToPosition(0) - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, - sessionContext, - Stub(TimeoutContext), - null)) - byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - messageHeader = new MessageHeader(byteBuf, maxMessageSize) - - then: - messageHeader.opCode == OpCode.OP_MSG.value - messageHeader.requestId < RequestMessage.currentGlobalId - messageHeader.responseTo == 0 - messageHeader.messageLength == 562 - byteBuf.getInt() == 1 << 1 - payload.getPosition() == 1 - !payload.hasAnotherSplit() - - cleanup: - output.close() - } - - def 'should respect the max batch count'() { - given: - def messageSettings = MessageSettings.builder().maxBatchCount(2).maxWireVersion(LATEST_WIRE_VERSION).build() - def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900])), - new BsonDocument('b', new BsonBinary(new byte[450])), - new BsonDocument('c', new BsonBinary(new byte[450]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) - def message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, ReadPreference.primary(), - messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null) - def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) - def sessionContext = Stub(SessionContext) { - getReadConcern() >> ReadConcern.DEFAULT - } - - when: - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, - Stub(TimeoutContext), - null)) - def byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - def messageHeader = new MessageHeader(byteBuf, 2048) - - then: - messageHeader.opCode == OpCode.OP_MSG.value - messageHeader.requestId < RequestMessage.currentGlobalId - messageHeader.responseTo == 0 - messageHeader.messageLength == 1497 - byteBuf.getInt() == 0 - payload.getPosition() == 2 - payload.hasAnotherSplit() - - when: - payload = payload.getNextSplit() - message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, payload, ClusterConnectionMode.MULTIPLE, null) - output.truncateToPosition(0) - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, - Stub(TimeoutContext), null)) - byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) - messageHeader = new MessageHeader(byteBuf, 1024) - - then: - messageHeader.opCode == OpCode.OP_MSG.value - messageHeader.requestId < RequestMessage.currentGlobalId - messageHeader.responseTo == 0 - byteBuf.getInt() == 1 << 1 - payload.getPosition() == 1 - !payload.hasAnotherSplit() - - cleanup: - output.close() - } - - def 'should throw if payload document bigger than max document size'() { - given: - def messageSettings = MessageSettings.builder().maxDocumentSize(900) - .maxWireVersion(LATEST_WIRE_VERSION).build() - def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) - def message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, ReadPreference.primary(), - messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null) - def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) - def sessionContext = Stub(SessionContext) { - getReadConcern() >> ReadConcern.DEFAULT - } - - when: - message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, - Stub(TimeoutContext), null)) - - then: - thrown(BsonMaximumSizeExceededException) - - cleanup: - output.close() - } - - private static BsonDocument getCommandDocument(ByteBufNIO byteBuf, ReplyHeader replyHeader) { - new ReplyMessage(new ResponseBuffers(replyHeader, byteBuf), new BsonDocumentCodec(), 0).document - } -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java index 091518c715c..e5eab18869b 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java @@ -27,6 +27,8 @@ import com.mongodb.internal.IgnorableRequestContext; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import com.mongodb.internal.bulk.InsertRequest; +import com.mongodb.internal.bulk.WriteRequestWithIndex; import com.mongodb.internal.client.model.bulk.ConcreteClientBulkWriteOptions; import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.operation.ClientBulkWriteOperation; @@ -34,11 +36,14 @@ import com.mongodb.internal.session.SessionContext; import com.mongodb.internal.validator.NoOpFieldNameValidator; import org.bson.BsonArray; +import org.bson.BsonBinary; import org.bson.BsonBoolean; import org.bson.BsonDocument; import org.bson.BsonInt32; +import org.bson.BsonMaximumSizeExceededException; import org.bson.BsonString; import org.bson.BsonTimestamp; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.List; @@ -53,17 +58,20 @@ import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +@DisplayName("CommandMessage") class CommandMessageTest { private static final MongoNamespace NAMESPACE = new MongoNamespace("db.test"); private static final BsonDocument COMMAND = new BsonDocument("find", new BsonString(NAMESPACE.getCollectionName())); @Test + @DisplayName("encode should throw timeout exception when timeout context is called") void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() { //given CommandMessage commandMessage = new CommandMessage(NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), @@ -91,6 +99,7 @@ void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() { } @Test + @DisplayName("encode should not add extra elements from timeout context when connected to mongocryptd") void encodeShouldNotAddExtraElementsFromTimeoutContextWhenConnectedToMongoCrypt() { //given CommandMessage commandMessage = new CommandMessage(NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), @@ -126,6 +135,7 @@ void encodeShouldNotAddExtraElementsFromTimeoutContextWhenConnectedToMongoCrypt( } @Test + @DisplayName("get command document from client bulk write operation") void getCommandDocumentFromClientBulkWrite() { MongoNamespace ns = new MongoNamespace("db", "test"); boolean retryWrites = false; @@ -164,8 +174,466 @@ void getCommandDocumentFromClientBulkWrite() { new OperationContext( IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, new TimeoutContext(TimeoutSettings.DEFAULT), null)); - BsonDocument actualCommandDocument = commandMessage.getCommandDocument(output); - assertEquals(expectedCommandDocument, actualCommandDocument); + + try (ByteBufBsonDocument actualCommandDocument = commandMessage.getCommandDocument(output)) { + assertEquals(expectedCommandDocument, actualCommandDocument); + } + } + } + + @Test + @DisplayName("get command document with payload containing documents") + void getCommandDocumentWithPayload() { + // given + BsonDocument originalCommandDocument = new BsonDocument("insert", new BsonString("coll")); + List documents = asList( + new BsonDocument("_id", new BsonInt32(1)), + new BsonDocument("_id", new BsonInt32(2)) + ); + List requestsFromDocs = IntStream.range(0, documents.size()) + .mapToObj(i -> new WriteRequestWithIndex(new InsertRequest(documents.get(i)), i)) + .collect(Collectors.toList()); + + SplittablePayload payload = new SplittablePayload( + SplittablePayload.Type.INSERT, + requestsFromDocs, + true, + NoOpFieldNameValidator.INSTANCE + ); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), originalCommandDocument, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), true, + payload, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode( + output, + new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null) + ); + + // when + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + // then + assertEquals("coll", commandDoc.getString("insert").getValue()); + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + BsonArray docsArray = commandDoc.getArray("documents"); + assertEquals(2, docsArray.size()); + } + } + } + + @Test + @DisplayName("get command document with pre-encoded documents") + void getCommandDocumentWithPreEncodedDocuments() { + // given + BsonDocument originalCommandDocument = new BsonDocument("insert", new BsonString("coll")) + .append("documents", new BsonArray(asList( + new BsonDocument("_id", new BsonInt32(1)), + new BsonDocument("_id", new BsonInt32(2)) + ))); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), originalCommandDocument, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), true, + EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode( + output, + new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null) + ); + + // when + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + // then + assertEquals("coll", commandDoc.getString("insert").getValue()); + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + BsonArray docsArray = commandDoc.getArray("documents"); + assertEquals(2, docsArray.size()); + } + } + } + + @Test + @DisplayName("encode respects max message size constraint") + void encodeShouldRespectMaxMessageSize() { + // given + int maxMessageSize = 1024; + MessageSettings messageSettings = MessageSettings.builder() + .maxMessageSize(maxMessageSize) + .maxWireVersion(LATEST_WIRE_VERSION) + .build(); + BsonDocument insertCommand = new BsonDocument("insert", new BsonString(NAMESPACE.getCollectionName())); + + List requests = asList( + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(1)).append("a", new BsonBinary(new byte[913]))), + 0), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(2)).append("b", new BsonBinary(new byte[441]))), + 1), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(3)).append("c", new BsonBinary(new byte[450]))), + 2), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(4)).append("b", new BsonBinary(new byte[441]))), + 3), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(5)).append("c", new BsonBinary(new byte[451]))), + 4) + ); + + SplittablePayload payload = new SplittablePayload( + SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE + ); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), insertCommand, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + // when - encode first batch + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + // then - first batch respects size constraint + assertTrue(output.size() <= maxMessageSize, "Output size " + output.size() + " should not exceed max " + maxMessageSize); + assertEquals(1, payload.getPosition()); + + // Verify multiple splits were created + assertTrue(payload.hasAnotherSplit()); + } + } + + @Test + @DisplayName("encode respects max batch count constraint") + void encodeShouldRespectMaxBatchCount() { + // given + MessageSettings messageSettings = MessageSettings.builder() + .maxBatchCount(2) + .maxWireVersion(LATEST_WIRE_VERSION) + .build(); + + List requests = asList( + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("a", new BsonBinary(new byte[900]))), + 0), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("b", new BsonBinary(new byte[450]))), + 1), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("c", new BsonBinary(new byte[450]))), + 2) + ); + + SplittablePayload payload = new SplittablePayload( + SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE + ); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + // when - encode first batch with max 2 documents + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + // then - first batch has 2 documents + assertEquals(2, payload.getPosition()); + assertTrue(payload.hasAnotherSplit()); } } + + @Test + @DisplayName("encode throws exception when payload document exceeds max document size") + void encodeShouldThrowWhenPayloadDocumentExceedsMaxSize() { + // given + MessageSettings messageSettings = MessageSettings.builder() + .maxDocumentSize(900) + .maxWireVersion(LATEST_WIRE_VERSION) + .build(); + + List requests = singletonList( + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("a", new BsonBinary(new byte[900]))), + 0) + ); + + SplittablePayload payload = new SplittablePayload( + SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE + ); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + // when & then + assertThrows(BsonMaximumSizeExceededException.class, () -> + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )) + ); + } + } + + @Test + @DisplayName("encode message with cluster time encodes successfully") + void encodeWithClusterTime() { + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + assertTrue(output.size() > 0, "Output should contain encoded message"); + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + } + } + } + + @Test + @DisplayName("encode message with active session encodes successfully") + void encodeWithActiveSession() { + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + assertTrue(output.size() > 0, "Output should contain encoded message"); + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + } + } + } + + @Test + @DisplayName("encode message with secondary read preference encodes successfully") + void encodeWithSecondaryReadPreference() { + ReadPreference secondary = ReadPreference.secondary(); + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + secondary, + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + assertTrue(output.size() > 0, "Output should contain encoded message"); + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + } + } + } + + @Test + @DisplayName("encode message in single cluster mode encodes successfully") + void encodeInSingleClusterMode() { + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder() + .maxWireVersion(LATEST_WIRE_VERSION) + .serverType(ServerType.REPLICA_SET_PRIMARY) + .build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.SINGLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + assertTrue(output.size() > 0, "Output should contain encoded message"); + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + } + } + } + + @Test + @DisplayName("encode includes database name in command document") + void encodeIncludesDatabaseName() { + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue()); + } + } + } + + @Test + @DisplayName("command document can be accessed multiple times") + void commandDocumentCanBeAccessedMultipleTimes() { + BsonDocument originalCommand = new BsonDocument("find", new BsonString("coll")) + .append("filter", new BsonDocument("_id", new BsonInt32(1))); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), originalCommand, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + // Access same fields multiple times + assertEquals("coll", commandDoc.getString("find").getValue()); + assertEquals("coll", commandDoc.getString("find").getValue()); + BsonDocument filter = commandDoc.getDocument("filter"); + BsonDocument filter2 = commandDoc.getDocument("filter"); + assertEquals(filter, filter2); + } + } + } + + @Test + @DisplayName("encode with multiple document sequences creates proper arrays") + void encodeWithMultipleDocumentsInSequence() { + BsonDocument insertCommand = new BsonDocument("insert", new BsonString("coll")); + List requests = asList( + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(1)).append("name", new BsonString("doc1"))), + 0), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(2)).append("name", new BsonString("doc2"))), + 1), + new WriteRequestWithIndex( + new InsertRequest(new BsonDocument("_id", new BsonInt32(3)).append("name", new BsonString("doc3"))), + 2) + ); + + SplittablePayload payload = new SplittablePayload( + SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE + ); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), insertCommand, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, payload, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + BsonArray documents = commandDoc.getArray("documents"); + assertEquals(3, documents.size()); + assertEquals(1, documents.get(0).asDocument().getInt32("_id").getValue()); + assertEquals(2, documents.get(1).asDocument().getInt32("_id").getValue()); + assertEquals(3, documents.get(2).asDocument().getInt32("_id").getValue()); + } + } + } + + @Test + @DisplayName("encode with response not expected sets continuation flag") + void encodeWithResponseNotExpected() { + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + false, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + // Verify encoded message has continuation flag (0x02) + assertTrue(output.size() > 0, "Output should contain encoded message"); + } + } + + @Test + @DisplayName("encode preserves original command structure") + void encodePreservesCommandStructure() { + BsonDocument complexCommand = new BsonDocument("aggregate", new BsonString("coll")) + .append("pipeline", new BsonArray(asList( + new BsonDocument("$match", new BsonDocument("status", new BsonString("active"))), + new BsonDocument("$group", new BsonDocument("_id", new BsonString("$category"))) + ))) + .append("cursor", new BsonDocument("batchSize", new BsonInt32(100))); + + CommandMessage message = new CommandMessage( + NAMESPACE.getDatabaseName(), complexCommand, NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), + true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null + ); + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + message.encode(output, new OperationContext( + IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(TimeoutSettings.DEFAULT), null + )); + + try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) { + assertEquals("coll", commandDoc.getString("aggregate").getValue()); + BsonArray pipeline = commandDoc.getArray("pipeline"); + assertEquals(2, pipeline.size()); + BsonDocument cursor = commandDoc.getDocument("cursor"); + assertEquals(100, cursor.getInt32("batchSize").getValue()); + } + } + } + } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy index e6f6afb02e0..8e7a7b9d78d 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy @@ -64,8 +64,10 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> debugLoggingEnabled } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) + def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, commandListener, - operationContext, message, message.getCommandDocument(bsonOutput), + operationContext, message, commandMessageDocument, new StructuredLogger(logger), LoggerSettings.builder().build()) when: @@ -87,6 +89,9 @@ class LoggingCommandEventSenderSpecification extends Specification { database, commandDocument.getFirstKey(), 1, failureException) ]) + cleanup: + commandMessageDocument?.close() + where: debugLoggingEnabled << [true, false] } @@ -110,8 +115,10 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> true } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) + def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, commandListener, - operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), + operationContext, message, commandMessageDocument, new StructuredLogger(logger), LoggerSettings.builder().build()) when: sender.sendStartedEvent() @@ -146,6 +153,9 @@ class LoggingCommandEventSenderSpecification extends Specification { "request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}.") }, failureException) + cleanup: + commandMessageDocument?.close() + where: commandListener << [null, Stub(CommandListener)] } @@ -167,6 +177,7 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> true } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, null, operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), LoggerSettings.builder().build()) @@ -182,6 +193,9 @@ class LoggingCommandEventSenderSpecification extends Specification { "request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}. " + "Command: {\"fake\": {\"\$binary\": {\"base64\": \"${'A' * 967} ..." } + + cleanup: + commandMessageDocument?.close() } def 'should log redacted command with ellipses'() { @@ -201,8 +215,9 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> true } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) def sender = new LoggingCommandEventSender(['createUser'] as Set, [] as Set, connectionDescription, null, - operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), + operationContext, message, commandMessageDocument, new StructuredLogger(logger), LoggerSettings.builder().build()) when: @@ -215,5 +230,8 @@ class LoggingCommandEventSenderSpecification extends Specification { "${connectionDescription.connectionId.serverValue} to 127.0.0.1:27017. The " + "request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}. Command: {}" } + + cleanup: + commandMessageDocument?.close() } } From 945bdd70ad6c7803be838d00e526fe9049dc2692 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 10 Mar 2026 13:21:34 +0000 Subject: [PATCH 2/4] `ByteBufBsonDocument` and `RawBsonDocument` simplifications (#1902) ### Rationale `ByteBufBsonDocument#clone` used to return a `RawBsonDocument`. The recent changes returned a normal `BsonDocument`, which is potentially expensive depending on its usage. The `ByteBufBsonDocument` changes also added complex iterator logic, when `RawBsonDocument` deferred to `BsonDocument` iterators. As iteration is essentially a hydrating mechanism, there is opportunity for improvements for both implementations. By changing the `RawBsonDocument` iterators to be more efficient, `ByteBufBsonDocument` can now utilize these efficiency gains by proxy, relying on the `cachedDocument` iterators. This change both reduces the complexity of `ByteBufBsonDocument` and relies on an improved `RawBsonDocument` implementation. ### Summary of changes * **`ByteBufBsonDocument`**: * Simplify by returning `RawBsonDocument` from `toBsonDocument`, avoiding full BSON deserialization. When there are no sequence fields, the body bytes are cloned directly. When sequence fields exist, `BsonBinaryWriter.pipe()` merges the body with sequence arrays efficiently. * Use `toBsonDocument` for iterators. This eliminates the need for custom iterators (`IteratorMode`, `CombinedIterator`, `createBodyIterator`, and sequence iterators) since `entrySet`/`values`/`keySet` now delegate to the cached `RawBsonDocument`. * **`RawBsonDocument`**: * Renamed `toBaseBsonDocument` to override the default `toBsonDocument` implementation. * Implemented the iterators so that they don't need to fully convert the document to a `BsonDocument`. * **Tests**: * Updated `ByteBufBsonDocumentTest` iteration tests. * Updated `ByteBufBsonArrayTest#fromValues` as `entrySet` now returns `RawBsonDocument` instances. JAVA-6010 --------- Co-authored-by: Claude Opus 4.6 --- bson/src/main/org/bson/RawBsonDocument.java | 66 ++- .../bson/RawBsonDocumentSpecification.groovy | 494 ------------------ .../unit/org/bson/RawBsonDocumentTest.java | 420 +++++++++++++++ .../connection/ByteBufBsonDocument.java | 330 ++---------- .../connection/ByteBufBsonArrayTest.java | 24 +- .../connection/ByteBufBsonDocumentTest.java | 58 +- 6 files changed, 543 insertions(+), 849 deletions(-) delete mode 100644 bson/src/test/unit/org/bson/RawBsonDocumentSpecification.groovy create mode 100644 bson/src/test/unit/org/bson/RawBsonDocumentTest.java diff --git a/bson/src/main/org/bson/RawBsonDocument.java b/bson/src/main/org/bson/RawBsonDocument.java index eb672bcef8d..00e4508f10f 100644 --- a/bson/src/main/org/bson/RawBsonDocument.java +++ b/bson/src/main/org/bson/RawBsonDocument.java @@ -35,7 +35,11 @@ import java.io.StringWriter; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.AbstractMap; +import java.util.ArrayList; import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; @@ -123,12 +127,13 @@ public RawBsonDocument(final byte[] bytes, final int offset, final int length) { public RawBsonDocument(final T document, final Codec codec) { notNull("document", document); notNull("codec", codec); - BasicOutputBuffer buffer = new BasicOutputBuffer(); - try (BsonBinaryWriter writer = new BsonBinaryWriter(buffer)) { - codec.encode(writer, document, EncoderContext.builder().build()); - this.bytes = buffer.getInternalBuffer(); - this.offset = 0; - this.length = buffer.getPosition(); + try (BasicOutputBuffer buffer = new BasicOutputBuffer()) { + try (BsonBinaryWriter writer = new BsonBinaryWriter(buffer)) { + codec.encode(writer, document, EncoderContext.builder().build()); + this.bytes = buffer.getInternalBuffer(); + this.offset = 0; + this.length = buffer.getPosition(); + } } } @@ -225,17 +230,42 @@ public int size() { @Override public Set> entrySet() { - return toBaseBsonDocument().entrySet(); + List> entries = new ArrayList<>(); + try (BsonBinaryReader bsonReader = createReader()) { + bsonReader.readStartDocument(); + while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) { + String key = bsonReader.readName(); + BsonValue value = RawBsonValueHelper.decode(bytes, bsonReader); + entries.add(new AbstractMap.SimpleImmutableEntry<>(key, value)); + } + } + return new LinkedHashSet<>(entries); } @Override public Collection values() { - return toBaseBsonDocument().values(); + List values = new ArrayList<>(); + try (BsonBinaryReader bsonReader = createReader()) { + bsonReader.readStartDocument(); + while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) { + bsonReader.skipName(); + values.add(RawBsonValueHelper.decode(bytes, bsonReader)); + } + } + return values; } @Override public Set keySet() { - return toBaseBsonDocument().keySet(); + List keys = new ArrayList<>(); + try (BsonBinaryReader bsonReader = createReader()) { + bsonReader.readStartDocument(); + while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) { + keys.add(bsonReader.readName()); + bsonReader.skipValue(); + } + } + return new LinkedHashSet<>(keys); } @Override @@ -318,12 +348,19 @@ public String toJson(final JsonWriterSettings settings) { @Override public boolean equals(final Object o) { - return toBaseBsonDocument().equals(o); + return toBsonDocument().equals(o); } @Override public int hashCode() { - return toBaseBsonDocument().hashCode(); + return toBsonDocument().hashCode(); + } + + @Override + public BsonDocument toBsonDocument() { + try (BsonBinaryReader bsonReader = createReader()) { + return new BsonDocumentCodec().decode(bsonReader, DecoderContext.builder().build()); + } } @Override @@ -335,13 +372,6 @@ private BsonBinaryReader createReader() { return new BsonBinaryReader(new ByteBufferBsonInput(getByteBuffer())); } - // Transform to an org.bson.BsonDocument instance - private BsonDocument toBaseBsonDocument() { - try (BsonBinaryReader bsonReader = createReader()) { - return new BsonDocumentCodec().decode(bsonReader, DecoderContext.builder().build()); - } - } - /** * Write the replacement object. * diff --git a/bson/src/test/unit/org/bson/RawBsonDocumentSpecification.groovy b/bson/src/test/unit/org/bson/RawBsonDocumentSpecification.groovy deleted file mode 100644 index a23ec06dedb..00000000000 --- a/bson/src/test/unit/org/bson/RawBsonDocumentSpecification.groovy +++ /dev/null @@ -1,494 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.bson - -import org.bson.codecs.BsonDocumentCodec -import org.bson.codecs.DecoderContext -import org.bson.codecs.DocumentCodec -import org.bson.codecs.EncoderContext -import org.bson.codecs.RawBsonDocumentCodec -import org.bson.io.BasicOutputBuffer -import org.bson.json.JsonMode -import org.bson.json.JsonReader -import org.bson.json.JsonWriter -import org.bson.json.JsonWriterSettings -import spock.lang.Specification - -import java.nio.ByteOrder - -import static java.util.Arrays.asList -import static util.GroovyHelpers.areEqual - -class RawBsonDocumentSpecification extends Specification { - - static emptyDocument = new BsonDocument() - static emptyRawDocument = new RawBsonDocument(emptyDocument, new BsonDocumentCodec()) - static document = new BsonDocument() - .append('a', new BsonInt32(1)) - .append('b', new BsonInt32(2)) - .append('c', new BsonDocument('x', BsonBoolean.TRUE)) - .append('d', new BsonArray(asList(new BsonDocument('y', BsonBoolean.FALSE), new BsonArray(asList(new BsonInt32(1)))))) - - def 'constructors should throw if parameters are invalid'() { - when: - new RawBsonDocument(null) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(null, 0, 5) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(new byte[5], -1, 5) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(new byte[5], 5, 5) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(new byte[5], 0, 0) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(new byte[10], 6, 5) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(null, new DocumentCodec()) - - then: - thrown(IllegalArgumentException) - - when: - new RawBsonDocument(new Document(), null) - - then: - thrown(IllegalArgumentException) - } - - def 'byteBuffer should contain the correct bytes'() { - when: - def byteBuf = rawDocument.getByteBuffer() - - then: - rawDocument == document - byteBuf.asNIO().order() == ByteOrder.LITTLE_ENDIAN - byteBuf.remaining() == 66 - - when: - def actualBytes = new byte[66] - byteBuf.get(actualBytes) - - then: - actualBytes == getBytesFromDocument() - - where: - rawDocument << createRawDocumentVariants() - } - - def 'parse should through if parameter is invalid'() { - when: - RawBsonDocument.parse(null) - - then: - thrown(IllegalArgumentException) - } - - def 'should parse json'() { - expect: - RawBsonDocument.parse('{a : 1}') == new BsonDocument('a', new BsonInt32(1)) - } - - def 'containKey should throw if the key name is null'() { - when: - rawDocument.containsKey(null) - - then: - thrown(IllegalArgumentException) - - where: - rawDocument << createRawDocumentVariants() - } - - def 'containsKey should find an existing key'() { - expect: - rawDocument.containsKey('a') - rawDocument.containsKey('b') - rawDocument.containsKey('c') - rawDocument.containsKey('d') - - where: - rawDocument << createRawDocumentVariants() - } - - def 'containsKey should not find a non-existing key'() { - expect: - !rawDocument.containsKey('e') - !rawDocument.containsKey('x') - !rawDocument.containsKey('y') - rawDocument.get('e') == null - rawDocument.get('x') == null - rawDocument.get('y') == null - - where: - rawDocument << createRawDocumentVariants() - } - - def 'should return RawBsonDocument for sub documents and RawBsonArray for arrays'() { - expect: - rawDocument.get('a') instanceof BsonInt32 - rawDocument.get('b') instanceof BsonInt32 - rawDocument.get('c') instanceof RawBsonDocument - rawDocument.get('d') instanceof RawBsonArray - rawDocument.get('d').asArray().get(0) instanceof RawBsonDocument - rawDocument.get('d').asArray().get(1) instanceof RawBsonArray - - and: - rawDocument.getDocument('c').getBoolean('x').value - !rawDocument.get('d').asArray().get(0).asDocument().getBoolean('y').value - rawDocument.get('d').asArray().get(1).asArray().get(0).asInt32().value == 1 - - where: - rawDocument << createRawDocumentVariants() - } - - def 'containValue should find an existing value'() { - expect: - rawDocument.containsValue(document.get('a')) - rawDocument.containsValue(document.get('b')) - rawDocument.containsValue(document.get('c')) - rawDocument.containsValue(document.get('d')) - - where: - rawDocument << createRawDocumentVariants() - } - - def 'containValue should not find a non-existing value'() { - expect: - !rawDocument.containsValue(new BsonInt32(3)) - !rawDocument.containsValue(new BsonDocument('e', BsonBoolean.FALSE)) - !rawDocument.containsValue(new BsonArray(asList(new BsonInt32(2), new BsonInt32(4)))) - - where: - rawDocument << createRawDocumentVariants() - } - - def 'isEmpty should return false when the document is not empty'() { - expect: - !rawDocument.isEmpty() - - where: - rawDocument << createRawDocumentVariants() - } - - def 'isEmpty should return true when the document is empty'() { - expect: - emptyRawDocument.isEmpty() - } - - def 'should get correct size when the document is empty'() { - expect: - emptyRawDocument.size() == 0 - } - - def 'should get correct key set when the document is empty'() { - expect: - emptyRawDocument.keySet().isEmpty() - } - - def 'should get correct values set when the document is empty'() { - expect: - emptyRawDocument.values().isEmpty() - } - - def 'should get correct entry set when the document is empty'() { - expect: - emptyRawDocument.entrySet().isEmpty() - } - - def 'should get correct size'() { - expect: - createRawDocumenFromDocument().size() == 4 - - where: - rawDocument << createRawDocumentVariants() - } - - def 'should get correct key set'() { - expect: - rawDocument.keySet() == ['a', 'b', 'c', 'd'] as Set - - where: - rawDocument << createRawDocumentVariants() - } - - def 'should get correct values set'() { - expect: - rawDocument.values() as Set == [document.get('a'), document.get('b'), document.get('c'), document.get('d')] as Set - - where: - rawDocument << createRawDocumentVariants() - } - - def 'should get correct entry set'() { - expect: - rawDocument.entrySet() == [new TestEntry('a', document.get('a')), - new TestEntry('b', document.get('b')), - new TestEntry('c', document.get('c')), - new TestEntry('d', document.get('d'))] as Set - - where: - rawDocument << createRawDocumentVariants() - } - - def 'should get first key'() { - expect: - document.getFirstKey() == 'a' - - where: - rawDocument << createRawDocumentVariants() - } - - def 'getFirstKey should throw NoSuchElementException if the document is empty'() { - when: - emptyRawDocument.getFirstKey() - - then: - thrown(NoSuchElementException) - } - - def 'should create BsonReader'() { - when: - def reader = document.asBsonReader() - - then: - new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()) == document - - cleanup: - reader.close() - } - - def 'toJson should return equivalent JSON'() { - expect: - new RawBsonDocumentCodec().decode(new JsonReader(rawDocument.toJson()), DecoderContext.builder().build()) == document - - where: - rawDocument << createRawDocumentVariants() - } - - def 'toJson should respect default JsonWriterSettings'() { - given: - def writer = new StringWriter() - - when: - new BsonDocumentCodec().encode(new JsonWriter(writer), document, EncoderContext.builder().build()) - - then: - writer.toString() == rawDocument.toJson() - - where: - rawDocument << createRawDocumentVariants() - } - - def 'toJson should respect JsonWriterSettings'() { - given: - def jsonWriterSettings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build() - def writer = new StringWriter() - - when: - new RawBsonDocumentCodec().encode(new JsonWriter(writer, jsonWriterSettings), rawDocument, EncoderContext.builder().build()) - - then: - writer.toString() == rawDocument.toJson(jsonWriterSettings) - - where: - rawDocument << createRawDocumentVariants() - } - - def 'all write methods should throw UnsupportedOperationException'() { - given: - def rawDocument = createRawDocumenFromDocument() - - when: - rawDocument.clear() - - then: - thrown(UnsupportedOperationException) - - when: - rawDocument.put('x', BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - - when: - rawDocument.append('x', BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - - when: - rawDocument.putAll(new BsonDocument('x', BsonNull.VALUE)) - - then: - thrown(UnsupportedOperationException) - - when: - rawDocument.remove(BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - } - - def 'should decode'() { - rawDocument.decode(new BsonDocumentCodec()) == document - - where: - rawDocument << createRawDocumentVariants() - } - - def 'hashCode should equal hash code of identical BsonDocument'() { - expect: - rawDocument.hashCode() == document.hashCode() - - where: - rawDocument << createRawDocumentVariants() - } - - def 'equals should equal identical BsonDocument'() { - expect: - areEqual(rawDocument, document) - areEqual(document, rawDocument) - areEqual(rawDocument, rawDocument) - !areEqual(rawDocument, emptyRawDocument) - - where: - rawDocument << createRawDocumentVariants() - } - - def 'clone should make a deep copy'() { - when: - RawBsonDocument cloned = rawDocument.clone() - - then: - !cloned.getByteBuffer().array().is(createRawDocumenFromDocument().getByteBuffer().array()) - cloned.getByteBuffer().remaining() == rawDocument.getByteBuffer().remaining() - cloned == createRawDocumenFromDocument() - - where: - rawDocument << [ - createRawDocumenFromDocument(), - createRawDocumentFromByteArray(), - createRawDocumentFromByteArrayOffsetLength() - ] - } - - def 'should serialize and deserialize'() { - given: - def baos = new ByteArrayOutputStream() - def oos = new ObjectOutputStream(baos) - - when: - oos.writeObject(localRawDocument) - def bais = new ByteArrayInputStream(baos.toByteArray()) - def ois = new ObjectInputStream(bais) - def deserializedDocument = ois.readObject() - - then: - document == deserializedDocument - - where: - localRawDocument << createRawDocumentVariants() - } - - private static List createRawDocumentVariants() { - [ - createRawDocumenFromDocument(), - createRawDocumentFromByteArray(), - createRawDocumentFromByteArrayOffsetLength() - ] - } - - private static RawBsonDocument createRawDocumenFromDocument() { - new RawBsonDocument(document, new BsonDocumentCodec()) - } - - private static RawBsonDocument createRawDocumentFromByteArray() { - byte[] strippedBytes = getBytesFromDocument() - new RawBsonDocument(strippedBytes) - } - - private static byte[] getBytesFromDocument() { - def (int size, byte[] bytes) = getBytesFromOutputBuffer() - def strippedBytes = new byte[size] - System.arraycopy(bytes, 0, strippedBytes, 0, size) - strippedBytes - } - - private static List getBytesFromOutputBuffer() { - def outputBuffer = new BasicOutputBuffer(1024) - new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), document, EncoderContext.builder().build()) - def bytes = outputBuffer.getInternalBuffer() - [outputBuffer.position, bytes] - } - - private static RawBsonDocument createRawDocumentFromByteArrayOffsetLength() { - def (int size, byte[] bytes) = getBytesFromOutputBuffer() - def unstrippedBytes = new byte[size + 2] - System.arraycopy(bytes, 0, unstrippedBytes, 1, size) - new RawBsonDocument(unstrippedBytes, 1, size) - } - - class TestEntry implements Map.Entry { - - private final String key - private BsonValue value - - TestEntry(String key, BsonValue value) { - this.key = key - this.value = value - } - - @Override - String getKey() { - key - } - - @Override - BsonValue getValue() { - value - } - - @Override - BsonValue setValue(final BsonValue value) { - this.value = value - } - } -} diff --git a/bson/src/test/unit/org/bson/RawBsonDocumentTest.java b/bson/src/test/unit/org/bson/RawBsonDocumentTest.java new file mode 100644 index 00000000000..6ebb716e91f --- /dev/null +++ b/bson/src/test/unit/org/bson/RawBsonDocumentTest.java @@ -0,0 +1,420 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.DecoderContext; +import org.bson.codecs.DocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.codecs.RawBsonDocumentCodec; +import org.bson.io.BasicOutputBuffer; +import org.bson.json.JsonMode; +import org.bson.json.JsonReader; +import org.bson.json.JsonWriter; +import org.bson.json.JsonWriterSettings; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.StringWriter; +import java.nio.ByteOrder; +import java.util.AbstractMap; +import java.util.HashSet; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@DisplayName("RawBsonDocument") +class RawBsonDocumentTest { + + private static final BsonDocument EMPTY_DOCUMENT = new BsonDocument(); + private static final RawBsonDocument EMPTY_RAW_DOCUMENT = new RawBsonDocument(EMPTY_DOCUMENT, new BsonDocumentCodec()); + private static final BsonDocument DOCUMENT = new BsonDocument() + .append("a", new BsonInt32(1)) + .append("b", new BsonInt32(1)) + .append("c", new BsonDocument("x", BsonBoolean.TRUE)) + .append("d", new BsonArray(asList(new BsonDocument("y", BsonBoolean.FALSE), new BsonArray(asList(new BsonInt32(1)))))); + + // Constructor Validation + + @Test + @DisplayName("constructors should throw if parameters are invalid") + void constructorsShouldThrowForInvalidParameters() { + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument((byte[]) null)); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(null, 0, 5)); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(new byte[5], -1, 5)); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(new byte[5], 5, 5)); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(new byte[5], 0, 0)); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(new byte[10], 6, 5)); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(null, new DocumentCodec())); + assertThrows(IllegalArgumentException.class, () -> new RawBsonDocument(new Document(), null)); + } + + // Byte Buffer + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("byteBuffer should contain the correct bytes") + void byteBufferShouldContainCorrectBytes(final RawBsonDocument rawDocument) { + ByteBuf byteBuf = rawDocument.getByteBuffer(); + + assertEquals(DOCUMENT, rawDocument); + assertEquals(ByteOrder.LITTLE_ENDIAN, byteBuf.asNIO().order()); + assertEquals(66, byteBuf.remaining()); + + byte[] actualBytes = new byte[66]; + byteBuf.get(actualBytes); + assertArrayEquals(getBytesFromDocument(), actualBytes); + } + + // Parse + + @Test + @DisplayName("parse() should throw if parameter is invalid") + void parseShouldThrowForInvalidParameter() { + assertThrows(IllegalArgumentException.class, () -> RawBsonDocument.parse(null)); + } + + @Test + @DisplayName("parse() should parse JSON") + void parseShouldParseJson() { + assertEquals(new BsonDocument("a", new BsonInt32(1)), RawBsonDocument.parse("{a : 1}")); + } + + // Basic Operations + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("containsKey() throws IllegalArgumentException for null key") + void containsKeyShouldThrowForNullKey(final RawBsonDocument rawDocument) { + assertThrows(IllegalArgumentException.class, () -> rawDocument.containsKey(null)); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("containsKey() finds existing keys") + void containsKeyShouldFindExistingKeys(final RawBsonDocument rawDocument) { + assertTrue(rawDocument.containsKey("a")); + assertTrue(rawDocument.containsKey("b")); + assertTrue(rawDocument.containsKey("c")); + assertTrue(rawDocument.containsKey("d")); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("containsKey() does not find non-existing keys") + void containsKeyShouldNotFindNonExistingKeys(final RawBsonDocument rawDocument) { + assertFalse(rawDocument.containsKey("e")); + assertFalse(rawDocument.containsKey("x")); + assertFalse(rawDocument.containsKey("y")); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("get() returns null for non-existing keys") + void getShouldReturnNullForNonExistingKeys(final RawBsonDocument rawDocument) { + assertEquals(null, rawDocument.get("e")); + assertEquals(null, rawDocument.get("x")); + assertEquals(null, rawDocument.get("y")); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("get() returns RawBsonDocument for sub documents and RawBsonArray for arrays") + void getShouldReturnCorrectTypes(final RawBsonDocument rawDocument) { + assertInstanceOf(BsonInt32.class, rawDocument.get("a")); + assertInstanceOf(BsonInt32.class, rawDocument.get("b")); + assertInstanceOf(RawBsonDocument.class, rawDocument.get("c")); + assertInstanceOf(RawBsonArray.class, rawDocument.get("d")); + assertInstanceOf(RawBsonDocument.class, rawDocument.get("d").asArray().get(0)); + assertInstanceOf(RawBsonArray.class, rawDocument.get("d").asArray().get(1)); + + assertTrue(rawDocument.getDocument("c").getBoolean("x").getValue()); + assertFalse(rawDocument.get("d").asArray().get(0).asDocument().getBoolean("y").getValue()); + assertEquals(1, rawDocument.get("d").asArray().get(1).asArray().get(0).asInt32().getValue()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("containsValue() finds existing values") + void containsValueShouldFindExistingValues(final RawBsonDocument rawDocument) { + assertTrue(rawDocument.containsValue(DOCUMENT.get("a"))); + assertTrue(rawDocument.containsValue(DOCUMENT.get("b"))); + assertTrue(rawDocument.containsValue(DOCUMENT.get("c"))); + assertTrue(rawDocument.containsValue(DOCUMENT.get("d"))); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("containsValue() does not find non-existing values") + void containsValueShouldNotFindNonExistingValues(final RawBsonDocument rawDocument) { + assertFalse(rawDocument.containsValue(new BsonInt32(3))); + assertFalse(rawDocument.containsValue(new BsonDocument("e", BsonBoolean.FALSE))); + assertFalse(rawDocument.containsValue(new BsonArray(asList(new BsonInt32(2), new BsonInt32(4))))); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("isEmpty() returns false when the document is not empty") + void isEmptyShouldReturnFalseForNonEmptyDocument(final RawBsonDocument rawDocument) { + assertFalse(rawDocument.isEmpty()); + } + + @Test + @DisplayName("isEmpty() returns true when the document is empty") + void isEmptyShouldReturnTrueForEmptyDocument() { + assertTrue(EMPTY_RAW_DOCUMENT.isEmpty()); + } + + @Test + @DisplayName("size() returns 0 for empty document") + void sizeShouldReturnZeroForEmptyDocument() { + assertEquals(0, EMPTY_RAW_DOCUMENT.size()); + } + + @Test + @DisplayName("keySet() is empty for empty document") + void keySetShouldBeEmptyForEmptyDocument() { + assertTrue(EMPTY_RAW_DOCUMENT.keySet().isEmpty()); + } + + @Test + @DisplayName("values() is empty for empty document") + void valuesShouldBeEmptyForEmptyDocument() { + assertTrue(EMPTY_RAW_DOCUMENT.values().isEmpty()); + } + + @Test + @DisplayName("entrySet() is empty for empty document") + void entrySetShouldBeEmptyForEmptyDocument() { + assertTrue(EMPTY_RAW_DOCUMENT.entrySet().isEmpty()); + } + + // Collection Views + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("size() returns correct count") + void sizeShouldReturnCorrectCount(final RawBsonDocument rawDocument) { + assertEquals(4, rawDocument.size()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("keySet() returns all keys") + void keySetShouldReturnAllKeys(final RawBsonDocument rawDocument) { + assertEquals(new HashSet<>(asList("a", "b", "c", "d")), rawDocument.keySet()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("values() returns all values") + void valuesShouldReturnAllValues(final RawBsonDocument rawDocument) { + assertEquals( + asList(DOCUMENT.get("a"), DOCUMENT.get("b"), DOCUMENT.get("c"), DOCUMENT.get("d")), + rawDocument.values()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("entrySet() returns all entries") + void entrySetShouldReturnAllEntries(final RawBsonDocument rawDocument) { + Set> expected = new HashSet<>(asList( + new AbstractMap.SimpleImmutableEntry<>("a", DOCUMENT.get("a")), + new AbstractMap.SimpleImmutableEntry<>("b", DOCUMENT.get("b")), + new AbstractMap.SimpleImmutableEntry<>("c", DOCUMENT.get("c")), + new AbstractMap.SimpleImmutableEntry<>("d", DOCUMENT.get("d")) + )); + assertEquals(expected, rawDocument.entrySet()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("getFirstKey() returns first key") + void getFirstKeyShouldReturnFirstKey(final RawBsonDocument rawDocument) { + assertEquals("a", rawDocument.getFirstKey()); + } + + @Test + @DisplayName("getFirstKey() throws NoSuchElementException for empty document") + void getFirstKeyShouldThrowForEmptyDocument() { + assertThrows(NoSuchElementException.class, () -> EMPTY_RAW_DOCUMENT.getFirstKey()); + } + + // Conversion and Serialization + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("asBsonReader() creates valid reader") + void asBsonReaderShouldWork(final RawBsonDocument rawDocument) { + try (BsonReader reader = rawDocument.asBsonReader()) { + BsonDocument decoded = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()); + assertEquals(DOCUMENT, decoded); + } + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("toJson() returns equivalent JSON") + void toJsonShouldReturnEquivalentJson(final RawBsonDocument rawDocument) { + RawBsonDocument reparsed = new RawBsonDocumentCodec().decode( + new JsonReader(rawDocument.toJson()), DecoderContext.builder().build()); + assertEquals(DOCUMENT, reparsed); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("toJson() respects default JsonWriterSettings") + void toJsonShouldRespectDefaultSettings(final RawBsonDocument rawDocument) { + StringWriter writer = new StringWriter(); + new BsonDocumentCodec().encode(new JsonWriter(writer), DOCUMENT, EncoderContext.builder().build()); + assertEquals(writer.toString(), rawDocument.toJson()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("toJson() respects JsonWriterSettings") + void toJsonShouldRespectCustomSettings(final RawBsonDocument rawDocument) { + JsonWriterSettings settings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build(); + StringWriter writer = new StringWriter(); + new RawBsonDocumentCodec().encode(new JsonWriter(writer, settings), rawDocument, EncoderContext.builder().build()); + assertEquals(writer.toString(), rawDocument.toJson(settings)); + } + + // Immutability + + @Test + @DisplayName("All write methods throw UnsupportedOperationException") + void writeMethodsShouldThrow() { + RawBsonDocument rawDocument = createRawDocumentFromDocument(); + assertThrows(UnsupportedOperationException.class, () -> rawDocument.clear()); + assertThrows(UnsupportedOperationException.class, () -> rawDocument.put("x", BsonNull.VALUE)); + assertThrows(UnsupportedOperationException.class, () -> rawDocument.append("x", BsonNull.VALUE)); + assertThrows(UnsupportedOperationException.class, () -> rawDocument.putAll(new BsonDocument("x", BsonNull.VALUE))); + assertThrows(UnsupportedOperationException.class, () -> rawDocument.remove(BsonNull.VALUE)); + } + + // Decode + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("decode() returns equivalent document") + void decodeShouldWork(final RawBsonDocument rawDocument) { + assertEquals(DOCUMENT, rawDocument.decode(new BsonDocumentCodec())); + } + + // Equality and HashCode + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("hashCode() equals hash code of identical BsonDocument") + void hashCodeShouldEqualBsonDocumentHashCode(final RawBsonDocument rawDocument) { + assertEquals(DOCUMENT.hashCode(), rawDocument.hashCode()); + } + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("equals() works correctly") + void equalsShouldWork(final RawBsonDocument rawDocument) { + assertEquals(DOCUMENT, rawDocument); + assertEquals(DOCUMENT, rawDocument); + assertEquals(rawDocument, rawDocument); + assertNotEquals(EMPTY_RAW_DOCUMENT, rawDocument); + } + + // Clone + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("clone() creates deep copy") + void cloneShouldMakeDeepCopy(final RawBsonDocument rawDocument) { + RawBsonDocument cloned = (RawBsonDocument) rawDocument.clone(); + RawBsonDocument reference = createRawDocumentFromDocument(); + + assertNotSame(cloned.getByteBuffer().array(), reference.getByteBuffer().array()); + assertEquals(rawDocument.getByteBuffer().remaining(), cloned.getByteBuffer().remaining()); + assertEquals(reference, cloned); + } + + // Serialization + + @ParameterizedTest + @MethodSource("rawDocumentVariants") + @DisplayName("Java serialization works correctly") + void serializationShouldWork(final RawBsonDocument rawDocument) throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + new ObjectOutputStream(baos).writeObject(rawDocument); + Object deserialized = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())).readObject(); + assertEquals(DOCUMENT, deserialized); + } + + // --- Helper Methods --- + + static Stream rawDocumentVariants() { + return Stream.of( + createRawDocumentFromDocument(), + createRawDocumentFromByteArray(), + createRawDocumentFromByteArrayOffsetLength() + ); + } + + private static RawBsonDocument createRawDocumentFromDocument() { + return new RawBsonDocument(DOCUMENT, new BsonDocumentCodec()); + } + + private static RawBsonDocument createRawDocumentFromByteArray() { + return new RawBsonDocument(getBytesFromDocument()); + } + + private static RawBsonDocument createRawDocumentFromByteArrayOffsetLength() { + BasicOutputBuffer outputBuffer = new BasicOutputBuffer(1024); + new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), DOCUMENT, EncoderContext.builder().build()); + byte[] bytes = outputBuffer.getInternalBuffer(); + int size = outputBuffer.getPosition(); + + byte[] unstrippedBytes = new byte[size + 2]; + System.arraycopy(bytes, 0, unstrippedBytes, 1, size); + return new RawBsonDocument(unstrippedBytes, 1, size); + } + + private static byte[] getBytesFromDocument() { + BasicOutputBuffer outputBuffer = new BasicOutputBuffer(1024); + new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), DOCUMENT, EncoderContext.builder().build()); + byte[] bytes = outputBuffer.getInternalBuffer(); + int size = outputBuffer.getPosition(); + + byte[] strippedBytes = new byte[size]; + System.arraycopy(bytes, 0, strippedBytes, 0, size); + return strippedBytes; + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java index 4d4ebcc1169..1953154bf34 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java @@ -24,13 +24,15 @@ import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonBinaryReader; +import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; +import org.bson.BsonElement; import org.bson.BsonReader; import org.bson.BsonType; import org.bson.BsonValue; import org.bson.ByteBuf; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.DecoderContext; +import org.bson.RawBsonDocument; +import org.bson.io.BasicOutputBuffer; import org.bson.io.ByteBufferBsonInput; import org.bson.json.JsonMode; import org.bson.json.JsonWriterSettings; @@ -41,13 +43,9 @@ import java.io.ObjectInputStream; import java.io.UnsupportedEncodingException; import java.nio.charset.StandardCharsets; -import java.util.AbstractCollection; -import java.util.AbstractMap; -import java.util.AbstractSet; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -186,7 +184,6 @@ public final class ByteBufBsonDocument extends BsonDocument implements Closeable */ private transient boolean closed; - /** * Creates a {@code ByteBufBsonDocument} from an OP_MSG command message. * @@ -302,12 +299,14 @@ public boolean isEmpty() { @Override public boolean containsKey(final Object key) { ensureOpen(); - if (cachedDocument != null) { - return cachedDocument.containsKey(key); - } if (key == null) { throw new IllegalArgumentException("key can not be null"); } + + if (cachedDocument != null) { + return cachedDocument.containsKey(key); + } + // Check sequence fields first (fast HashMap lookup), then scan body if (sequenceFields.containsKey(key)) { return true; @@ -382,67 +381,23 @@ public String getFirstKey() { @Override public Set> entrySet() { - ensureOpen(); - if (cachedDocument != null) { - return cachedDocument.entrySet(); - } - return new AbstractSet>() { - @Override - public Iterator> iterator() { - // Combine body entries with sequence entries - return new CombinedIterator<>(createBodyIterator(IteratorMode.ENTRIES), createSequenceEntryIterator()); - } - - @Override - public int size() { - return ByteBufBsonDocument.this.size(); - } - }; + return toBsonDocument().entrySet(); } @Override public Collection values() { - ensureOpen(); - if (cachedDocument != null) { - return cachedDocument.values(); - } - return new AbstractCollection() { - @Override - public Iterator iterator() { - return new CombinedIterator<>(createBodyIterator(IteratorMode.VALUES), createSequenceValueIterator()); - } - - @Override - public int size() { - return ByteBufBsonDocument.this.size(); - } - }; + return toBsonDocument().values(); } @Override public Set keySet() { - ensureOpen(); - if (cachedDocument != null) { - return cachedDocument.keySet(); - } - return new AbstractSet() { - @Override - public Iterator iterator() { - return new CombinedIterator<>(createBodyIterator(IteratorMode.KEYS), sequenceFields.keySet().iterator()); - } - - @Override - public int size() { - return ByteBufBsonDocument.this.size(); - } - }; + return toBsonDocument().keySet(); } // ==================== Conversion Methods ==================== @Override public BsonReader asBsonReader() { - ensureOpen(); // Must hydrate first since we need to include sequence fields return toBsonDocument().asBsonReader(); } @@ -452,9 +407,8 @@ public BsonReader asBsonReader() { * *

After this method is called:

*
    - *
  • The result is cached for future calls
  • + *
  • The result is cached as a {@link RawBsonDocument} for future calls
  • *
  • All underlying byte buffers are released
  • - *
  • Sequence field documents are hydrated to regular {@code BsonDocument} instances
  • *
  • All subsequent read operations use the cached document
  • *
* @@ -464,21 +418,30 @@ public BsonReader asBsonReader() { public BsonDocument toBsonDocument() { ensureOpen(); if (cachedDocument == null) { - ByteBuf dup = bodyByteBuf.duplicate(); - try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { - // Decode body document - BsonDocument doc = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()); - // Add hydrated sequence fields - for (Map.Entry entry : sequenceFields.entrySet()) { - doc.put(entry.getKey(), entry.getValue().toHydratedArray()); + if (sequenceFields.isEmpty()) { + // No sequence fields: clone body bytes directly + byte[] clonedBytes = new byte[bodyByteBuf.remaining()]; + bodyByteBuf.get(bodyByteBuf.position(), clonedBytes); + cachedDocument = new RawBsonDocument(clonedBytes); + } else { + // With sequence fields: pipe body + extra elements + try (BasicOutputBuffer buffer = new BasicOutputBuffer()) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryWriter writer = new BsonBinaryWriter(buffer); + BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + List extraElements = new ArrayList<>(); + for (Entry entry : sequenceFields.entrySet()) { + extraElements.add(new BsonElement(entry.getKey(), entry.getValue().asArray())); + } + writer.pipe(reader, extraElements); + } finally { + dup.release(); + } + cachedDocument = new RawBsonDocument(buffer.getInternalBuffer(), 0, buffer.getPosition()); } - cachedDocument = doc; - closed = true; - // Release buffers since we no longer need them - releaseResources(); - } finally { - dup.release(); } + closed = true; + releaseResources(); } return cachedDocument; } @@ -490,33 +453,31 @@ public String toJson() { @Override public String toJson(final JsonWriterSettings settings) { - ensureOpen(); return toBsonDocument().toJson(settings); } @Override public String toString() { - ensureOpen(); return toBsonDocument().toString(); } @SuppressWarnings("MethodDoesntCallSuperMethod") @Override public BsonDocument clone() { - ensureOpen(); - return toBsonDocument().clone(); + if (cachedDocument != null) { + return cachedDocument.clone(); + } + return toBsonDocument(); } @SuppressWarnings("EqualsDoesntCheckParameterClass") @Override public boolean equals(final Object o) { - ensureOpen(); return toBsonDocument().equals(o); } @Override public int hashCode() { - ensureOpen(); return toBsonDocument().hashCode(); } @@ -572,6 +533,7 @@ public void clear() { // ==================== Private Body Field Operations ==================== // These methods read from bodyByteBuf using a temporary duplicate buffer + // Must be guarded by `ensureOpen` /** * Searches the body for a field with the given key. @@ -645,8 +607,8 @@ private BsonValue getValueFromBody(final String key) { } /** - * Gets the first key from the body, or from sequence fields if body is empty. - * Throws NoSuchElementException if the document is completely empty. + * Gets the first key from the body. + * Throws NoSuchElementException if the body document is completely empty. */ private String getFirstKeyFromBody() { ByteBuf dup = bodyByteBuf.duplicate(); @@ -655,10 +617,6 @@ private String getFirstKeyFromBody() { if (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { return reader.readName(); } - // Body is empty, try sequence fields - if (!sequenceFields.isEmpty()) { - return sequenceFields.keySet().iterator().next(); - } throw new NoSuchElementException(); } finally { dup.release(); @@ -697,147 +655,6 @@ private int countBodyFields() { } } - // ==================== Iterator Support ==================== - - /** - * Mode for the body iterator, determining what type of elements it produces. - */ - private enum IteratorMode { ENTRIES, KEYS, VALUES } - - /** - * Creates an iterator over the body document fields. - * - *

The iterator creates a duplicated ByteBuf that is temporarily tracked for safety. - * When iteration completes normally, the buffer is released immediately and removed from tracking. - * This prevents accumulation of finished iterator buffers while ensuring cleanup if the parent - * document is closed before iteration completes.

- * - * @param mode Determines whether to return entries, keys, or values - * @return An iterator of the appropriate type - */ - @SuppressWarnings("unchecked") - private Iterator createBodyIterator(final IteratorMode mode) { - return new Iterator() { - private final Closeable resourceHandle; - private ByteBuf duplicatedByteBuf; - private BsonBinaryReader reader; - private boolean started; - private boolean finished; - - { - // Create duplicate buffer for iteration and track it temporarily - duplicatedByteBuf = bodyByteBuf.duplicate(); - resourceHandle = () -> { - if (duplicatedByteBuf != null) { - try { - if (reader != null) { - reader.close(); - } - } catch (Exception e) { - // Ignore - } - duplicatedByteBuf.release(); - duplicatedByteBuf = null; - reader = null; - } - }; - trackedResources.add(resourceHandle); - reader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); - } - - @Override - public boolean hasNext() { - if (finished) { - return false; - } - ensureOpen(); - if (!started) { - reader.readStartDocument(); - reader.readBsonType(); - started = true; - } - boolean hasNext = reader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; - if (!hasNext) { - cleanup(); - } - return hasNext; - } - - @Override - public T next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - ensureOpen(); - String key = reader.readName(); - BsonValue value = readBsonValue(duplicatedByteBuf, reader, trackedResources); - reader.readBsonType(); - - switch (mode) { - case ENTRIES: - return (T) new AbstractMap.SimpleImmutableEntry<>(key, value); - case KEYS: - return (T) key; - case VALUES: - return (T) value; - default: - throw new IllegalStateException("Unknown iterator mode: " + mode); - } - } - - private void cleanup() { - if (!finished) { - finished = true; - // Remove from tracked resources since we're cleaning up immediately - trackedResources.remove(resourceHandle); - try { - resourceHandle.close(); - } catch (Exception e) { - // Ignore - } - } - } - }; - } - - /** - * Creates an iterator over sequence fields as map entries. - * Each entry contains the field name and its array value. - */ - private Iterator> createSequenceEntryIterator() { - Iterator> iter = sequenceFields.entrySet().iterator(); - return new Iterator>() { - @Override - public boolean hasNext() { - return iter.hasNext(); - } - - @Override - public Entry next() { - Entry entry = iter.next(); - return new AbstractMap.SimpleImmutableEntry<>(entry.getKey(), entry.getValue().asArray()); - } - }; - } - - /** - * Creates an iterator over sequence field values (arrays). - */ - private Iterator createSequenceValueIterator() { - Iterator iter = sequenceFields.values().iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - return iter.hasNext(); - } - - @Override - public BsonValue next() { - return iter.next().asArray(); - } - }; - } - // ==================== Resource Management Helpers ==================== /** @@ -977,69 +794,6 @@ boolean containsValue(final Object value) { return value instanceof BsonValue && asArray().asArray().contains(value); } - /** - * Converts this sequence to a BsonArray of regular BsonDocument instances. - * - *

Used by {@link ByteBufBsonDocument#toBsonDocument()} to fully hydrate the document. - * Unlike {@link #asArray()}, this creates regular BsonDocument instances, not - * ByteBufBsonDocument wrappers.

- * - * @return A BsonArray containing fully deserialized BsonDocument instances - */ - BsonArray toHydratedArray() { - ByteBuf dup = sequenceByteBuf.duplicate(); - try { - List hydratedDocs = new ArrayList<>(); - while (dup.hasRemaining()) { - int docStart = dup.position(); - int docSize = dup.getInt(); - int docEnd = docStart + docSize; - ByteBuf docBuf = dup.duplicate().position(docStart).limit(docEnd); - try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(docBuf))) { - hydratedDocs.add(new BsonDocumentCodec().decode(reader, DecoderContext.builder().build())); - } finally { - docBuf.release(); - } - dup.position(docEnd); - } - return new BsonArray(hydratedDocs); - } finally { - dup.release(); - } - } } - /** - * An iterator that combines two iterators sequentially. - * - *

Used to merge body field iteration with sequence field iteration, - * presenting a unified view of all document fields.

- * - * @param The type of elements returned by the iterator - */ - private static final class CombinedIterator implements Iterator { - private final Iterator primary; - private final Iterator secondary; - - CombinedIterator(final Iterator primary, final Iterator secondary) { - this.primary = primary; - this.secondary = secondary; - } - - @Override - public boolean hasNext() { - return primary.hasNext() || secondary.hasNext(); - } - - @Override - public T next() { - if (primary.hasNext()) { - return primary.next(); - } - if (secondary.hasNext()) { - return secondary.next(); - } - throw new NoSuchElementException(); - } - } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java index 637f89cb347..4b05607a56f 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java @@ -49,8 +49,6 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.Date; import java.util.Iterator; @@ -351,17 +349,19 @@ void testAllBsonTypes() { } static ByteBufBsonArray fromBsonValues(final List values) { - BsonDocument document = new BsonDocument() - .append("a", new BsonArray(values)); + BsonDocument document = new BsonDocument("a", new BsonArray(values)); BasicOutputBuffer buffer = new BasicOutputBuffer(); new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - try { - buffer.pipe(baos); - } catch (IOException e) { - throw new RuntimeException("impossible!"); - } - ByteBuf documentByteBuf = new ByteBufNIO(ByteBuffer.wrap(baos.toByteArray())); - return (ByteBufBsonArray) new ByteBufBsonDocument(documentByteBuf).entrySet().iterator().next().getValue(); + byte[] bytes = new byte[buffer.getPosition()]; + System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, bytes.length); + // Skip past the outer document header to the array value bytes. + // Document format: [4-byte size][type byte (0x04)][field name "a\0"][array bytes...][0x00] + int arrayOffset = 4 + 1 + 2; // doc size + type byte + "a" + null terminator + int arraySize = (bytes[arrayOffset] & 0xFF) + | ((bytes[arrayOffset + 1] & 0xFF) << 8) + | ((bytes[arrayOffset + 2] & 0xFF) << 16) + | ((bytes[arrayOffset + 3] & 0xFF) << 24); + ByteBuf arrayByteBuf = new ByteBufNIO(ByteBuffer.wrap(bytes, arrayOffset, arraySize)); + return new ByteBufBsonArray(arrayByteBuf); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java index 1f61f309d14..f3744057a18 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java @@ -62,6 +62,7 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -321,7 +322,10 @@ void toStringShouldWork() { void cloneShouldWork() { try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) { BsonDocument cloned = byteBufDocument.clone(); + assertNotSame(byteBufDocument, cloned); assertEquals(byteBufDocument, cloned); + + assertNotSame(byteBufDocument.clone(), byteBufDocument.clone()); } } @@ -438,52 +442,32 @@ void deeplyNestedClosedRecursively() { } @Test - @DisplayName("Iteration tracks resources correctly") - void iterationTracksResources() { + @DisplayName("Iterators work as expected") + void iteratorsWorksAsExpected() { BsonDocument doc = new BsonDocument() .append("doc1", new BsonDocument("a", new BsonInt32(1))) .append("arr1", new BsonArray(asList(new BsonInt32(2), new BsonInt32(3)))) .append("primitive", new BsonString("test")); - ByteBuf buf = createByteBufFromDocument(doc); - ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); - - int count = 0; - for (Map.Entry entry : byteBufDoc.entrySet()) { - assertNotNull(entry.getKey()); - assertNotNull(entry.getValue()); - count++; - } - assertEquals(3, count); + try (ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(createByteBufFromDocument(doc))) { - byteBufDoc.close(); - assertThrows(IllegalStateException.class, byteBufDoc::size); - } - - @Test - @DisplayName("Iterators ensure the resource is still open") - void iteratorsEnsureResourceIsStillOpen() { - BsonDocument doc = new BsonDocument() - .append("doc1", new BsonDocument("a", new BsonInt32(1))) - .append("arr1", new BsonArray(asList(new BsonInt32(2), new BsonInt32(3)))) - .append("primitive", new BsonString("test")); - - ByteBuf buf = createByteBufFromDocument(doc); - ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); - - Iterator keysIterator = byteBufDoc.keySet().iterator(); - assertDoesNotThrow(keysIterator::hasNext); + int count = 0; + for (Map.Entry entry : byteBufDoc.entrySet()) { + assertNotNull(entry.getKey()); + assertNotNull(entry.getValue()); + count++; + } + assertEquals(3, count); - Iterator nestedKeysIterator = byteBufDoc.getDocument("doc1").keySet().iterator(); - assertDoesNotThrow(nestedKeysIterator::hasNext); + Iterator keysIterator = byteBufDoc.keySet().iterator(); + assertDoesNotThrow(keysIterator::hasNext); - Iterator arrayIterator = byteBufDoc.getArray("arr1").iterator(); - assertDoesNotThrow(arrayIterator::hasNext); + Iterator nestedKeysIterator = byteBufDoc.getDocument("doc1").keySet().iterator(); + assertDoesNotThrow(nestedKeysIterator::hasNext); - byteBufDoc.close(); - assertThrows(IllegalStateException.class, keysIterator::hasNext); - assertThrows(IllegalStateException.class, nestedKeysIterator::hasNext); - assertThrows(IllegalStateException.class, arrayIterator::hasNext); + Iterator arrayIterator = byteBufDoc.getArray("arr1").iterator(); + assertDoesNotThrow(arrayIterator::hasNext); + } } @Test From f91b10e4316b0e93b981152d9922734563335e0a Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 24 Mar 2026 15:07:24 +0000 Subject: [PATCH 3/4] Merge main into ByteBuf (#1907) * Revert NettyByteBuf.asReadOnly change (#1871) Originally introduced in 057649fd This change had the unintended side effect of leaking netty ByteBufs when logging. JAVA-5982 * Update Netty dependency to the latest version. (#1867) JAVA-5818 * Adjust timeout handling in client-side operations to account for RTT variations (#1793) JAVA-5375 --------- Co-authored-by: Ross Lawley * Update Snappy version for the latest security fixes. (#1868) JAVA-6069 * Bson javadoc improvements (#1883) Fixed no comment warning in BinaryVector Improved BsonBinary asUuid documentation Improved BsonBinarySubType isUuid documentation JAVA-6086 * Make NettyByteBuf share parent reference count. (#1891) JAVA-6107 * JAVA-5907 (#1893) * JAVA-5907 * JAVA-5907 use execute within executor service If we don't use the return value from executor then we should use `execute` instead of `submit` * format * revert error log for netty leak --------- Co-authored-by: Almas Abdrazak * Fix RawBsonDocument encoding performance regression (#1888) Add instanceof check in BsonDocumentCodec to route RawBsonDocument to RawBsonDocumentCodec, restoring efficient byte-copy encoding. Previous BsonType-based lookup led to sub-optimal performance as it could not distinguish RawBsonDocument from BsonDocument. JAVA-6101 * Update specifications to latest (#1884) JAVA-6092 * Evergreen atlas search fix (#1894) Update evergreen atlas-deployed-task-group configuration Assume test secrets and follow the driver-evergreen-tools atlas recommended usage: https://github.com/mongodb-labs/drivers-evergreen-tools/tree/master/.evergreen/atlas#usage JAVA-6103 * [JAVA-6028] Add Micrometer/OpenTelemetry tracing support to the reactive-streams (#1898) * Add Micrometer/OpenTelemetry tracing support to the reactive-streams driver https://jira.mongodb.org/browse/JAVA-6028 Port the tracing infrastructure from the sync driver to driver-reactive-streams, reusing the existing driver-core, TracingManager, Span, and TraceContext classes. * Move error handling and span lifecycle (span.error(), span.end()) from Reactor's doOnError/doFinally operators into the async callback, before emitting the result to the subscriber. * Making sure span is properly closed when an exception occurs * Clone command event document before storing to prevent use-after-free. (#1901) * Version: bump 5.7.0-beta1 * Version: bump 5.7.0-SNAPSHOT * Remove unneeded variable --------- Co-authored-by: Viacheslav Babanin Co-authored-by: Almas Abdrazak Co-authored-by: Almas Abdrazak Co-authored-by: Nabil Hachicha Co-authored-by: Nabil Hachicha <1793238+nhachicha@users.noreply.github.com> --- .evergreen/.evg.yml | 8 +- bson/src/main/org/bson/BinaryVector.java | 3 + bson/src/main/org/bson/BsonBinary.java | 16 +- bson/src/main/org/bson/BsonBinarySubType.java | 2 +- .../org/bson/codecs/BsonDocumentCodec.java | 6 + .../src/test/unit/util/ThreadTestHelpers.java | 2 +- .../AbstractBsonDocumentBenchmark.java | 2 +- .../benchmark/benchmarks/BenchmarkSuite.java | 3 + .../GridFSMultiFileDownloadBenchmark.java | 6 +- .../GridFSMultiFileUploadBenchmark.java | 2 +- .../benchmarks/MultiFileExportBenchmark.java | 6 +- .../benchmarks/MultiFileImportBenchmark.java | 4 +- .../RawBsonArrayEncodingBenchmark.java | 55 +++++++ .../RawBsonNestedEncodingBenchmark.java | 46 ++++++ .../framework/MongoCryptBenchmarkRunner.java | 2 +- .../connection/DefaultConnectionPool.java | 2 +- .../connection/InternalStreamConnection.java | 54 ++++-- .../connection/netty/NettyByteBuf.java | 4 +- .../connection/netty/NettyStream.java | 3 +- .../async/AsynchronousTlsChannel.java | 28 ++-- .../async/AsynchronousTlsChannelGroup.java | 15 +- .../micrometer/TracingManager.java | 45 +++++ .../CommandHelperSpecification.groovy | 2 +- .../connection/DefaultConnectionPoolTest.java | 4 +- .../mongodb/internal/TimeoutContextTest.java | 7 +- .../kotlin/client/coroutine/ClientSession.kt | 4 + driver-reactive-streams/build.gradle.kts | 15 +- .../reactivestreams/client/ClientSession.java | 11 ++ .../client/internal/ClientSessionHelper.java | 8 +- .../internal/ClientSessionPublisherImpl.java | 39 ++++- .../client/internal/MongoClientImpl.java | 7 +- .../internal/OperationExecutorImpl.java | 59 +++++-- .../client/internal/TimeoutHelper.java | 15 +- .../gridfs/GridFSUploadPublisherImpl.java | 10 +- .../ClientSideOperationTimeoutProseTest.java | 133 +++++++++++---- .../observability/MicrometerProseTest.java | 32 ++++ .../client/syncadapter/SyncClientSession.java | 2 +- .../client/unified/MicrometerTracingTest.java | 27 +++ .../client/internal/MongoClientImplTest.java | 3 +- .../client/internal/MongoClusterImpl.java | 62 +------ ...tClientSideOperationsTimeoutProseTest.java | 150 +++++++++-------- .../AbstractMicrometerProseTest.java} | 155 +++++++++++++++--- .../client/AbstractSessionsProseTest.java | 2 +- ...eOperationsEncryptionTimeoutProseTest.java | 27 ++- .../observability/MicrometerProseTest.java | 32 ++++ .../{ => client}/observability/SpanTree.java | 6 +- .../mongodb/client/unified/UnifiedTest.java | 2 +- .../unified/UnifiedTestModifications.java | 19 +++ gradle/libs.versions.toml | 4 +- testing/resources/specifications | 2 +- 50 files changed, 876 insertions(+), 277 deletions(-) create mode 100644 driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonArrayEncodingBenchmark.java create mode 100644 driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonNestedEncodingBenchmark.java create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java rename driver-sync/src/test/functional/com/mongodb/{observability/MicrometerProseTest.java => client/AbstractMicrometerProseTest.java} (57%) create mode 100644 driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java rename driver-sync/src/test/functional/com/mongodb/{ => client}/observability/SpanTree.java (98%) diff --git a/.evergreen/.evg.yml b/.evergreen/.evg.yml index 525861928f3..da9d720de40 100644 --- a/.evergreen/.evg.yml +++ b/.evergreen/.evg.yml @@ -1939,16 +1939,18 @@ task_groups: setup_group: - func: "fetch-source" - func: "prepare-resources" + - func: "assume-aws-test-secrets-role" - command: subprocess.exec type: "setup" params: working_dir: "src" binary: bash - add_expansions_to_env: true + include_expansions_in_env: [ "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN" ] env: + CLUSTER_PREFIX: "dbx-java" MONGODB_VERSION: "8.0" args: - - ${DRIVERS_TOOLS}/.evergreen/atlas/setup-atlas-cluster.sh + - ${DRIVERS_TOOLS}/.evergreen/atlas/setup.sh - command: expansions.update params: file: src/atlas-expansion.yml @@ -1960,7 +1962,7 @@ task_groups: binary: bash add_expansions_to_env: true args: - - ${DRIVERS_TOOLS}/.evergreen/atlas/teardown-atlas-cluster.sh + - ${DRIVERS_TOOLS}/.evergreen/atlas/teardown.sh tasks: - "atlas-search-index-management-task" - "aws-lambda-deployed-task" diff --git a/bson/src/main/org/bson/BinaryVector.java b/bson/src/main/org/bson/BinaryVector.java index a1914601a9d..f5d57f5b241 100644 --- a/bson/src/main/org/bson/BinaryVector.java +++ b/bson/src/main/org/bson/BinaryVector.java @@ -32,6 +32,9 @@ * @since 5.3 */ public abstract class BinaryVector { + /** + * The BinaryVector logger + */ protected static final Logger LOGGER = Loggers.getLogger("BinaryVector"); private final DataType dataType; diff --git a/bson/src/main/org/bson/BsonBinary.java b/bson/src/main/org/bson/BsonBinary.java index 833a1b5ad29..0ece148eb2d 100644 --- a/bson/src/main/org/bson/BsonBinary.java +++ b/bson/src/main/org/bson/BsonBinary.java @@ -127,9 +127,14 @@ public BsonBinary(final UUID uuid, final UuidRepresentation uuidRepresentation) } /** - * Returns the binary as a UUID. The binary type must be 4. + * Returns the binary as a UUID. + * + *

Note:The BsonBinary subtype must be {@link BsonBinarySubType#UUID_STANDARD}.

* * @return the uuid + * @throws BsonInvalidOperationException if BsonBinary subtype is not {@link BsonBinarySubType#UUID_STANDARD} + * @see #asUuid(UuidRepresentation) + * @see BsonBinarySubType * @since 3.9 */ public UUID asUuid() { @@ -162,8 +167,15 @@ public BinaryVector asVector() { /** * Returns the binary as a UUID. * - * @param uuidRepresentation the UUID representation + *

Note:The BsonBinary subtype must be either {@link BsonBinarySubType#UUID_STANDARD} or + * {@link BsonBinarySubType#UUID_LEGACY}.

+ * + * @param uuidRepresentation the UUID representation, must be {@link UuidRepresentation#STANDARD} or + * {@link UuidRepresentation#JAVA_LEGACY} * @return the uuid + * @throws BsonInvalidOperationException if the BsonBinary subtype is incompatible with the given {@code uuidRepresentation}, or if + * the {@code uuidRepresentation} is not {@link UuidRepresentation#STANDARD} or + * {@link UuidRepresentation#JAVA_LEGACY}. * @since 3.9 */ public UUID asUuid(final UuidRepresentation uuidRepresentation) { diff --git a/bson/src/main/org/bson/BsonBinarySubType.java b/bson/src/main/org/bson/BsonBinarySubType.java index 08c29e2ef09..2a6eed1f5de 100644 --- a/bson/src/main/org/bson/BsonBinarySubType.java +++ b/bson/src/main/org/bson/BsonBinarySubType.java @@ -93,7 +93,7 @@ public enum BsonBinarySubType { * Returns true if the given value is a UUID subtype. * * @param value the subtype value as a byte. - * @return true if value is a UUID subtype. + * @return true if value has a {@link #UUID_STANDARD} or {@link #UUID_LEGACY} subtype. * @since 3.4 */ public static boolean isUuid(final byte value) { diff --git a/bson/src/main/org/bson/codecs/BsonDocumentCodec.java b/bson/src/main/org/bson/codecs/BsonDocumentCodec.java index 75bd3b7a2b0..172b0c94338 100644 --- a/bson/src/main/org/bson/codecs/BsonDocumentCodec.java +++ b/bson/src/main/org/bson/codecs/BsonDocumentCodec.java @@ -22,6 +22,7 @@ import org.bson.BsonType; import org.bson.BsonValue; import org.bson.BsonWriter; +import org.bson.RawBsonDocument; import org.bson.codecs.configuration.CodecRegistry; import org.bson.types.ObjectId; @@ -40,6 +41,7 @@ public class BsonDocumentCodec implements CollectibleCodec { private static final String ID_FIELD_NAME = "_id"; private static final CodecRegistry DEFAULT_REGISTRY = fromProviders(new BsonValueCodecProvider()); private static final BsonTypeCodecMap DEFAULT_BSON_TYPE_CODEC_MAP = new BsonTypeCodecMap(getBsonTypeClassMap(), DEFAULT_REGISTRY); + private static final RawBsonDocumentCodec RAW_BSON_DOCUMENT_CODEC = new RawBsonDocumentCodec(); private final CodecRegistry codecRegistry; private final BsonTypeCodecMap bsonTypeCodecMap; @@ -101,6 +103,10 @@ protected BsonValue readValue(final BsonReader reader, final DecoderContext deco @Override public void encode(final BsonWriter writer, final BsonDocument value, final EncoderContext encoderContext) { + if (value instanceof RawBsonDocument) { + RAW_BSON_DOCUMENT_CODEC.encode(writer, (RawBsonDocument) value, encoderContext); + return; + } writer.writeStartDocument(); beforeFields(writer, encoderContext, value); diff --git a/bson/src/test/unit/util/ThreadTestHelpers.java b/bson/src/test/unit/util/ThreadTestHelpers.java index e2115da079f..2428ee9074e 100644 --- a/bson/src/test/unit/util/ThreadTestHelpers.java +++ b/bson/src/test/unit/util/ThreadTestHelpers.java @@ -41,7 +41,7 @@ public static void executeAll(final Runnable... runnables) { CountDownLatch latch = new CountDownLatch(runnables.length); List failures = Collections.synchronizedList(new ArrayList<>()); for (final Runnable runnable : runnables) { - service.submit(() -> { + service.execute(() -> { try { runnable.run(); } catch (Throwable e) { diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/AbstractBsonDocumentBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/AbstractBsonDocumentBenchmark.java index 89f932f03cd..78e6e37f7f9 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/AbstractBsonDocumentBenchmark.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/AbstractBsonDocumentBenchmark.java @@ -61,7 +61,7 @@ public int getBytesPerRun() { return fileLength * NUM_INTERNAL_ITERATIONS; } - private byte[] getDocumentAsBuffer(final T document) throws IOException { + protected byte[] getDocumentAsBuffer(final T document) throws IOException { BasicOutputBuffer buffer = new BasicOutputBuffer(); codec.encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/BenchmarkSuite.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/BenchmarkSuite.java index 2595568f148..c2a8ed9bafe 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/BenchmarkSuite.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/BenchmarkSuite.java @@ -71,6 +71,9 @@ private static void runBenchmarks() runBenchmark(new BsonDecodingBenchmark<>("Deep", "extended_bson/deep_bson.json", DOCUMENT_CODEC)); runBenchmark(new BsonDecodingBenchmark<>("Full", "extended_bson/full_bson.json", DOCUMENT_CODEC)); + runBenchmark(new RawBsonNestedEncodingBenchmark("Full RawBsonDocument in BsonDocument BSON Encoding", "extended_bson/full_bson.json")); + runBenchmark(new RawBsonArrayEncodingBenchmark("Full RawBsonDocument Array in BsonDocument BSON Encoding", "extended_bson/full_bson.json", 10)); + runBenchmark(new RunCommandBenchmark<>(DOCUMENT_CODEC)); runBenchmark(new FindOneBenchmark("single_and_multi_document/tweet.json", BenchmarkSuite.DOCUMENT_CLASS)); diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileDownloadBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileDownloadBenchmark.java index e39c0fb46ba..f8f66fe8b90 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileDownloadBenchmark.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileDownloadBenchmark.java @@ -97,7 +97,7 @@ public void run() throws Exception { CountDownLatch latch = new CountDownLatch(50); for (int i = 0; i < 50; i++) { - gridFSService.submit(exportFile(latch, i)); + gridFSService.execute(exportFile(latch, i)); } latch.await(1, TimeUnit.MINUTES); @@ -107,7 +107,7 @@ private Runnable exportFile(final CountDownLatch latch, final int fileId) { return () -> { UnsafeByteArrayOutputStream outputStream = new UnsafeByteArrayOutputStream(5242880); bucket.downloadToStream(GridFSMultiFileDownloadBenchmark.this.getFileName(fileId), outputStream); - fileService.submit(() -> { + fileService.execute(() -> { try { FileOutputStream fos = new FileOutputStream(new File(tempDirectory, String.format("%02d", fileId) + ".txt")); fos.write(outputStream.getByteArray()); @@ -124,7 +124,7 @@ private void importFiles() throws Exception { CountDownLatch latch = new CountDownLatch(50); for (int i = 0; i < 50; i++) { - fileService.submit(importFile(latch, i)); + fileService.execute(importFile(latch, i)); } latch.await(1, TimeUnit.MINUTES); diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileUploadBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileUploadBenchmark.java index cefdc7eaf1c..e2ee177847d 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileUploadBenchmark.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/GridFSMultiFileUploadBenchmark.java @@ -75,7 +75,7 @@ public void run() throws Exception { CountDownLatch latch = new CountDownLatch(50); for (int i = 0; i < 50; i++) { - fileService.submit(importFile(latch, i)); + fileService.execute(importFile(latch, i)); } latch.await(1, TimeUnit.MINUTES); diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileExportBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileExportBenchmark.java index 30c74084419..d57829de45b 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileExportBenchmark.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileExportBenchmark.java @@ -109,7 +109,7 @@ public void run() throws Exception { CountDownLatch latch = new CountDownLatch(100); for (int i = 0; i < 100; i++) { - documentReadingService.submit(exportJsonFile(i, latch)); + documentReadingService.execute(exportJsonFile(i, latch)); } latch.await(1, TimeUnit.MINUTES); @@ -125,7 +125,7 @@ private Runnable exportJsonFile(final int fileId, final CountDownLatch latch) { List documents = collection.find(new BsonDocument("fileId", new BsonInt32(fileId))) .batchSize(5000) .into(new ArrayList<>(5000)); - fileWritingService.submit(writeJsonFile(fileId, documents, latch)); + fileWritingService.execute(writeJsonFile(fileId, documents, latch)); }; } @@ -154,7 +154,7 @@ private void importJsonFiles() throws InterruptedException { for (int i = 0; i < 100; i++) { int fileId = i; - importService.submit(() -> { + importService.execute(() -> { String resourcePath = "parallel/ldjson_multi/ldjson" + String.format("%03d", fileId) + ".txt"; try (BufferedReader reader = new BufferedReader(readFromRelativePath(resourcePath), 1024 * 64)) { String json; diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileImportBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileImportBenchmark.java index 03d1a721bee..d7afc54496d 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileImportBenchmark.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/MultiFileImportBenchmark.java @@ -86,7 +86,7 @@ public void run() throws InterruptedException { CountDownLatch latch = new CountDownLatch(500); for (int i = 0; i < 100; i++) { - fileReadingService.submit(importJsonFile(latch, i)); + fileReadingService.execute(importJsonFile(latch, i)); } latch.await(1, TimeUnit.MINUTES); @@ -104,7 +104,7 @@ private Runnable importJsonFile(final CountDownLatch latch, final int fileId) { documents.add(document); if (documents.size() == 1000) { List documentsToInsert = documents; - documentWritingService.submit(() -> { + documentWritingService.execute(() -> { collection.insertMany(documentsToInsert, new InsertManyOptions().ordered(false)); latch.countDown(); }); diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonArrayEncodingBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonArrayEncodingBenchmark.java new file mode 100644 index 00000000000..0768f4f63c6 --- /dev/null +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonArrayEncodingBenchmark.java @@ -0,0 +1,55 @@ +/* + * Copyright 2016-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.mongodb.benchmark.benchmarks; + +import org.bson.BsonArray;import org.bson.BsonDocument; +import org.bson.RawBsonDocument; +import org.bson.codecs.BsonDocumentCodec; + +import java.io.IOException; + +public class RawBsonArrayEncodingBenchmark extends BsonEncodingBenchmark { + + private final int arraySize; + + public RawBsonArrayEncodingBenchmark(final String name, final String resourcePath, final int arraySize) { + super(name, resourcePath, new BsonDocumentCodec()); + this.arraySize = arraySize; + } + + @Override + public void setUp() throws IOException { + super.setUp(); + RawBsonDocument rawDoc = new RawBsonDocument(document, codec); + + BsonArray array = new BsonArray(); + for (int i = 0; i < arraySize; i++) { + array.add(rawDoc); + } + document = new BsonDocument("results", array); + + // Recalculate documentBytes for accurate throughput reporting + documentBytes = getDocumentAsBuffer(document); + + } + + @Override + public int getBytesPerRun() { + return documentBytes.length * NUM_INTERNAL_ITERATIONS; + } +} \ No newline at end of file diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonNestedEncodingBenchmark.java b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonNestedEncodingBenchmark.java new file mode 100644 index 00000000000..3872c5888d9 --- /dev/null +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/benchmarks/RawBsonNestedEncodingBenchmark.java @@ -0,0 +1,46 @@ +/* + * Copyright 2016-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.mongodb.benchmark.benchmarks; + +import org.bson.BsonDocument; +import org.bson.RawBsonDocument; +import org.bson.codecs.BsonDocumentCodec; + +import java.io.IOException; + +public class RawBsonNestedEncodingBenchmark extends BsonEncodingBenchmark { + + public RawBsonNestedEncodingBenchmark(final String name, final String resourcePath) { + super(name, resourcePath, new BsonDocumentCodec()); + } + + @Override + public void setUp() throws IOException { + super.setUp(); + + RawBsonDocument rawDoc = new RawBsonDocument(document, codec); + document = new BsonDocument("nested", rawDoc); + + documentBytes = getDocumentAsBuffer(document); + } + + @Override + public int getBytesPerRun() { + return documentBytes.length * NUM_INTERNAL_ITERATIONS; + } +} \ No newline at end of file diff --git a/driver-benchmarks/src/main/com/mongodb/benchmark/framework/MongoCryptBenchmarkRunner.java b/driver-benchmarks/src/main/com/mongodb/benchmark/framework/MongoCryptBenchmarkRunner.java index 718ab9f21af..a6c623364db 100644 --- a/driver-benchmarks/src/main/com/mongodb/benchmark/framework/MongoCryptBenchmarkRunner.java +++ b/driver-benchmarks/src/main/com/mongodb/benchmark/framework/MongoCryptBenchmarkRunner.java @@ -177,7 +177,7 @@ public List run() throws InterruptedException { for (int i = 0; i < threadCount; i++) { DecryptTask decryptTask = new DecryptTask(mongoCrypt, encrypted, NUM_SECS, doneSignal); decryptTasks.add(decryptTask); - executorService.submit(decryptTask); + executorService.execute(decryptTask); } // Await completion of all tasks. Tasks are expected to complete shortly after NUM_SECS. Time out `await` if time exceeds 2 * NUM_SECS. diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java index 81a0e59e277..2339cf18b86 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java @@ -1321,7 +1321,7 @@ private boolean initUnlessClosed() { boolean result = true; if (state == State.NEW) { worker = Executors.newSingleThreadExecutor(new DaemonThreadFactory("AsyncGetter")); - worker.submit(() -> runAndLogUncaught(this::workerRun)); + worker.execute(() -> runAndLogUncaught(this::workerRun)); state = State.INITIALIZED; } else if (state == State.CLOSED) { result = false; diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index aeef4e0a6a1..f10f471881b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -454,9 +454,7 @@ private T sendAndReceiveInternal(final CommandMessage message, final Decoder () -> getDescription().getConnectionId() ); - boolean isLoggingCommandNeeded = isLoggingCommandNeeded(); - - if (isLoggingCommandNeeded) { + if (isLoggingCommandNeeded()) { commandEventSender = new LoggingCommandEventSender( SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, operationContext, message, commandDocument, @@ -618,37 +616,71 @@ private void sendAndReceiveAsyncInternal(final CommandMessage message, final // Async try with resources release after the write ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this); + Span tracingSpan = null; try { message.encode(bsonOutput, operationContext); - - String commandName; CommandEventSender commandEventSender; try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { - commandName = commandDocument.getFirstKey(); + tracingSpan = operationContext + .getTracingManager() + .createTracingSpan(message, + operationContext, + commandDocument, + cmdName -> SECURITY_SENSITIVE_COMMANDS.contains(cmdName) + || SECURITY_SENSITIVE_HELLO_COMMANDS.contains(cmdName), + () -> getDescription().getServerAddress(), + () -> getDescription().getConnectionId() + ); + if (isLoggingCommandNeeded()) { commandEventSender = new LoggingCommandEventSender( SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, operationContext, message, commandDocument, COMMAND_PROTOCOL_LOGGER, loggerSettings); + commandEventSender.sendStartedEvent(); } else { commandEventSender = new NoOpCommandEventSender(); } - commandEventSender.sendStartedEvent(); - } - List messageByteBuffers = getMessageByteBuffers(commandName, message, bsonOutput, operationContext); + boolean isTracingCommandPayloadNeeded = tracingSpan != null && operationContext.getTracingManager().isCommandPayloadEnabled(); + if (isTracingCommandPayloadNeeded) { + tracingSpan.tagHighCardinality(QUERY_TEXT.asString(), commandDocument); + } + + final Span commandSpan = tracingSpan; + SingleResultCallback tracingCallback = commandSpan == null ? callback : (result, t) -> { + try { + if (t != null) { + if (t instanceof MongoCommandException) { + commandSpan.tagLowCardinality( + RESPONSE_STATUS_CODE.withValue(String.valueOf(((MongoCommandException) t).getErrorCode()))); + } + commandSpan.error(t); + } + } finally { + commandSpan.end(); + callback.onResult(result, t); + } + }; + + List messageByteBuffers = getMessageByteBuffers(commandDocument.getFirstKey(), message, bsonOutput, operationContext); sendCommandMessageAsync(messageByteBuffers, message.getId(), decoder, operationContext, commandEventSender, message.isResponseExpected(), (r, t) -> { ResourceUtil.release(messageByteBuffers); bsonOutput.close(); // Close AFTER async write completes if (t != null) { - callback.onResult(null, t); + tracingCallback.onResult(null, t); } else { - callback.onResult(r, null); + tracingCallback.onResult(r, null); } }); + } } catch (Throwable t) { bsonOutput.close(); + if (tracingSpan != null) { + tracingSpan.error(t); + tracingSpan.end(); + } callback.onResult(null, t); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java index 72235b46760..99233dcc77e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java @@ -251,12 +251,12 @@ public ByteBuf limit(final int newLimit) { @Override public ByteBuf asReadOnly() { - return new NettyByteBuf(proxied.asReadOnly().retain(), false); + return this; } @Override public ByteBuf duplicate() { - return new NettyByteBuf(proxied.retainedDuplicate(), isWriting); + return new NettyByteBuf(proxied.duplicate().retain(), isWriting); } @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java index 76e10653454..e480363fc82 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java @@ -307,7 +307,8 @@ private void readAsync(final int numBytes, final AsyncCompletionHandler composite.addComponent(next); iter.remove(); } else { - composite.addComponent(next.readRetainedSlice(bytesNeededFromCurrentBuffer)); + next.retain(); + composite.addComponent(next.readSlice(bytesNeededFromCurrentBuffer)); } composite.writerIndex(composite.writerIndex() + bytesNeededFromCurrentBuffer); bytesNeeded -= bytesNeededFromCurrentBuffer; diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannel.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannel.java index 04114318f92..c1e3f067335 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannel.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannel.java @@ -98,8 +98,8 @@ public void read(ByteBuffer dst, A attach, CompletionHandler group.submit(() -> handler.completed((int) c, attach)), - e -> group.submit(() -> handler.failed(e, attach))); + c -> group.execute(() -> handler.completed((int) c, attach)), + e -> group.execute(() -> handler.failed(e, attach))); } @Override @@ -119,8 +119,8 @@ public void read( new ByteBufferSet(dst), timeout, unit, - c -> group.submit(() -> handler.completed((int) c, attach)), - e -> group.submit(() -> handler.failed(e, attach))); + c -> group.execute(() -> handler.completed((int) c, attach)), + e -> group.execute(() -> handler.failed(e, attach))); } @Override @@ -145,8 +145,8 @@ public void read( bufferSet, timeout, unit, - c -> group.submit(() -> handler.completed(c, attach)), - e -> group.submit(() -> handler.failed(e, attach))); + c -> group.execute(() -> handler.completed(c, attach)), + e -> group.execute(() -> handler.failed(e, attach))); } @Override @@ -185,8 +185,8 @@ public void write(ByteBuffer src, A attach, CompletionHandler group.submit(() -> handler.completed((int) c, attach)), - e -> group.submit(() -> handler.failed(e, attach))); + c -> group.execute(() -> handler.completed((int) c, attach)), + e -> group.execute(() -> handler.failed(e, attach))); } @Override @@ -205,8 +205,8 @@ public void write( new ByteBufferSet(src), timeout, unit, - c -> group.submit(() -> handler.completed((int) c, attach)), - e -> group.submit(() -> handler.failed(e, attach))); + c -> group.execute(() -> handler.completed((int) c, attach)), + e -> group.execute(() -> handler.failed(e, attach))); } @Override @@ -228,8 +228,8 @@ public void write( bufferSet, timeout, unit, - c -> group.submit(() -> handler.completed(c, attach)), - e -> group.submit(() -> handler.failed(e, attach))); + c -> group.execute(() -> handler.completed(c, attach)), + e -> group.execute(() -> handler.failed(e, attach))); } @Override @@ -251,11 +251,11 @@ public Future write(ByteBuffer src) { } private void completeWithZeroInt(A attach, CompletionHandler handler) { - group.submit(() -> handler.completed(0, attach)); + group.execute(() -> handler.completed(0, attach)); } private void completeWithZeroLong(A attach, CompletionHandler handler) { - group.submit(() -> handler.completed(0L, attach)); + group.execute(() -> handler.completed(0L, attach)); } /** diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java index d9b1420a6e3..5150149fa6a 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java @@ -43,6 +43,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; @@ -65,7 +66,7 @@ * instance of this class is a singleton-like object that manages a thread pool that makes it * possible to run a group of asynchronous channels. */ -public class AsynchronousTlsChannelGroup { +public class AsynchronousTlsChannelGroup implements Executor { private static final Logger LOGGER = Loggers.getLogger("connection.tls"); @@ -224,8 +225,16 @@ public AsynchronousTlsChannelGroup(@Nullable final ExecutorService executorServi selectorThread.start(); } - void submit(final Runnable r) { - executor.submit(r); + + @Override + public void execute(final Runnable r) { + executor.execute(() -> { + try { + r.run(); + } catch (Throwable t) { + LOGGER.error(null, t); + } + }); } RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel) { diff --git a/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java b/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java index b26cb396e7b..4b08dd9a15c 100644 --- a/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java +++ b/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java @@ -39,6 +39,8 @@ import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COMMAND_NAME; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CURSOR_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NAMESPACE; +import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.OPERATION_NAME; +import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.OPERATION_SUMMARY; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NETWORK_TRANSPORT; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.QUERY_SUMMARY; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SERVER_ADDRESS; @@ -266,4 +268,47 @@ public Span createTracingSpan(final CommandMessage message, return span; } + + /** + * Creates an operation-level tracing span for a database command. + *

+ * The span is named "{commandName} {database}[.{collection}]" and tagged with standard + * low-cardinality attributes (system, namespace, collection, operation name, operation summary). + * The span is also set on the {@link OperationContext} for use by downstream command-level tracing. + * + * @param transactionSpan the active transaction span (for parent context), or null + * @param operationContext the operation context to attach the span to + * @param commandName the name of the command (e.g. "find", "insert") + * @param namespace the MongoDB namespace for the operation + * @return the created span, or null if tracing is disabled + */ + @Nullable + public Span createOperationSpan(@Nullable final TransactionSpan transactionSpan, + final OperationContext operationContext, final String commandName, final MongoNamespace namespace) { + if (!isEnabled()) { + return null; + } + TraceContext parentContext = null; + if (transactionSpan != null) { + parentContext = transactionSpan.getContext(); + } + String name = commandName + " " + namespace.getDatabaseName() + + (MongoNamespaceHelper.COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName()) + ? "" + : "." + namespace.getCollectionName()); + + KeyValues keyValues = KeyValues.of( + SYSTEM.withValue("mongodb"), + NAMESPACE.withValue(namespace.getDatabaseName())); + if (!MongoNamespaceHelper.COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName())) { + keyValues = keyValues.and(COLLECTION.withValue(namespace.getCollectionName())); + } + keyValues = keyValues.and(OPERATION_NAME.withValue(commandName), + OPERATION_SUMMARY.withValue(name)); + + Span span = addSpan(name, parentContext, namespace); + span.tagLowCardinality(keyValues); + operationContext.setTracingSpan(span); + return span; + } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy index 2d7dc04d758..f1585f82595 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy @@ -52,6 +52,7 @@ class CommandHelperSpecification extends Specification { } def cleanup() { + InternalStreamConnection.setRecordEverything(false) connection?.close() } @@ -81,5 +82,4 @@ class CommandHelperSpecification extends Specification { !receivedDocument receivedException instanceof MongoCommandException } - } diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java index fc5926b3bad..81e778b4a61 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java @@ -127,7 +127,7 @@ public void shouldThrowOnTimeout() throws InterruptedException { // when TimeoutTrackingConnectionGetter connectionGetter = new TimeoutTrackingConnectionGetter(provider, timeoutSettings); - cachedExecutor.submit(connectionGetter); + cachedExecutor.execute(connectionGetter); connectionGetter.getLatch().await(); @@ -152,7 +152,7 @@ public void shouldNotUseMaxAwaitTimeMSWhenTimeoutMsIsSet() throws InterruptedExc // when TimeoutTrackingConnectionGetter connectionGetter = new TimeoutTrackingConnectionGetter(provider, timeoutSettings); - cachedExecutor.submit(connectionGetter); + cachedExecutor.execute(connectionGetter); sleep(70); // wait for more than maxWaitTimeMS but less than timeoutMs. internalConnection.close(); diff --git a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java index be4526aada7..5f736f421c2 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java @@ -331,9 +331,10 @@ static Stream shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS() ); } - @ParameterizedTest - @MethodSource @DisplayName("should choose timeoutMS when timeoutMS is less than connectTimeoutMS") + @ParameterizedTest(name = "should choose timeoutMS when timeoutMS is less than connectTimeoutMS. " + + "Parameters: connectTimeoutMS: {0}, timeoutMS: {1}, expected: {2}") + @MethodSource void shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS(final Long connectTimeoutMS, final Long timeoutMS, final long expected) { @@ -345,7 +346,7 @@ void shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS(final Long connectTim 0)); long calculatedTimeoutMS = timeoutContext.getConnectTimeoutMs(); - assertTrue(expected - calculatedTimeoutMS <= 1); + assertTrue(expected - calculatedTimeoutMS <= 2); } private TimeoutContextTest() { diff --git a/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt b/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt index 6c53a1faf47..cbe308eece0 100644 --- a/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt +++ b/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt @@ -19,6 +19,7 @@ import com.mongodb.ClientSessionOptions import com.mongodb.ServerAddress import com.mongodb.TransactionOptions import com.mongodb.internal.TimeoutContext +import com.mongodb.internal.observability.micrometer.TransactionSpan import com.mongodb.reactivestreams.client.ClientSession as reactiveClientSession import com.mongodb.session.ClientSession as jClientSession import com.mongodb.session.ServerSession @@ -58,6 +59,9 @@ public class ClientSession(public val wrapped: reactiveClientSession) : jClientS */ public fun notifyOperationInitiated(operation: Any): Unit = wrapped.notifyOperationInitiated(operation) + /** Get the transaction span (if started). */ + public fun getTransactionSpan(): TransactionSpan? = wrapped.transactionSpan + /** * Get the server address of the pinned mongos on this session. For internal use only. * diff --git a/driver-reactive-streams/build.gradle.kts b/driver-reactive-streams/build.gradle.kts index dab192e2583..b55dd95d683 100644 --- a/driver-reactive-streams/build.gradle.kts +++ b/driver-reactive-streams/build.gradle.kts @@ -15,6 +15,7 @@ */ import ProjectExtensions.configureJarManifest import ProjectExtensions.configureMavenPublication +import project.DEFAULT_JAVA_VERSION plugins { id("project.java") @@ -36,6 +37,9 @@ dependencies { implementation(libs.project.reactor.core) compileOnly(project(path = ":mongodb-crypt", configuration = "default")) + optionalImplementation(platform(libs.micrometer.observation.bom)) + optionalImplementation(libs.micrometer.observation) + testImplementation(libs.project.reactor.test) testImplementation(project(path = ":driver-sync", configuration = "default")) testImplementation(project(path = ":bson", configuration = "testArtifacts")) @@ -45,11 +49,20 @@ dependencies { // Reactive Streams TCK testing testImplementation(libs.reactive.streams.tck) - // Tracing + // Tracing testing testImplementation(platform(libs.micrometer.tracing.integration.test.bom)) testImplementation(libs.micrometer.tracing.integration.test) { exclude(group = "org.junit.jupiter") } } +tasks.withType { + // Needed for MicrometerProseTest to set env variable programmatically (calls + // `field.setAccessible(true)`) + val testJavaVersion: Int = findProperty("javaVersion")?.toString()?.toInt() ?: DEFAULT_JAVA_VERSION + if (testJavaVersion >= DEFAULT_JAVA_VERSION) { + jvmArgs("--add-opens=java.base/java.util=ALL-UNNAMED") + } +} + configureMavenPublication { pom { name.set("The MongoDB Reactive Streams Driver") diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java index 3d9354e9ae9..fe58864fad0 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java @@ -18,6 +18,8 @@ package com.mongodb.reactivestreams.client; import com.mongodb.TransactionOptions; +import com.mongodb.internal.observability.micrometer.TransactionSpan; +import com.mongodb.lang.Nullable; import org.reactivestreams.Publisher; /** @@ -94,4 +96,13 @@ public interface ClientSession extends com.mongodb.session.ClientSession { * @mongodb.server.release 4.0 */ Publisher abortTransaction(); + + /** + * Get the transaction span (if started). + * + * @return the transaction span + * @since 5.7 + */ + @Nullable + TransactionSpan getTransactionSpan(); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java index 30714a6a576..b5e94c02975 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java @@ -18,6 +18,7 @@ import com.mongodb.ClientSessionOptions; import com.mongodb.TransactionOptions; +import com.mongodb.internal.observability.micrometer.TracingManager; import com.mongodb.internal.session.ServerSessionPool; import com.mongodb.lang.Nullable; import com.mongodb.reactivestreams.client.ClientSession; @@ -31,10 +32,13 @@ public class ClientSessionHelper { private final MongoClientImpl mongoClient; private final ServerSessionPool serverSessionPool; + private final TracingManager tracingManager; - public ClientSessionHelper(final MongoClientImpl mongoClient, final ServerSessionPool serverSessionPool) { + public ClientSessionHelper(final MongoClientImpl mongoClient, final ServerSessionPool serverSessionPool, + final TracingManager tracingManager) { this.mongoClient = mongoClient; this.serverSessionPool = serverSessionPool; + this.tracingManager = tracingManager; } Mono withClientSession(@Nullable final ClientSession clientSessionFromOperation, final OperationExecutor executor) { @@ -62,6 +66,6 @@ ClientSession createClientSession(final ClientSessionOptions options, final Oper .readPreference(mongoClient.getSettings().getReadPreference()) .build())) .build(); - return new ClientSessionPublisherImpl(serverSessionPool, mongoClient, mergedOptions, executor); + return new ClientSessionPublisherImpl(serverSessionPool, mongoClient, mergedOptions, executor, tracingManager); } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java index 5cf0ea103bd..511f9f62c6b 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java @@ -24,6 +24,8 @@ import com.mongodb.TransactionOptions; import com.mongodb.WriteConcern; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.observability.micrometer.TracingManager; +import com.mongodb.internal.observability.micrometer.TransactionSpan; import com.mongodb.internal.operation.AbortTransactionOperation; import com.mongodb.internal.operation.CommitTransactionOperation; import com.mongodb.internal.operation.ReadOperation; @@ -48,17 +50,21 @@ final class ClientSessionPublisherImpl extends BaseClientSessionImpl implements private final MongoClientImpl mongoClient; private final OperationExecutor executor; + private final TracingManager tracingManager; private TransactionState transactionState = TransactionState.NONE; private boolean messageSentInCurrentTransaction; private boolean commitInProgress; private TransactionOptions transactionOptions; + @Nullable + private TransactionSpan transactionSpan; ClientSessionPublisherImpl(final ServerSessionPool serverSessionPool, final MongoClientImpl mongoClient, - final ClientSessionOptions options, final OperationExecutor executor) { + final ClientSessionOptions options, final OperationExecutor executor, final TracingManager tracingManager) { super(serverSessionPool, mongoClient, options); this.executor = executor; this.mongoClient = mongoClient; + this.tracingManager = tracingManager; } @Override @@ -128,6 +134,10 @@ public void startTransaction(final TransactionOptions transactionOptions) { if (!writeConcern.isAcknowledged()) { throw new MongoClientException("Transactions do not support unacknowledged write concern"); } + + if (tracingManager.isEnabled()) { + transactionSpan = new TransactionSpan(tracingManager); + } clearTransactionContext(); setTimeoutContext(timeoutContext); } @@ -152,6 +162,9 @@ public Publisher commitTransaction() { } if (!messageSentInCurrentTransaction) { cleanupTransaction(TransactionState.COMMITTED); + if (transactionSpan != null) { + transactionSpan.finalizeTransactionSpan(TransactionState.COMMITTED.name()); + } return Mono.create(MonoSink::success); } else { ReadConcern readConcern = transactionOptions.getReadConcern(); @@ -171,7 +184,17 @@ public Publisher commitTransaction() { commitInProgress = false; transactionState = TransactionState.COMMITTED; }) - .doOnError(MongoException.class, this::clearTransactionContextOnError); + .doOnError(MongoException.class, e -> { + clearTransactionContextOnError(e); + if (transactionSpan != null) { + transactionSpan.handleTransactionSpanError(e); + } + }) + .doOnSuccess(v -> { + if (transactionSpan != null) { + transactionSpan.finalizeTransactionSpan(TransactionState.COMMITTED.name()); + } + }); } }); } @@ -191,6 +214,9 @@ public Publisher abortTransaction() { } if (!messageSentInCurrentTransaction) { cleanupTransaction(TransactionState.ABORTED); + if (transactionSpan != null) { + transactionSpan.finalizeTransactionSpan(TransactionState.ABORTED.name()); + } return Mono.create(MonoSink::success); } else { ReadConcern readConcern = transactionOptions.getReadConcern(); @@ -208,6 +234,9 @@ public Publisher abortTransaction() { .doOnTerminate(() -> { clearTransactionContext(); cleanupTransaction(TransactionState.ABORTED); + if (transactionSpan != null) { + transactionSpan.finalizeTransactionSpan(TransactionState.ABORTED.name()); + } }); } }); @@ -219,6 +248,12 @@ private void clearTransactionContextOnError(final MongoException e) { } } + @Override + @Nullable + public TransactionSpan getTransactionSpan() { + return transactionSpan; + } + @Override public void close() { if (transactionState == TransactionState.IN) { diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java index 07a17badcd7..8fda2e9294d 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java @@ -33,6 +33,7 @@ import com.mongodb.internal.connection.Cluster; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; +import com.mongodb.internal.observability.micrometer.TracingManager; import com.mongodb.internal.session.ServerSessionPool; import com.mongodb.lang.Nullable; import com.mongodb.reactivestreams.client.ChangeStreamPublisher; @@ -88,9 +89,10 @@ private MongoClientImpl(final MongoClientSettings settings, final MongoDriverInf notNull("settings", settings); notNull("cluster", cluster); + TracingManager tracingManager = new TracingManager(settings.getObservabilitySettings()); TimeoutSettings timeoutSettings = TimeoutSettings.create(settings); ServerSessionPool serverSessionPool = new ServerSessionPool(cluster, timeoutSettings, settings.getServerApi()); - ClientSessionHelper clientSessionHelper = new ClientSessionHelper(this, serverSessionPool); + ClientSessionHelper clientSessionHelper = new ClientSessionHelper(this, serverSessionPool, tracingManager); AutoEncryptionSettings autoEncryptSettings = settings.getAutoEncryptionSettings(); Crypt crypt = autoEncryptSettings != null ? Crypts.createCrypt(settings, autoEncryptSettings) : null; @@ -100,7 +102,8 @@ private MongoClientImpl(final MongoClientSettings settings, final MongoDriverInf + ReactiveContextProvider.class.getName() + " when using the Reactive Streams driver"); } OperationExecutor operationExecutor = executor != null ? executor - : new OperationExecutorImpl(this, clientSessionHelper, timeoutSettings, (ReactiveContextProvider) contextProvider); + : new OperationExecutorImpl(this, clientSessionHelper, timeoutSettings, (ReactiveContextProvider) contextProvider, + tracingManager); MongoOperationPublisher mongoOperationPublisher = new MongoOperationPublisher<>(Document.class, withUuidRepresentation(settings.getCodecRegistry(), settings.getUuidRepresentation()), diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java index ef18c2c6b1f..62a4431cc9a 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java @@ -31,10 +31,11 @@ import com.mongodb.internal.binding.AsyncReadWriteBinding; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext; +import com.mongodb.internal.observability.micrometer.Span; +import com.mongodb.internal.observability.micrometer.TracingManager; import com.mongodb.internal.operation.OperationHelper; import com.mongodb.internal.operation.ReadOperation; import com.mongodb.internal.operation.WriteOperation; -import com.mongodb.internal.observability.micrometer.TracingManager; import com.mongodb.lang.Nullable; import com.mongodb.reactivestreams.client.ClientSession; import com.mongodb.reactivestreams.client.ReactiveContextProvider; @@ -63,13 +64,16 @@ public class OperationExecutorImpl implements OperationExecutor { @Nullable private final ReactiveContextProvider contextProvider; private final TimeoutSettings timeoutSettings; + private final TracingManager tracingManager; OperationExecutorImpl(final MongoClientImpl mongoClient, final ClientSessionHelper clientSessionHelper, - final TimeoutSettings timeoutSettings, @Nullable final ReactiveContextProvider contextProvider) { + final TimeoutSettings timeoutSettings, @Nullable final ReactiveContextProvider contextProvider, + final TracingManager tracingManager) { this.mongoClient = mongoClient; this.clientSessionHelper = clientSessionHelper; this.timeoutSettings = timeoutSettings; this.contextProvider = contextProvider; + this.tracingManager = tracingManager; } @Override @@ -93,22 +97,37 @@ public Mono execute(final ReadOperation operation, final ReadPrefer OperationContext operationContext = getOperationContext(requestContext, actualClientSession, readConcern, operation.getCommandName()) .withSessionContext(new ClientSessionBinding.AsyncClientSessionContext(actualClientSession, isImplicitSession(session), readConcern)); + Span span = tracingManager.createOperationSpan(actualClientSession.getTransactionSpan(), + operationContext, operation.getCommandName(), operation.getNamespace()); if (session != null && session.hasActiveTransaction() && !binding.getReadPreference().equals(primary())) { binding.release(); - return Mono.error(new MongoClientException("Read preference in a transaction must be primary")); + MongoClientException error = new MongoClientException("Read preference in a transaction must be primary"); + if (span != null) { + span.error(error); + span.end(); + } + return Mono.error(error); } else { return Mono.create(sink -> operation.executeAsync(binding, operationContext, (result, t) -> { try { binding.release(); } finally { + if (t != null) { + Throwable exceptionToHandle = t instanceof MongoException + ? OperationHelper.unwrap((MongoException) t) : t; + labelException(session, exceptionToHandle); + unpinServerAddressOnTransientTransactionError(session, exceptionToHandle); + if (span != null) { + span.error(t); + } + } + if (span != null) { + span.end(); + } sinkToCallback(sink).onResult(result, t); } - })).doOnError((t) -> { - Throwable exceptionToHandle = t instanceof MongoException ? OperationHelper.unwrap((MongoException) t) : t; - labelException(session, exceptionToHandle); - unpinServerAddressOnTransientTransactionError(session, exceptionToHandle); - }); + })); } }).subscribe(subscriber) ); @@ -133,18 +152,28 @@ public Mono execute(final WriteOperation operation, final ReadConcern OperationContext operationContext = getOperationContext(requestContext, actualClientSession, readConcern, operation.getCommandName()) .withSessionContext(new ClientSessionBinding.AsyncClientSessionContext(actualClientSession, isImplicitSession(session), readConcern)); + Span span = tracingManager.createOperationSpan(actualClientSession.getTransactionSpan(), + operationContext, operation.getCommandName(), operation.getNamespace()); return Mono.create(sink -> operation.executeAsync(binding, operationContext, (result, t) -> { try { binding.release(); } finally { + if (t != null) { + Throwable exceptionToHandle = t instanceof MongoException + ? OperationHelper.unwrap((MongoException) t) : t; + labelException(session, exceptionToHandle); + unpinServerAddressOnTransientTransactionError(session, exceptionToHandle); + if (span != null) { + span.error(t); + } + } + if (span != null) { + span.end(); + } sinkToCallback(sink).onResult(result, t); } - })).doOnError((t) -> { - Throwable exceptionToHandle = t instanceof MongoException ? OperationHelper.unwrap((MongoException) t) : t; - labelException(session, exceptionToHandle); - unpinServerAddressOnTransientTransactionError(session, exceptionToHandle); - }); + })); } ).subscribe(subscriber) ); @@ -155,7 +184,7 @@ public OperationExecutor withTimeoutSettings(final TimeoutSettings newTimeoutSet if (Objects.equals(timeoutSettings, newTimeoutSettings)) { return this; } - return new OperationExecutorImpl(mongoClient, clientSessionHelper, newTimeoutSettings, contextProvider); + return new OperationExecutorImpl(mongoClient, clientSessionHelper, newTimeoutSettings, contextProvider, tracingManager); } @Override @@ -214,7 +243,7 @@ private OperationContext getOperationContext(final RequestContext requestContext requestContext, new ReadConcernAwareNoOpSessionContext(readConcern), createTimeoutContext(session, timeoutSettings), - TracingManager.NO_OP, + tracingManager, mongoClient.getSettings().getServerApi(), commandName); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java index bc4da3026a9..cefdf7184d8 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java @@ -55,8 +55,14 @@ public static MongoCollection collectionWithTimeout(final MongoCollection public static Mono> collectionWithTimeoutMono(final MongoCollection collection, @Nullable final Timeout timeout) { + return collectionWithTimeoutMono(collection, timeout, DEFAULT_TIMEOUT_MESSAGE); + } + + public static Mono> collectionWithTimeoutMono(final MongoCollection collection, + @Nullable final Timeout timeout, + final String message) { try { - return Mono.just(collectionWithTimeout(collection, timeout)); + return Mono.just(collectionWithTimeout(collection, timeout, message)); } catch (MongoOperationTimeoutException e) { return Mono.error(e); } @@ -64,9 +70,14 @@ public static Mono> collectionWithTimeoutMono(final Mongo public static Mono> collectionWithTimeoutDeferred(final MongoCollection collection, @Nullable final Timeout timeout) { - return Mono.defer(() -> collectionWithTimeoutMono(collection, timeout)); + return collectionWithTimeoutDeferred(collection, timeout, DEFAULT_TIMEOUT_MESSAGE); } + public static Mono> collectionWithTimeoutDeferred(final MongoCollection collection, + @Nullable final Timeout timeout, + final String message) { + return Mono.defer(() -> collectionWithTimeoutMono(collection, timeout, message)); + } public static MongoDatabase databaseWithTimeout(final MongoDatabase database, @Nullable final Timeout timeout) { diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java index 7d9a46cdf3f..50586e92102 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java @@ -54,7 +54,8 @@ */ public final class GridFSUploadPublisherImpl implements GridFSUploadPublisher { - private static final String TIMEOUT_ERROR_MESSAGE = "Saving chunks exceeded the timeout limit."; + private static final String TIMEOUT_ERROR_MESSAGE_CHUNKS_SAVING = "Saving chunks exceeded the timeout limit."; + private static final String TIMEOUT_ERROR_MESSAGE_UPLOAD_CANCELLATION = "Upload cancellation exceeded the timeout limit."; private static final Document PROJECTION = new Document("_id", 1); private static final Document FILES_INDEX = new Document("filename", 1).append("uploadDate", 1); private static final Document CHUNKS_INDEX = new Document("files_id", 1).append("n", 1); @@ -226,8 +227,8 @@ private Mono createSaveChunksMono(final AtomicBoolean terminated, @Nullabl .append("data", data); Publisher insertOnePublisher = clientSession == null - ? collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(chunkDocument) - : collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE) + ? collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE_CHUNKS_SAVING).insertOne(chunkDocument) + : collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE_CHUNKS_SAVING) .insertOne(clientSession, chunkDocument); return Mono.from(insertOnePublisher).thenReturn(data.length()); @@ -270,7 +271,8 @@ private Mono createSaveFileDataMono(final AtomicBoolean termina } private Mono createCancellationMono(final AtomicBoolean terminated, @Nullable final Timeout timeout) { - Mono> chunksCollectionMono = collectionWithTimeoutDeferred(chunksCollection, timeout); + Mono> chunksCollectionMono = collectionWithTimeoutDeferred(chunksCollection, timeout, + TIMEOUT_ERROR_MESSAGE_UPLOAD_CANCELLATION); if (terminated.compareAndSet(false, true)) { if (clientSession != null) { return chunksCollectionMono.flatMap(collection -> Mono.from(collection diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java index b922ec20b71..90446953fc1 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java @@ -16,7 +16,6 @@ package com.mongodb.reactivestreams.client; -import com.mongodb.ClusterFixture; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCommandException; import com.mongodb.MongoNamespace; @@ -24,7 +23,6 @@ import com.mongodb.ReadPreference; import com.mongodb.WriteConcern; import com.mongodb.client.AbstractClientSideOperationsTimeoutProseTest; -import com.mongodb.client.model.CreateCollectionOptions; import com.mongodb.client.model.changestream.FullDocument; import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandStartedEvent; @@ -43,6 +41,7 @@ import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.nio.ByteBuffer; @@ -58,12 +57,16 @@ import static com.mongodb.ClusterFixture.TIMEOUT_DURATION; import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet; +import static com.mongodb.ClusterFixture.isStandalone; import static com.mongodb.ClusterFixture.serverVersionAtLeast; import static com.mongodb.ClusterFixture.sleep; +import static com.mongodb.assertions.Assertions.assertTrue; import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeFalse; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -104,7 +107,6 @@ protected boolean isAsync() { @Override public void testGridFSUploadViaOpenUploadStreamTimeout() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); //given collectionHelper.runAdminCommand("{" @@ -113,12 +115,12 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() { + " data: {" + " failCommands: [\"insert\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 405) + + " blockTimeMS: " + 600 + " }" + "}"); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 400, TimeUnit.MILLISECONDS))) { + .timeout(600, TimeUnit.MILLISECONDS))) { MongoDatabase database = client.getDatabase(gridFsFileNamespace.getDatabaseName()); GridFSBucket gridFsBucket = createReaciveGridFsBucket(database, GRID_FS_BUCKET_NAME); @@ -158,7 +160,6 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() { @Override public void testAbortingGridFsUploadStreamTimeout() throws ExecutionException, InterruptedException, TimeoutException { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); //given CompletableFuture droppedErrorFuture = new CompletableFuture<>(); @@ -170,12 +171,12 @@ public void testAbortingGridFsUploadStreamTimeout() throws ExecutionException, I + " data: {" + " failCommands: [\"delete\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 405) + + " blockTimeMS: " + 405 + " }" + "}"); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 400, TimeUnit.MILLISECONDS))) { + .timeout(400, TimeUnit.MILLISECONDS))) { MongoDatabase database = client.getDatabase(gridFsFileNamespace.getDatabaseName()); GridFSBucket gridFsBucket = createReaciveGridFsBucket(database, GRID_FS_BUCKET_NAME); @@ -198,12 +199,25 @@ public void testAbortingGridFsUploadStreamTimeout() throws ExecutionException, I //then Throwable droppedError = droppedErrorFuture.get(TIMEOUT_DURATION.toMillis(), TimeUnit.MILLISECONDS); Throwable commandError = droppedError.getCause(); - assertInstanceOf(MongoOperationTimeoutException.class, commandError); CommandFailedEvent deleteFailedEvent = commandListener.getCommandFailedEvent("delete"); assertNotNull(deleteFailedEvent); - assertEquals(commandError, commandListener.getCommandFailedEvent("delete").getThrowable()); + CommandStartedEvent deleteStartedEvent = commandListener.getCommandStartedEvent("delete"); + assertTrue(deleteStartedEvent.getCommand().containsKey("maxTimeMS"), "Expected delete command to have maxTimeMS"); + long deleteMaxTimeMS = deleteStartedEvent + .getCommand() + .get("maxTimeMS") + .asNumber() + .longValue(); + + assertTrue(deleteMaxTimeMS <= 420 + // some leeway for timing variations, when compression is used it is often less then 300. + // Without it, it is more than 300. + && deleteMaxTimeMS >= 150, + "Expected maxTimeMS for delete command to be between 150s and 420ms, " + "but was: " + deleteMaxTimeMS + "ms"); + assertEquals(commandError, deleteFailedEvent.getThrowable()); + // When subscription is cancelled, we should not receive any more events. testSubscriber.assertNoTerminalEvent(); } @@ -219,9 +233,8 @@ public void testTimeoutMSAppliesToFullResumeAttemptInNextCall() { assumeTrue(isDiscoverableReplicaSet()); //given - long rtt = ClusterFixture.getPrimaryRTT(); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 500, TimeUnit.MILLISECONDS))) { + .timeout(500, TimeUnit.MILLISECONDS))) { MongoNamespace namespace = generateNamespace(); MongoCollection collection = client.getDatabase(namespace.getDatabaseName()) @@ -273,9 +286,8 @@ public void testTimeoutMSAppliedToInitialAggregate() { assumeTrue(isDiscoverableReplicaSet()); //given - long rtt = ClusterFixture.getPrimaryRTT(); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 200, TimeUnit.MILLISECONDS))) { + .timeout(200, TimeUnit.MILLISECONDS))) { MongoNamespace namespace = generateNamespace(); MongoCollection collection = client.getDatabase(namespace.getDatabaseName()) @@ -290,7 +302,7 @@ public void testTimeoutMSAppliedToInitialAggregate() { + " data: {" + " failCommands: [\"aggregate\" ]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 201) + + " blockTimeMS: " + 201 + " }" + "}"); @@ -321,13 +333,10 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsNotSet() { //given BsonTimestamp startTime = new BsonTimestamp((int) Instant.now().getEpochSecond(), 0); - collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); sleep(2000); - - long rtt = ClusterFixture.getPrimaryRTT(); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 300, TimeUnit.MILLISECONDS))) { + .timeout(500, TimeUnit.MILLISECONDS))) { MongoCollection collection = client.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary()); @@ -338,7 +347,7 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsNotSet() { + " data: {" + " failCommands: [\"getMore\", \"aggregate\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 200) + + " blockTimeMS: " + 200 + " }" + "}"); @@ -389,12 +398,10 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsSet() { //given BsonTimestamp startTime = new BsonTimestamp((int) Instant.now().getEpochSecond(), 0); - collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); sleep(2000); - long rtt = ClusterFixture.getPrimaryRTT(); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 300, TimeUnit.MILLISECONDS))) { + .timeout(500, TimeUnit.MILLISECONDS))) { MongoCollection collection = client.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()) @@ -406,7 +413,7 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsSet() { + " data: {" + " failCommands: [\"aggregate\", \"getMore\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 200) + + " blockTimeMS: " + 200 + " }" + "}"); @@ -449,9 +456,8 @@ public void testTimeoutMsISHonoredForNnextOperationWhenSeveralGetMoreExecutedInt assumeTrue(isDiscoverableReplicaSet()); //given - long rtt = ClusterFixture.getPrimaryRTT(); try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder() - .timeout(rtt + 2500, TimeUnit.MILLISECONDS))) { + .timeout(2500, TimeUnit.MILLISECONDS))) { MongoCollection collection = client.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary()); @@ -468,7 +474,78 @@ public void testTimeoutMsISHonoredForNnextOperationWhenSeveralGetMoreExecutedInt List commandStartedEvents = commandListener.getCommandStartedEvents(); assertCommandStartedEventsInOder(Arrays.asList("aggregate", "getMore", "getMore", "getMore", "killCursors"), commandStartedEvents); - assertOnlyOneCommandTimeoutFailure("getMore"); + + } + } + + @DisplayName("9. End Session. The timeout specified via the MongoClient timeoutMS option") + @Test + @Override + public void test9EndSessionClientTimeout() { + assumeTrue(serverVersionAtLeast(4, 4)); + assumeFalse(isStandalone()); + + collectionHelper.runAdminCommand("{" + + " configureFailPoint: \"failCommand\"," + + " mode: { times: 1 }," + + " data: {" + + " failCommands: [\"abortTransaction\"]," + + " blockConnection: true," + + " blockTimeMS: " + 400 + + " }" + + "}"); + + try (MongoClient mongoClient = createReactiveClient(getMongoClientSettingsBuilder().retryWrites(false) + .timeout(300, TimeUnit.MILLISECONDS))) { + MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) + .getCollection(namespace.getCollectionName()); + + try (ClientSession session = Mono.from(mongoClient.startSession()).block()) { + session.startTransaction(); + Mono.from(collection.insertOne(session, new Document("x", 1))).block(); + } + + sleep(postSessionCloseSleep()); + CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() -> commandListener.getCommandFailedEvent("abortTransaction")); + long elapsedTime = abortTransactionEvent.getElapsedTime(TimeUnit.MILLISECONDS); + assertInstanceOf(MongoOperationTimeoutException.class, abortTransactionEvent.getThrowable()); + assertTrue(elapsedTime <= 400, "Took too long to time out, elapsedMS: " + elapsedTime); + } + } + + @Test + @DisplayName("9. End Session. The timeout specified via the ClientSession defaultTimeoutMS option") + @Override + public void test9EndSessionSessionTimeout() { + assumeTrue(serverVersionAtLeast(4, 4)); + assumeFalse(isStandalone()); + + collectionHelper.runAdminCommand("{" + + " configureFailPoint: \"failCommand\"," + + " mode: { times: 1 }," + + " data: {" + + " failCommands: [\"abortTransaction\"]," + + " blockConnection: true," + + " blockTimeMS: " + 400 + + " }" + + "}"); + + try (MongoClient mongoClient = createReactiveClient(getMongoClientSettingsBuilder())) { + MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) + .getCollection(namespace.getCollectionName()); + + try (ClientSession session = Mono.from(mongoClient.startSession(com.mongodb.ClientSessionOptions.builder() + .defaultTimeout(300, TimeUnit.MILLISECONDS).build())).block()) { + + session.startTransaction(); + Mono.from(collection.insertOne(session, new Document("x", 1))).block(); + } + + sleep(postSessionCloseSleep()); + CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() -> commandListener.getCommandFailedEvent("abortTransaction")); + long elapsedTime = abortTransactionEvent.getElapsedTime(TimeUnit.MILLISECONDS); + assertInstanceOf(MongoOperationTimeoutException.class, abortTransactionEvent.getThrowable()); + assertTrue(elapsedTime <= 400, "Took too long to time out, elapsedMS: " + elapsedTime); } } @@ -512,6 +589,6 @@ public void tearDown() throws InterruptedException { @Override protected int postSessionCloseSleep() { - return 256; + return 1000; } } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java new file mode 100644 index 00000000000..c58bb98f2cc --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java @@ -0,0 +1,32 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client.observability; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.AbstractMicrometerProseTest; +import com.mongodb.client.MongoClient; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; + +/** + * Reactive Streams driver implementation of the Micrometer prose tests. + */ +public class MicrometerProseTest extends AbstractMicrometerProseTest { + @Override + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(settings); + } +} diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java index e1d765150a7..473d57a3878 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java @@ -192,7 +192,7 @@ public TimeoutContext getTimeoutContext() { @Override @Nullable public TransactionSpan getTransactionSpan() { - return null; + return wrapped.getTransactionSpan(); } private static void sleep(final long millis) { diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java new file mode 100644 index 00000000000..bf2e6205ad6 --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client.unified; + +import org.junit.jupiter.params.provider.Arguments; + +import java.util.Collection; + +final class MicrometerTracingTest extends UnifiedReactiveStreamsTest { + private static Collection data() { + return getTestData("open-telemetry/tests"); + } +} diff --git a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java index c192ae17896..0fda131f4ff 100644 --- a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java +++ b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java @@ -25,6 +25,7 @@ import com.mongodb.internal.connection.ClientMetadata; import com.mongodb.internal.connection.Cluster; import com.mongodb.internal.mockito.MongoMockito; +import com.mongodb.internal.observability.micrometer.TracingManager; import com.mongodb.internal.session.ServerSessionPool; import com.mongodb.reactivestreams.client.ChangeStreamPublisher; import com.mongodb.reactivestreams.client.ClientSession; @@ -179,7 +180,7 @@ void testWatch() { @Test void testStartSession() { ServerSessionPool serverSessionPool = mock(ServerSessionPool.class); - ClientSessionHelper clientSessionHelper = new ClientSessionHelper(mongoClient, serverSessionPool); + ClientSessionHelper clientSessionHelper = new ClientSessionHelper(mongoClient, serverSessionPool, TracingManager.NO_OP); assertAll("Start Session Tests", () -> assertAll("check validation", diff --git a/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java b/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java index 920feb1f986..eb36678761a 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java +++ b/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java @@ -22,7 +22,6 @@ import com.mongodb.MongoClientException; import com.mongodb.MongoException; import com.mongodb.MongoInternalException; -import com.mongodb.MongoNamespace; import com.mongodb.MongoQueryException; import com.mongodb.MongoSocketException; import com.mongodb.MongoTimeoutException; @@ -53,17 +52,14 @@ import com.mongodb.internal.connection.Cluster; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext; +import com.mongodb.internal.observability.micrometer.Span; +import com.mongodb.internal.observability.micrometer.TracingManager; import com.mongodb.internal.operation.OperationHelper; import com.mongodb.internal.operation.Operations; import com.mongodb.internal.operation.ReadOperation; import com.mongodb.internal.operation.WriteOperation; import com.mongodb.internal.session.ServerSessionPool; -import com.mongodb.internal.observability.micrometer.Span; -import com.mongodb.internal.observability.micrometer.TraceContext; -import com.mongodb.internal.observability.micrometer.TracingManager; -import com.mongodb.internal.observability.micrometer.TransactionSpan; import com.mongodb.lang.Nullable; -import io.micrometer.common.KeyValues; import org.bson.BsonDocument; import org.bson.Document; import org.bson.UuidRepresentation; @@ -77,17 +73,11 @@ import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL; import static com.mongodb.MongoException.UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL; -import static com.mongodb.internal.MongoNamespaceHelper.COMMAND_COLLECTION_NAME; import static com.mongodb.ReadPreference.primary; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.isTrueArgument; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.TimeoutContext.createTimeoutContext; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COLLECTION; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NAMESPACE; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.OPERATION_NAME; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.OPERATION_SUMMARY; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SYSTEM; final class MongoClusterImpl implements MongoCluster { @Nullable @@ -434,7 +424,8 @@ public T execute(final ReadOperation operation, final ReadPreference r boolean implicitSession = isImplicitSession(session); OperationContext operationContext = getOperationContext(actualClientSession, readConcern, operation.getCommandName()) .withSessionContext(new ClientSessionBinding.SyncClientSessionContext(actualClientSession, readConcern, implicitSession)); - Span span = createOperationSpan(actualClientSession, operationContext, operation.getCommandName(), operation.getNamespace()); + Span span = operationContext.getTracingManager().createOperationSpan( + actualClientSession.getTransactionSpan(), operationContext, operation.getCommandName(), operation.getNamespace()); ReadBinding binding = getReadBinding(readPreference, actualClientSession, implicitSession); @@ -469,7 +460,8 @@ public T execute(final WriteOperation operation, final ReadConcern readCo ClientSession actualClientSession = getClientSession(session); OperationContext operationContext = getOperationContext(actualClientSession, readConcern, operation.getCommandName()) .withSessionContext(new ClientSessionBinding.SyncClientSessionContext(actualClientSession, readConcern, isImplicitSession(session))); - Span span = createOperationSpan(actualClientSession, operationContext, operation.getCommandName(), operation.getNamespace()); + Span span = operationContext.getTracingManager().createOperationSpan( + actualClientSession.getTransactionSpan(), operationContext, operation.getCommandName(), operation.getNamespace()); WriteBinding binding = getWriteBinding(actualClientSession, isImplicitSession(session)); try { @@ -587,48 +579,6 @@ ClientSession getClientSession(@Nullable final ClientSession clientSessionFromOp return session; } - /** - * Create a tracing span for the given operation, and set it on operation context. - * - * @param actualClientSession the session that the operation is part of - * @param operationContext the operation context for the operation - * @param commandName the name of the command - * @param namespace the namespace of the command - * @return the created span, or null if tracing is not enabled - */ - @Nullable - private Span createOperationSpan(final ClientSession actualClientSession, final OperationContext operationContext, final String commandName, final MongoNamespace namespace) { - TracingManager tracingManager = operationContext.getTracingManager(); - if (tracingManager.isEnabled()) { - TraceContext parentContext = null; - TransactionSpan transactionSpan = actualClientSession.getTransactionSpan(); - if (transactionSpan != null) { - parentContext = transactionSpan.getContext(); - } - String name = commandName + " " + namespace.getDatabaseName() + (COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName()) - ? "" - : "." + namespace.getCollectionName()); - - KeyValues keyValues = KeyValues.of( - SYSTEM.withValue("mongodb"), - NAMESPACE.withValue(namespace.getDatabaseName())); - if (!COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName())) { - keyValues = keyValues.and(COLLECTION.withValue(namespace.getCollectionName())); - } - keyValues = keyValues.and(OPERATION_NAME.withValue(commandName), - OPERATION_SUMMARY.withValue(name)); - - Span span = tracingManager.addSpan(name, parentContext, namespace); - - span.tagLowCardinality(keyValues); - - operationContext.setTracingSpan(span); - return span; - - } else { - return null; - } - } } private boolean isImplicitSession(@Nullable final ClientSession session) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java index 9ce58b1654f..7828ecde684 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java @@ -56,11 +56,8 @@ import com.mongodb.internal.connection.TestCommandListener; import com.mongodb.internal.connection.TestConnectionPoolListener; import com.mongodb.test.FlakyTest; -import org.bson.BsonArray; -import org.bson.BsonBoolean; import org.bson.BsonDocument; import org.bson.BsonInt32; -import org.bson.BsonString; import org.bson.BsonTimestamp; import org.bson.Document; import org.bson.codecs.BsonDocumentCodec; @@ -256,7 +253,6 @@ public void testBlockingIterationMethodsChangeStream() { assumeFalse(isAsync()); // Async change stream cursor is non-deterministic for cursor::next BsonTimestamp startTime = new BsonTimestamp((int) Instant.now().getEpochSecond(), 0); - collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); sleep(2000); collectionHelper.insertDocuments(singletonList(BsonDocument.parse("{x: 1}")), WriteConcern.MAJORITY); @@ -298,7 +294,6 @@ public void testBlockingIterationMethodsChangeStream() { @FlakyTest(maxAttempts = 3) public void testGridFSUploadViaOpenUploadStreamTimeout() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); collectionHelper.runAdminCommand("{" + " configureFailPoint: \"failCommand\"," @@ -306,7 +301,7 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() { + " data: {" + " failCommands: [\"insert\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 205) + + " blockTimeMS: " + 205 + " }" + "}"); @@ -314,7 +309,7 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() { filesCollectionHelper.create(); try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() - .timeout(rtt + 200, TimeUnit.MILLISECONDS))) { + .timeout(200, TimeUnit.MILLISECONDS))) { MongoDatabase database = client.getDatabase(namespace.getDatabaseName()); GridFSBucket gridFsBucket = createGridFsBucket(database, GRID_FS_BUCKET_NAME); @@ -329,7 +324,6 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() { @Test public void testAbortingGridFsUploadStreamTimeout() throws Throwable { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); collectionHelper.runAdminCommand("{" + " configureFailPoint: \"failCommand\"," @@ -337,7 +331,7 @@ public void testAbortingGridFsUploadStreamTimeout() throws Throwable { + " data: {" + " failCommands: [\"delete\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 305) + + " blockTimeMS: " + 320 + " }" + "}"); @@ -345,7 +339,7 @@ public void testAbortingGridFsUploadStreamTimeout() throws Throwable { filesCollectionHelper.create(); try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() - .timeout(rtt + 300, TimeUnit.MILLISECONDS))) { + .timeout(300, TimeUnit.MILLISECONDS))) { MongoDatabase database = client.getDatabase(namespace.getDatabaseName()); GridFSBucket gridFsBucket = createGridFsBucket(database, GRID_FS_BUCKET_NAME).withChunkSizeBytes(2); @@ -360,7 +354,6 @@ public void testAbortingGridFsUploadStreamTimeout() throws Throwable { @Test public void testGridFsDownloadStreamTimeout() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); chunksCollectionHelper.create(); filesCollectionHelper.create(); @@ -382,18 +375,19 @@ public void testGridFsDownloadStreamTimeout() { + " metadata: {}" + "}" )), WriteConcern.MAJORITY); + collectionHelper.runAdminCommand("{" + " configureFailPoint: \"failCommand\"," + " mode: { skip: 1 }," + " data: {" + " failCommands: [\"find\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 95) + + " blockTimeMS: " + 500 + " }" + "}"); try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder() - .timeout(rtt + 100, TimeUnit.MILLISECONDS))) { + .timeout(300, TimeUnit.MILLISECONDS))) { MongoDatabase database = client.getDatabase(namespace.getDatabaseName()); GridFSBucket gridFsBucket = createGridFsBucket(database, GRID_FS_BUCKET_NAME).withChunkSizeBytes(2); @@ -401,7 +395,9 @@ public void testGridFsDownloadStreamTimeout() { assertThrows(MongoOperationTimeoutException.class, downloadStream::read); List events = commandListener.getCommandStartedEvents(); - List findCommands = events.stream().filter(e -> e.getCommandName().equals("find")).collect(Collectors.toList()); + List findCommands = events.stream() + .filter(e -> e.getCommandName().equals("find")) + .collect(Collectors.toList()); assertEquals(2, findCommands.size()); assertEquals(gridFsFileNamespace.getCollectionName(), findCommands.get(0).getCommand().getString("find").getValue()); @@ -414,7 +410,7 @@ public void testGridFsDownloadStreamTimeout() { @ParameterizedTest(name = "[{index}] {0}") @MethodSource("test8ServerSelectionArguments") public void test8ServerSelection(final String connectionString) { - int timeoutBuffer = 100; // 5 in spec, Java is slower + int timeoutBuffer = 150; // 5 in spec, Java is slower // 1. Create a MongoClient try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder() .applyConnectionString(new ConnectionString(connectionString))) @@ -450,7 +446,7 @@ public void test8ServerSelectionHandshake(final String ignoredTestName, final in + " data: {" + " failCommands: [\"saslContinue\"]," + " blockConnection: true," - + " blockTimeMS: 350" + + " blockTimeMS: 600" + " }" + "}"); @@ -466,7 +462,7 @@ public void test8ServerSelectionHandshake(final String ignoredTestName, final in .insertOne(new Document("x", 1)); }); long elapsed = msElapsedSince(start); - assertTrue(elapsed <= 310, "Took too long to time out, elapsedMS: " + elapsed); + assertTrue(elapsed <= 350, "Took too long to time out, elapsedMS: " + elapsed); } } @@ -483,23 +479,23 @@ public void test9EndSessionClientTimeout() { + " data: {" + " failCommands: [\"abortTransaction\"]," + " blockConnection: true," - + " blockTimeMS: " + 150 + + " blockTimeMS: " + 500 + " }" + "}"); try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder().retryWrites(false) - .timeout(100, TimeUnit.MILLISECONDS))) { - MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) + .timeout(250, TimeUnit.MILLISECONDS))) { + MongoDatabase database = mongoClient.getDatabase(namespace.getDatabaseName()); + MongoCollection collection = database .getCollection(namespace.getCollectionName()); try (ClientSession session = mongoClient.startSession()) { session.startTransaction(); collection.insertOne(session, new Document("x", 1)); - long start = System.nanoTime(); session.close(); - long elapsed = msElapsedSince(start) - postSessionCloseSleep(); - assertTrue(elapsed <= 150, "Took too long to time out, elapsedMS: " + elapsed); + long elapsed = msElapsedSince(start); + assertTrue(elapsed <= 300, "Took too long to time out, elapsedMS: " + elapsed); } } CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() -> @@ -520,7 +516,7 @@ public void test9EndSessionSessionTimeout() { + " data: {" + " failCommands: [\"abortTransaction\"]," + " blockConnection: true," - + " blockTimeMS: " + 150 + + " blockTimeMS: " + 400 + " }" + "}"); @@ -529,14 +525,14 @@ public void test9EndSessionSessionTimeout() { .getCollection(namespace.getCollectionName()); try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder() - .defaultTimeout(100, TimeUnit.MILLISECONDS).build())) { + .defaultTimeout(300, TimeUnit.MILLISECONDS).build())) { session.startTransaction(); collection.insertOne(session, new Document("x", 1)); long start = System.nanoTime(); session.close(); - long elapsed = msElapsedSince(start) - postSessionCloseSleep(); - assertTrue(elapsed <= 150, "Took too long to time out, elapsedMS: " + elapsed); + long elapsed = msElapsedSince(start); + assertTrue(elapsed <= 400, "Took too long to time out, elapsedMS: " + elapsed); } } CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() -> @@ -563,11 +559,12 @@ public void test9EndSessionCustomTesEachOperationHasItsOwnTimeoutWithCommit() { MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()); + int defaultTimeout = 300; try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder() - .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) { + .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) { session.startTransaction(); collection.insertOne(session, new Document("x", 1)); - sleep(200); + sleep(defaultTimeout); assertDoesNotThrow(session::commitTransaction); } @@ -594,11 +591,12 @@ public void test9EndSessionCustomTesEachOperationHasItsOwnTimeoutWithAbort() { MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()); + int defaultTimeout = 300; try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder() - .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) { + .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) { session.startTransaction(); collection.insertOne(session, new Document("x", 1)); - sleep(200); + sleep(defaultTimeout); assertDoesNotThrow(session::close); } @@ -618,12 +616,12 @@ public void test10ConvenientTransactions() { + " data: {" + " failCommands: [\"insert\", \"abortTransaction\"]," + " blockConnection: true," - + " blockTimeMS: " + 150 + + " blockTimeMS: " + 200 + " }" + "}"); try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder() - .timeout(100, TimeUnit.MILLISECONDS))) { + .timeout(150, TimeUnit.MILLISECONDS))) { MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()); @@ -661,12 +659,13 @@ public void test10CustomTestWithTransactionUsesASingleTimeout() { MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()); + int defaultTimeout = 200; try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder() - .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) { + .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) { assertThrows(MongoOperationTimeoutException.class, () -> session.withTransaction(() -> { collection.insertOne(session, new Document("x", 1)); - sleep(200); + sleep(defaultTimeout); return true; }) ); @@ -696,12 +695,13 @@ public void test10CustomTestWithTransactionUsesASingleTimeoutWithLock() { MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()); + int defaultTimeout = 200; try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder() - .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) { + .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) { assertThrows(MongoOperationTimeoutException.class, () -> session.withTransaction(() -> { collection.insertOne(session, new Document("x", 1)); - sleep(200); + sleep(defaultTimeout); return true; }) ); @@ -710,7 +710,7 @@ public void test10CustomTestWithTransactionUsesASingleTimeoutWithLock() { } @DisplayName("11. Multi-batch bulkWrites") - @Test + @FlakyTest(maxAttempts = 3) @SuppressWarnings("try") protected void test11MultiBatchBulkWrites() throws InterruptedException { assumeTrue(serverVersionAtLeast(8, 0)); @@ -718,12 +718,18 @@ protected void test11MultiBatchBulkWrites() throws InterruptedException { // a workaround for https://jira.mongodb.org/browse/DRIVERS-2997, remove this block when the aforementioned bug is fixed client.getDatabase(namespace.getDatabaseName()).drop(); } - BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) - .append("mode", new BsonDocument("times", new BsonInt32(2))) - .append("data", new BsonDocument("failCommands", new BsonArray(singletonList(new BsonString("bulkWrite")))) - .append("blockConnection", BsonBoolean.TRUE) - .append("blockTimeMS", new BsonInt32(2020))); - try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().timeout(4000, TimeUnit.MILLISECONDS)); + BsonDocument failPointDocument = BsonDocument.parse("{" + + " configureFailPoint: \"failCommand\"," + + " mode: { times: 2}," + + " data: {" + + " failCommands: [\"bulkWrite\" ]," + + " blockConnection: true," + + " blockTimeMS: " + 2020 + + " }" + + "}"); + + long timeout = 4000; + try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().timeout(timeout, TimeUnit.MILLISECONDS)); FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) { MongoDatabase db = client.getDatabase(namespace.getDatabaseName()); db.drop(); @@ -746,8 +752,8 @@ protected void test11MultiBatchBulkWrites() throws InterruptedException { * Not a prose spec test. However, it is additional test case for better coverage. */ @Test - @DisplayName("Should ignore wTimeoutMS of WriteConcern to initial and subsequent commitTransaction operations") - public void shouldIgnoreWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTransactionOperations() { + @DisplayName("Should not include wTimeoutMS of WriteConcern to initial and subsequent commitTransaction operations") + public void shouldNotIncludeWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTransactionOperations() { assumeTrue(serverVersionAtLeast(4, 4)); assumeFalse(isStandalone()); @@ -755,14 +761,15 @@ public void shouldIgnoreWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTran MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) .getCollection(namespace.getCollectionName()); + int defaultTimeout = 200; try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder() - .defaultTimeout(200, TimeUnit.MILLISECONDS) + .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS) .build())) { session.startTransaction(TransactionOptions.builder() .writeConcern(WriteConcern.ACKNOWLEDGED.withWTimeout(100, TimeUnit.MILLISECONDS)) .build()); collection.insertOne(session, new Document("x", 1)); - sleep(200); + sleep(defaultTimeout); assertDoesNotThrow(session::commitTransaction); //repeat commit. @@ -805,12 +812,12 @@ public void shouldIgnoreWaitQueueTimeoutMSWhenTimeoutMsIsSet() { + " data: {" + " failCommands: [\"find\" ]," + " blockConnection: true," - + " blockTimeMS: " + 300 + + " blockTimeMS: " + 450 + " }" + "}"); - executor.submit(() -> collection.find().first()); - sleep(100); + executor.execute(() -> collection.find().first()); + sleep(150); //when && then assertDoesNotThrow(() -> collection.find().first()); @@ -844,7 +851,7 @@ public void shouldThrowOperationTimeoutExceptionWhenConnectionIsNotAvailableAndT + " }" + "}"); - executor.submit(() -> collection.withTimeout(0, TimeUnit.MILLISECONDS).find().first()); + executor.execute(() -> collection.withTimeout(0, TimeUnit.MILLISECONDS).find().first()); sleep(100); //when && then @@ -863,7 +870,7 @@ public void shouldUseWaitQueueTimeoutMSWhenTimeoutIsNotSet() { //given try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder() .applyToConnectionPoolSettings(builder -> builder - .maxWaitTime(100, TimeUnit.MILLISECONDS) + .maxWaitTime(20, TimeUnit.MILLISECONDS) .maxSize(1) ))) { MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) @@ -875,12 +882,12 @@ public void shouldUseWaitQueueTimeoutMSWhenTimeoutIsNotSet() { + " data: {" + " failCommands: [\"find\" ]," + " blockConnection: true," - + " blockTimeMS: " + 300 + + " blockTimeMS: " + 400 + " }" + "}"); - executor.submit(() -> collection.find().first()); - sleep(100); + executor.execute(() -> collection.find().first()); + sleep(200); //when & then assertThrows(MongoTimeoutException.class, () -> collection.find().first()); @@ -896,7 +903,6 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkErrorWhenTimeoutMsIsN assumeTrue(serverVersionAtLeast(4, 4)); assumeTrue(isLoadBalanced()); - long rtt = ClusterFixture.getPrimaryRTT(); collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); collectionHelper.insertDocuments(new Document(), new Document()); collectionHelper.runAdminCommand("{" @@ -905,7 +911,7 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkErrorWhenTimeoutMsIsN + " data: {" + " failCommands: [\"getMore\" ]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 600) + + " blockTimeMS: " + 600 + " }" + "}"); @@ -943,7 +949,6 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkError() { assumeTrue(serverVersionAtLeast(4, 4)); assumeTrue(isLoadBalanced()); - long rtt = ClusterFixture.getPrimaryRTT(); collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); collectionHelper.insertDocuments(new Document(), new Document()); collectionHelper.runAdminCommand("{" @@ -952,7 +957,7 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkError() { + " data: {" + " failCommands: [\"getMore\" ]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 600) + + " blockTimeMS: " + 600 + " }" + "}"); @@ -1040,11 +1045,16 @@ public void shouldUseConnectTimeoutMsWhenEstablishingConnectionInBackground() { + " data: {" + " failCommands: [\"hello\", \"isMaster\"]," + " blockConnection: true," - + " blockTimeMS: " + 500 + + " blockTimeMS: " + 500 + "," + // The appName is unique to prevent this failpoint from affecting ClusterFixture's ServerMonitor. + // Without the appName, ClusterFixture's heartbeats would be blocked, polluting RTT measurements with 500ms values, + // which would cause flakiness in other prose tests that use ClusterFixture.getPrimaryRTT() for timeout adjustments. + + " appName: \"connectTimeoutBackgroundTest\"" + " }" + "}"); try (MongoClient ignored = createMongoClient(getMongoClientSettingsBuilder() + .applicationName("connectTimeoutBackgroundTest") .applyToConnectionPoolSettings(builder -> builder.minSize(1)) // Use a very short timeout to ensure that the connection establishment will fail on the first handshake command. .timeout(10, TimeUnit.MILLISECONDS))) { @@ -1075,9 +1085,10 @@ private static Stream test8ServerSelectionArguments() { } private static Stream test8ServerSelectionHandshakeArguments() { + return Stream.of( - Arguments.of("timeoutMS honored for connection handshake commands if it's lower than serverSelectionTimeoutMS", 200, 300), - Arguments.of("serverSelectionTimeoutMS honored for connection handshake commands if it's lower than timeoutMS", 300, 200) + Arguments.of("timeoutMS honored for connection handshake commands if it's lower than serverSelectionTimeoutMS", 200, 500), + Arguments.of("serverSelectionTimeoutMS honored for connection handshake commands if it's lower than timeoutMS", 500, 200) ); } @@ -1088,7 +1099,8 @@ protected MongoNamespace generateNamespace() { protected MongoClientSettings.Builder getMongoClientSettingsBuilder() { commandListener.reset(); - return Fixture.getMongoClientSettingsBuilder() + MongoClientSettings.Builder mongoClientSettingsBuilder = Fixture.getMongoClientSettingsBuilder(); + return mongoClientSettingsBuilder .readConcern(ReadConcern.MAJORITY) .writeConcern(WriteConcern.MAJORITY) .readPreference(ReadPreference.primary()) @@ -1103,6 +1115,9 @@ public void setUp() { gridFsChunksNamespace = new MongoNamespace(getDefaultDatabaseName(), GRID_FS_BUCKET_NAME + ".chunks"); collectionHelper = new CollectionHelper<>(new BsonDocumentCodec(), namespace); + // in some test collection might not have been created yet, thus dropping it in afterEach will throw an error + collectionHelper.create(); + filesCollectionHelper = new CollectionHelper<>(new BsonDocumentCodec(), gridFsFileNamespace); chunksCollectionHelper = new CollectionHelper<>(new BsonDocumentCodec(), gridFsChunksNamespace); commandListener = new TestCommandListener(); @@ -1112,10 +1127,13 @@ public void setUp() { public void tearDown() throws InterruptedException { ClusterFixture.disableFailPoint(FAIL_COMMAND_NAME); if (collectionHelper != null) { + // Due to testing abortTransaction via failpoint, there may be open transactions + // after the test finishes, thus drop() command hangs for 60 seconds until transaction + // is automatically rolled back. + collectionHelper.runAdminCommand("{killAllSessions: []}"); collectionHelper.drop(); filesCollectionHelper.drop(); chunksCollectionHelper.drop(); - commandListener.reset(); try { ServerHelper.checkPool(getPrimary()); } catch (InterruptedException e) { @@ -1139,7 +1157,7 @@ private MongoClient createMongoClient(final MongoClientSettings.Builder builder) return createMongoClient(builder.build()); } - private long msElapsedSince(final long t1) { + protected long msElapsedSince(final long t1) { return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1); } diff --git a/driver-sync/src/test/functional/com/mongodb/observability/MicrometerProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractMicrometerProseTest.java similarity index 57% rename from driver-sync/src/test/functional/com/mongodb/observability/MicrometerProseTest.java rename to driver-sync/src/test/functional/com/mongodb/client/AbstractMicrometerProseTest.java index d4239aa44d7..746b0ffd8d9 100644 --- a/driver-sync/src/test/functional/com/mongodb/observability/MicrometerProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractMicrometerProseTest.java @@ -14,44 +14,59 @@ * limitations under the License. */ -package com.mongodb.observability; +package com.mongodb.client; import com.mongodb.MongoClientSettings; -import com.mongodb.client.Fixture; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoClients; -import com.mongodb.client.MongoCollection; -import com.mongodb.client.MongoDatabase; +import com.mongodb.lang.Nullable; +import com.mongodb.observability.ObservabilitySettings; +import com.mongodb.client.observability.SpanTree; +import com.mongodb.client.observability.SpanTree.SpanNode; import com.mongodb.observability.micrometer.MicrometerObservabilitySettings; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.tracing.exporter.FinishedSpan; import io.micrometer.tracing.test.reporter.inmemory.InMemoryOtelSetup; import org.bson.Document; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static com.mongodb.ClusterFixture.getDefaultDatabaseName; +import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder; +import static com.mongodb.internal.observability.micrometer.MongodbObservation.HighCardinalityKeyNames.QUERY_TEXT; import static com.mongodb.internal.observability.micrometer.TracingManager.ENV_OBSERVABILITY_ENABLED; import static com.mongodb.internal.observability.micrometer.TracingManager.ENV_OBSERVABILITY_QUERY_TEXT_MAX_LENGTH; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.HighCardinalityKeyNames.QUERY_TEXT; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; /** - * Implementation of the prose tests for Micrometer OpenTelemetry tracing. + * Implementation of the prose tests + * for Micrometer OpenTelemetry tracing. */ -public class MicrometerProseTest { +public abstract class AbstractMicrometerProseTest { private final ObservationRegistry observationRegistry = ObservationRegistry.create(); private InMemoryOtelSetup memoryOtelSetup; private InMemoryOtelSetup.Builder.OtelBuildingBlocks inMemoryOtel; private static String previousEnvVarMdbTracingEnabled; private static String previousEnvVarMdbQueryTextLength; + protected abstract MongoClient createMongoClient(MongoClientSettings settings); + @BeforeAll static void beforeAll() { // preserve original env var values @@ -77,18 +92,19 @@ void tearDown() { memoryOtelSetup.close(); } + @DisplayName("Test 1: Tracing Enable/Disable via Environment Variable") @Test void testControlOtelInstrumentationViaEnvironmentVariable() throws Exception { setEnv(ENV_OBSERVABILITY_ENABLED, "false"); // don't enable command payload by default - MongoClientSettings clientSettings = Fixture.getMongoClientSettingsBuilder() + MongoClientSettings clientSettings = getMongoClientSettingsBuilder() .observabilitySettings(ObservabilitySettings.micrometerBuilder() .observationRegistry(observationRegistry) .build()) .build(); - try (MongoClient client = MongoClients.create(clientSettings)) { + try (MongoClient client = createMongoClient(clientSettings)) { MongoDatabase database = client.getDatabase(getDefaultDatabaseName()); MongoCollection collection = database.getCollection("test"); collection.find().first(); @@ -98,7 +114,7 @@ void testControlOtelInstrumentationViaEnvironmentVariable() throws Exception { } setEnv(ENV_OBSERVABILITY_ENABLED, "true"); - try (MongoClient client = MongoClients.create(clientSettings)) { + try (MongoClient client = createMongoClient(clientSettings)) { MongoDatabase database = client.getDatabase(getDefaultDatabaseName()); MongoCollection collection = database.getCollection("test"); collection.find().first(); @@ -114,6 +130,7 @@ void testControlOtelInstrumentationViaEnvironmentVariable() throws Exception { } } + @DisplayName("Test 2: Command Payload Emission via Environment Variable") @Test void testControlCommandPayloadViaEnvironmentVariable() throws Exception { setEnv(ENV_OBSERVABILITY_QUERY_TEXT_MAX_LENGTH, "42"); @@ -123,13 +140,13 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception { .maxQueryTextLength(75) // should be overridden by env var .build(); - MongoClientSettings clientSettings = Fixture.getMongoClientSettingsBuilder() + MongoClientSettings clientSettings = getMongoClientSettingsBuilder() .observabilitySettings(ObservabilitySettings.micrometerBuilder() .applySettings(settings) .build()). build(); - try (MongoClient client = MongoClients.create(clientSettings)) { + try (MongoClient client = createMongoClient(clientSettings)) { MongoDatabase database = client.getDatabase(getDefaultDatabaseName()); MongoCollection collection = database.getCollection("test"); collection.find().first(); @@ -153,14 +170,14 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception { setEnv(ENV_OBSERVABILITY_QUERY_TEXT_MAX_LENGTH, null); // Unset the environment variable - clientSettings = Fixture.getMongoClientSettingsBuilder() + clientSettings = getMongoClientSettingsBuilder() .observabilitySettings(ObservabilitySettings.micrometerBuilder() .observationRegistry(observationRegistry) .maxQueryTextLength(42) // setting this will not matter since env var is not set and enableCommandPayloadTracing is false .build()) .build(); - try (MongoClient client = MongoClients.create(clientSettings)) { + try (MongoClient client = createMongoClient(clientSettings)) { MongoDatabase database = client.getDatabase(getDefaultDatabaseName()); MongoCollection collection = database.getCollection("test"); collection.find().first(); @@ -182,11 +199,11 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception { .maxQueryTextLength(7) // setting this will be used; .build(); - clientSettings = Fixture.getMongoClientSettingsBuilder() + clientSettings = getMongoClientSettingsBuilder() .observabilitySettings(settings) .build(); - try (MongoClient client = MongoClients.create(clientSettings)) { + try (MongoClient client = createMongoClient(clientSettings)) { MongoDatabase database = client.getDatabase(getDefaultDatabaseName()); MongoCollection collection = database.getCollection("test"); collection.find().first(); @@ -200,8 +217,108 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception { } } + /** + * Verifies that concurrent operations produce isolated span trees with no cross-contamination. + * Each operation should get its own trace ID, correct parent-child linkage, and collection-specific tags, + * even when multiple operations execute simultaneously on the same client. + * + *

This test is not from the specification.

+ */ + @Test + void testConcurrentOperationsHaveSeparateSpans() throws Exception { + setEnv(ENV_OBSERVABILITY_ENABLED, "true"); + int nbrConcurrentOps = 10; + MongoClientSettings clientSettings = getMongoClientSettingsBuilder() + .applyToConnectionPoolSettings(pool -> pool.maxSize(nbrConcurrentOps)) + .observabilitySettings(ObservabilitySettings.micrometerBuilder() + .observationRegistry(observationRegistry) + .build()) + .build(); + + try (MongoClient client = createMongoClient(clientSettings)) { + MongoDatabase database = client.getDatabase(getDefaultDatabaseName()); + + // Warm up connections so the concurrent phase doesn't include handshake overhead + for (int i = 0; i < nbrConcurrentOps; i++) { + database.getCollection("concurrent_test_" + i).find().first(); + } + // Clear spans from warm-up before the actual concurrent test + memoryOtelSetup.close(); + memoryOtelSetup = InMemoryOtelSetup.builder().register(observationRegistry); + inMemoryOtel = memoryOtelSetup.getBuildingBlocks(); + + ExecutorService executor = Executors.newFixedThreadPool(nbrConcurrentOps); + try { + CountDownLatch startLatch = new CountDownLatch(1); + List> futures = new ArrayList<>(); + + for (int i = 0; i < nbrConcurrentOps; i++) { + String collectionName = "concurrent_test_" + i; + futures.add(executor.submit(() -> { + try { + startLatch.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + database.getCollection(collectionName).find().first(); + })); + } + + // Release all threads simultaneously to maximize concurrency + startLatch.countDown(); + + for (Future future : futures) { + future.get(30, TimeUnit.SECONDS); + } + } finally { + executor.shutdown(); + } + + List allSpans = inMemoryOtel.getFinishedSpans(); + + // Each find() produces 2 spans: operation-level span + command-level span + assertEquals(nbrConcurrentOps * 2, allSpans.size(), + "Each concurrent operation should produce exactly 2 spans (operation + command)."); + + // Verify trace isolation: each independent operation should get its own traceId + Map> spansByTrace = allSpans.stream() + .collect(Collectors.groupingBy(FinishedSpan::getTraceId)); + assertEquals(nbrConcurrentOps, spansByTrace.size(), + "Each concurrent operation should have its own distinct trace ID."); + + // Use SpanTree to validate parent-child structure built from spanId/parentId linkage + SpanTree spanTree = SpanTree.from(allSpans); + List roots = spanTree.getRoots(); + + // Each operation span is a root; its command span is a child + assertEquals(nbrConcurrentOps, roots.size(), + "SpanTree should have one root per concurrent operation."); + + Set observedCollections = new HashSet<>(); + for (SpanNode root : roots) { + assertTrue(root.getName().startsWith("find " + getDefaultDatabaseName() + ".concurrent_test_"), + "Root span should be an operation span, but was: " + root.getName()); + + assertEquals(1, root.getChildren().size(), + "Each operation span should have exactly one child (command span)."); + assertEquals("find", root.getChildren().get(0).getName(), + "Child span should be the command span 'find'."); + + // Extract collection name from the operation span name to verify no cross-contamination + String collectionName = root.getName().substring( + ("find " + getDefaultDatabaseName() + ".").length()); + assertTrue(observedCollections.add(collectionName), + "Each operation should target a unique collection, but found duplicate: " + collectionName); + } + + assertEquals(nbrConcurrentOps, observedCollections.size(), + "All " + nbrConcurrentOps + " concurrent operations should be represented in distinct traces."); + } + } + @SuppressWarnings("unchecked") - private static void setEnv(final String key, final String value) throws Exception { + private static void setEnv(final String key, @Nullable final String value) throws Exception { // Get the unmodifiable Map from System.getenv() Map env = System.getenv(); diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java index 3682bd64ff0..910cf57edfd 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java @@ -93,7 +93,7 @@ public void shouldCreateServerSessionOnlyAfterConnectionCheckout() throws Interr .addCommandListener(new CommandListener() { @Override public void commandStarted(final CommandStartedEvent event) { - lsidSet.add(event.getCommand().getDocument("lsid")); + lsidSet.add(event.getCommand().getDocument("lsid").clone()); } }) .build())) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java index dd45bc8ae2c..04303833bf5 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java @@ -93,14 +93,13 @@ public abstract class AbstractClientSideOperationsEncryptionTimeoutProseTest { @Test void shouldThrowOperationTimeoutExceptionWhenCreateDataKey() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); Map> kmsProviders = new HashMap<>(); Map localProviderMap = new HashMap<>(); localProviderMap.put("key", Base64.getDecoder().decode(MASTER_KEY)); kmsProviders.put("local", localProviderMap); - try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 100))) { + try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(100))) { keyVaultCollectionHelper.runAdminCommand("{" + " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\"," @@ -108,7 +107,7 @@ void shouldThrowOperationTimeoutExceptionWhenCreateDataKey() { + " data: {" + " failCommands: [\"insert\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 100) + + " blockTimeMS: " + 100 + " }" + "}"); @@ -126,9 +125,8 @@ void shouldThrowOperationTimeoutExceptionWhenCreateDataKey() { @Test void shouldThrowOperationTimeoutExceptionWhenEncryptData() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); - try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 150))) { + try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(150))) { clientEncryption.createDataKey("local"); @@ -138,7 +136,7 @@ void shouldThrowOperationTimeoutExceptionWhenEncryptData() { + " data: {" + " failCommands: [\"find\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 150) + + " blockTimeMS: " + 150 + " }" + "}"); @@ -160,10 +158,9 @@ void shouldThrowOperationTimeoutExceptionWhenEncryptData() { @Test void shouldThrowOperationTimeoutExceptionWhenDecryptData() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); BsonBinary encrypted; - try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 400))) { + try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(400))) { clientEncryption.createDataKey("local"); BsonBinary dataKey = clientEncryption.createDataKey("local"); EncryptOptions encryptOptions = new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"); @@ -171,14 +168,14 @@ void shouldThrowOperationTimeoutExceptionWhenDecryptData() { encrypted = clientEncryption.encrypt(new BsonString("hello"), encryptOptions); } - try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 400))) { + try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(400))) { keyVaultCollectionHelper.runAdminCommand("{" - + " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\"," + + " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\"," + " mode: { times: 1 }," + " data: {" + " failCommands: [\"find\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 500) + + " blockTimeMS: " + 500 + " }" + "}"); commandListener.reset(); @@ -197,8 +194,7 @@ void shouldThrowOperationTimeoutExceptionWhenDecryptData() { @Test void shouldDecreaseOperationTimeoutForSubsequentOperations() { assumeTrue(serverVersionAtLeast(4, 4)); - long rtt = ClusterFixture.getPrimaryRTT(); - long initialTimeoutMS = rtt + 2500; + long initialTimeoutMS = 2500; keyVaultCollectionHelper.runAdminCommand("{" + " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\"," @@ -206,7 +202,7 @@ void shouldDecreaseOperationTimeoutForSubsequentOperations() { + " data: {" + " failCommands: [\"insert\", \"find\", \"listCollections\"]," + " blockConnection: true," - + " blockTimeMS: " + (rtt + 10) + + " blockTimeMS: " + 10 + " }" + "}"); @@ -272,8 +268,7 @@ void shouldDecreaseOperationTimeoutForSubsequentOperations() { void shouldThrowTimeoutExceptionWhenCreateEncryptedCollection(final String commandToTimeout) { assumeTrue(serverVersionAtLeast(7, 0)); //given - long rtt = ClusterFixture.getPrimaryRTT(); - long initialTimeoutMS = rtt + 200; + long initialTimeoutMS = 200; try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder() .timeout(initialTimeoutMS, MILLISECONDS))) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java new file mode 100644 index 00000000000..38bd4350b1d --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java @@ -0,0 +1,32 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.observability; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.AbstractMicrometerProseTest; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +/** + * Sync driver implementation of the Micrometer prose tests. + */ +public class MicrometerProseTest extends AbstractMicrometerProseTest { + @Override + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/observability/SpanTree.java b/driver-sync/src/test/functional/com/mongodb/client/observability/SpanTree.java similarity index 98% rename from driver-sync/src/test/functional/com/mongodb/observability/SpanTree.java rename to driver-sync/src/test/functional/com/mongodb/client/observability/SpanTree.java index aa6697bf3ad..7d3bff3224d 100644 --- a/driver-sync/src/test/functional/com/mongodb/observability/SpanTree.java +++ b/driver-sync/src/test/functional/com/mongodb/client/observability/SpanTree.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.mongodb.observability; +package com.mongodb.client.observability; import com.mongodb.lang.Nullable; import io.micrometer.tracing.exporter.FinishedSpan; @@ -204,6 +204,10 @@ private static void assertValid(final SpanNode reportedNode, final SpanNode expe } } + public List getRoots() { + return Collections.unmodifiableList(roots); + } + @Override public String toString() { return "SpanTree{" diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index cf003078f04..35189aef455 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -28,7 +28,7 @@ import com.mongodb.client.gridfs.GridFSBucket; import com.mongodb.client.model.Filters; import com.mongodb.client.test.CollectionHelper; -import com.mongodb.observability.SpanTree; +import com.mongodb.client.observability.SpanTree; import com.mongodb.client.unified.UnifiedTestModifications.TestDef; import com.mongodb.client.vault.ClientEncryption; import com.mongodb.connection.ClusterDescription; diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java index 2225f837ec5..328c8298b6c 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java @@ -63,6 +63,25 @@ public static void applyCustomizations(final TestDef def) { .file("client-side-encryption/tests/unified", "client bulkWrite with queryable encryption"); // client-side-operation-timeout (CSOT) + def.retry("Unified CSOT tests do not account for RTT which varies in TLS vs non-TLS runs") + .whenFailureContains("timeout") + .test("client-side-operations-timeout", + "timeoutMS behaves correctly for non-tailable cursors", + "timeoutMS is refreshed for getMore if timeoutMode is iteration - success"); + + def.retry("Unified CSOT tests do not account for RTT which varies in TLS vs non-TLS runs") + .whenFailureContains("timeout") + .test("client-side-operations-timeout", + "timeoutMS behaves correctly for tailable non-awaitData cursors", + "timeoutMS is refreshed for getMore - success"); + + def.retry("Unified CSOT tests do not account for RTT which varies in TLS vs non-TLS runs") + .whenFailureContains("timeout") + .test("client-side-operations-timeout", + "timeoutMS behaves correctly for tailable non-awaitData cursors", + "timeoutMS is refreshed for getMore - success"); + + //TODO-invistigate /* As to the background connection pooling section: timeoutMS set at the MongoClient level MUST be used as the timeout for all commands sent as part of the handshake. diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8a08c34f213..b5e561c7f7e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -18,10 +18,10 @@ aws-sdk-v2 = "2.30.31" graal-sdk = "24.0.0" jna = "5.11.0" jnr-unixsocket = "0.38.17" -netty-bom = "4.1.87.Final" +netty-bom = "4.2.9.Final" project-reactor-bom = "2022.0.0" reactive-streams = "1.0.4" -snappy = "1.1.10.3" +snappy = "1.1.10.4" zstd = "1.5.5-3" jetbrains-annotations = "26.0.2" micrometer-tracing = "1.6.0-M3" # This version has a fix for https://github.com/micrometer-metrics/tracing/issues/1092 diff --git a/testing/resources/specifications b/testing/resources/specifications index de684cf1ef9..bb9dddd8176 160000 --- a/testing/resources/specifications +++ b/testing/resources/specifications @@ -1 +1 @@ -Subproject commit de684cf1ef9feede71d358cbb7d253840f1a8647 +Subproject commit bb9dddd8176eddbb9424f9bebedfe8c6bbf28c3a From 8711eef93c9ef60c2b806642ef5a08602771a088 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 24 Mar 2026 15:28:35 +0000 Subject: [PATCH 4/4] ByteBuf leak fixes (#1876) - Ensure Default Server Monitor calls close on resources before interrupt - Update ByteBufferBsonOutput documentation - Improve ReplyHeader testing and ensure resources are closed - Improve ServerSessionPool testing - Ensure reactive client session closing is idempotent - Added System.gc to unified test cleanup. Should cause more gc when testing. JAVA-6081 --- bson/src/main/org/bson/BsonDocument.java | 4 +- config/spotbugs/exclude.xml | 6 + .../connection/ByteBufferBsonOutput.java | 117 ++++++- .../connection/DefaultServerMonitor.java | 215 +++++++++--- .../ReplyHeaderSpecification.groovy | 201 ----------- .../internal/connection/ReplyHeaderTest.java | 213 ++++++++++++ .../connection/DefaultServerMonitorTest.java | 50 +++ .../ServerSessionPoolSpecification.groovy | 229 ------------- .../session/ServerSessionPoolTest.java | 320 ++++++++++++++++++ .../main/com/mongodb/DBDecoderAdapter.java | 9 +- .../internal/ClientSessionPublisherImpl.java | 20 +- .../reactivestreams/client/Fixture.java | 18 +- .../mongodb/client/unified/UnifiedTest.java | 3 + 13 files changed, 913 insertions(+), 492 deletions(-) delete mode 100644 driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy create mode 100644 driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java delete mode 100644 driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy create mode 100644 driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java diff --git a/bson/src/main/org/bson/BsonDocument.java b/bson/src/main/org/bson/BsonDocument.java index 87625de8dbd..f52a1f25f7d 100644 --- a/bson/src/main/org/bson/BsonDocument.java +++ b/bson/src/main/org/bson/BsonDocument.java @@ -921,10 +921,12 @@ private static class SerializationProxy implements Serializable { new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); this.bytes = new byte[buffer.size()]; int curPos = 0; - for (ByteBuf cur : buffer.getByteBuffers()) { + List byteBuffers = buffer.getByteBuffers(); + for (ByteBuf cur : byteBuffers) { System.arraycopy(cur.array(), cur.position(), bytes, curPos, cur.limit()); curPos += cur.position(); } + byteBuffers.forEach(ByteBuf::release); } private Object readResolve() { diff --git a/config/spotbugs/exclude.xml b/config/spotbugs/exclude.xml index 20684680865..e54d445dc83 100644 --- a/config/spotbugs/exclude.xml +++ b/config/spotbugs/exclude.xml @@ -290,4 +290,10 @@ + + + + + + diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java index 1edbb0f4c2f..0988679f28d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java @@ -16,6 +16,9 @@ package com.mongodb.internal.connection; +import com.mongodb.annotations.Sealed; +import com.mongodb.internal.ResourceUtil; +import com.mongodb.internal.VisibleForTesting; import org.bson.BsonSerializationException; import org.bson.ByteBuf; import org.bson.io.OutputBuffer; @@ -28,11 +31,28 @@ import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; /** + * A BSON output implementation that uses pooled {@link ByteBuf} instances for efficient memory management. + * + *

ByteBuf Ownership and Lifecycle

+ *

This class manages the lifecycle of {@link ByteBuf} instances obtained from the {@link BufferProvider}. + * The ownership model is as follows:

+ *
    + *
  • Internal buffers are owned by this output and released when {@link #close()} is called or + * when {@link #truncateToPosition(int)} removes them.
  • + *
  • Methods that return {@link ByteBuf} instances (e.g., {@link #getByteBuffers()}) return + * duplicates with their own reference counts. Callers are responsible for releasing + * these buffers to prevent memory leaks.
  • + *
  • The {@link Branch} subclass merges its buffers into the parent on close, transferring + * ownership by retaining buffers before the branch releases them.
  • + *
+ * *

This class is not part of the public API and may be removed or changed at any time

*/ +@Sealed public class ByteBufferBsonOutput extends OutputBuffer { private static final int MAX_SHIFT = 31; @@ -50,6 +70,9 @@ public class ByteBufferBsonOutput extends OutputBuffer { /** * Construct an instance that uses the given buffer provider to allocate byte buffers as needs as it grows. * + *

The buffer provider is used to allocate new {@link ByteBuf} instances as the output grows. + * All allocated buffers are owned by this output and will be released when {@link #close()} is called.

+ * * @param bufferProvider the non-null buffer provider */ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { @@ -63,6 +86,10 @@ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { * If multiple branches are created, they are merged in the order they are {@linkplain ByteBufferBsonOutput.Branch#close() closed}. * {@linkplain #close() Closing} this {@link ByteBufferBsonOutput} does not {@linkplain ByteBufferBsonOutput.Branch#close() close} the branch. * + *

ByteBuf Ownership: The branch allocates its own buffers. When the branch is closed, + * ownership of these buffers is transferred to the parent by retaining them before the branch releases + * its references. The parent then becomes responsible for releasing these buffers when it is closed.

+ * * @return A new {@link ByteBufferBsonOutput.Branch}. */ public ByteBufferBsonOutput.Branch branch() { @@ -223,10 +250,28 @@ protected void write(final int absolutePosition, final int value) { byteBuffer.put(bufferPositionPair.position++, (byte) value); } + /** + * Returns a list of duplicated byte buffers containing the written data, flipped for reading. + * + *

ByteBuf Ownership: The returned buffers are duplicates with their own + * reference counts (each starts with a reference count of 1). The caller is responsible + * for releasing each buffer when done to prevent memory leaks. Example usage:

+ *
{@code
+     * List buffers = output.getByteBuffers();
+     * try {
+     *     // use buffers
+     * } finally {
+     *     ResourceUtil.release(buffers);
+     * }
+     * }
+ *

Note: These buffers must be released before this {@code ByteBufferBsonOutput} is closed. + * Otherwise there is a risk of the buffers being released back to the bufferProvider and data corruption.

+ * + * @return a list of duplicated buffers, flipped for reading + */ @Override public List getByteBuffers() { ensureOpen(); - List buffers = new ArrayList<>(bufferList.size()); for (final ByteBuf cur : bufferList) { buffers.add(cur.duplicate().order(ByteOrder.LITTLE_ENDIAN).flip()); @@ -234,6 +279,17 @@ public List getByteBuffers() { return buffers; } + /** + * Returns a list of duplicated byte buffers without flipping them. + * + *

ByteBuf Ownership: The returned buffers are duplicates with their own + * reference counts (each starts with a reference count of 1). The caller is responsible + * for releasing each buffer when done to prevent memory leaks.

+ * + * @return a list of duplicated buffers + * @see #getByteBuffers() + */ + @VisibleForTesting(otherwise = PRIVATE) public List getDuplicateByteBuffers() { ensureOpen(); @@ -245,6 +301,13 @@ public List getDuplicateByteBuffers() { } + /** + * {@inheritDoc} + * + *

ByteBuf Management: This method obtains duplicated buffers via + * {@link #getByteBuffers()} and releases them after writing to the output stream, + * ensuring no buffer leaks occur.

+ */ @Override public int pipe(final OutputStream out) throws IOException { ensureOpen(); @@ -263,11 +326,20 @@ public int pipe(final OutputStream out) throws IOException { total += cur.limit(); } } finally { - byteBuffers.forEach(ByteBuf::release); + ResourceUtil.release(byteBuffers); } return total; } + /** + * Truncates this output to the specified position, releasing any buffers that are no longer needed. + * + *

ByteBuf Management: Any buffers beyond the new position are removed from + * the internal buffer list and released. This ensures no memory leaks when truncating.

+ * + * @param newPosition the new position to truncate to + * @throws IllegalArgumentException if newPosition is negative or greater than the current position + */ @Override public void truncateToPosition(final int newPosition) { ensureOpen(); @@ -306,13 +378,15 @@ public final void flush() throws IOException { * {@inheritDoc} *

* Idempotent.

+ * + *

ByteBuf Management: Releases internal buffers and clears the buffer list. + * After this method returns, all buffers that were allocated by this output will have been fully released + * back to the buffer provider.

*/ @Override public void close() { if (isOpen()) { - for (final ByteBuf cur : bufferList) { - cur.release(); - } + ResourceUtil.release(bufferList); currentByteBuffer = null; bufferList.clear(); closed = true; @@ -345,7 +419,14 @@ boolean isOpen() { } /** - * @see #branch() + * Merges a branch's buffers into this output. + * + *

ByteBuf Ownership: This method retains each buffer from the branch before + * adding it to this output's buffer list. This is necessary because the branch will release its + * references when it closes. The retain ensures the buffers remain valid and are now owned by + * this output.

+ * + * @param branch the branch to merge */ private void merge(final ByteBufferBsonOutput branch) { assertTrue(branch instanceof ByteBufferBsonOutput.Branch); @@ -356,6 +437,20 @@ private void merge(final ByteBufferBsonOutput branch) { currentByteBuffer = null; } + /** + * A branch of a {@link ByteBufferBsonOutput} that can be merged back into its parent. + * + *

ByteBuf Ownership: A branch allocates its own buffers independently. + * When {@link #close()} is called:

+ *
    + *
  1. The parent's {@link ByteBufferBsonOutput#merge(ByteBufferBsonOutput)} method is called, + * which retains all buffers in this branch.
  2. + *
  3. Then {@code super.close()} is called, which releases the branch's references to the buffers.
  4. + *
+ *

The retain/release sequence ensures buffers are safely transferred to the parent without leaks.

+ * + * @see #branch() + */ public static final class Branch extends ByteBufferBsonOutput { private final ByteBufferBsonOutput parent; @@ -365,6 +460,16 @@ private Branch(final ByteBufferBsonOutput parent) { } /** + * Closes this branch and merges its data into the parent output. + * + *

ByteBuf Ownership: On close, this branch's buffers are transferred + * to the parent. The parent retains the buffers (incrementing reference counts), and then + * this branch releases only its own single reference. The parent + * becomes the sole owner of the buffers and is responsible for releasing them.

+ * + *

Idempotent. If already closed, this method does nothing.

+ * + * @throws AssertionError if the parent has been closed before this branch * @see #branch() */ @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java index bb97517d315..e4ca569761a 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java @@ -187,11 +187,17 @@ class ServerMonitor extends Thread implements AutoCloseable { @Override public void close() { - interrupt(); - InternalConnection connection = this.connection; - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; + connection = null; + return result; + }); + + if (localConnection != null) { + localConnection.close(); } + + interrupt(); } @Override @@ -219,8 +225,9 @@ public void run() { logStateChange(previousServerDescription, currentServerDescription); sdamProvider.get().monitorUpdate(currentServerDescription); + InternalConnection localConnection = connection; if ((shouldStreamResponses && currentServerDescription.getType() != UNKNOWN) - || (connection != null && connection.hasMoreToCome()) + || (localConnection != null && localConnection.hasMoreToCome()) || (currentServerDescription.getException() instanceof MongoSocketException && previousServerDescription.getType() != UNKNOWN)) { continue; @@ -233,8 +240,9 @@ public void run() { LOGGER.error(format("%s for %s stopped working. You may want to recreate the MongoClient", this, serverId), t); throw t; } finally { - if (connection != null) { - connection.close(); + InternalConnection localConnection = connection; + if (localConnection != null) { + localConnection.close(); } } } @@ -249,7 +257,8 @@ private ServerDescription lookupServerDescription(final ServerDescription curren lookupStartTimeNanos = System.nanoTime(); // Handle connection setup - if (connection == null || connection.isClosed()) { + InternalConnection localConnection = connection; + if (localConnection == null || localConnection.isClosed()) { return setupNewConnectionAndGetInitialDescription(shouldStreamResponses); } @@ -275,17 +284,47 @@ private ServerDescription lookupServerDescription(final ServerDescription curren } private ServerDescription setupNewConnectionAndGetInitialDescription(final boolean shouldStreamResponses) { - connection = internalConnectionFactory.create(serverId); + InternalConnection newConnection = internalConnectionFactory.create(serverId); + + // Publish the connection to the field under the lock so that heartbeat + // started logging (which reads the field) can see it, but only if the + // monitor has not been closed in the meantime. + boolean published = withLock(lock, () -> { + if (!isClosed) { + connection = newConnection; + return true; + } + return false; + }); + + if (!published) { + newConnection.close(); + throw new MongoSocketException("Monitor closed", serverId.getAddress()); + } + logAndNotifyHeartbeatStarted(shouldStreamResponses); try { - connection.open(operationContextFactory.create()); - roundTripTimeSampler.addSample(connection.getInitialServerDescription().getRoundTripTimeNanos()); - return connection.getInitialServerDescription(); + newConnection.open(operationContextFactory.create()); } catch (Exception e) { logAndNotifyHeartbeatFailed(shouldStreamResponses, e); throw e; } + + // After the potentially long open(), verify the monitor is still open + // before using the connection. If close() ran during open(), it already + // nulled the field and closed the connection, so we must not use it. + boolean stillValid = withLock(lock, () -> !isClosed && connection == newConnection); + + if (!stillValid) { + // close() may or may not have closed newConnection already; + // closing an already-closed connection is a safe no-op. + newConnection.close(); + throw new MongoSocketException("Monitor closed during connection open", serverId.getAddress()); + } + + roundTripTimeSampler.addSample(newConnection.getInitialServerDescription().getRoundTripTimeNanos()); + return newConnection.getInitialServerDescription(); } /** @@ -293,24 +332,36 @@ private ServerDescription setupNewConnectionAndGetInitialDescription(final boole */ private ServerDescription doHeartbeat(final ServerDescription currentServerDescription, final boolean shouldStreamResponses) { + // Check if monitor was closed or connection is unusable + InternalConnection localConnection = withLock(lock, () -> { + if (isClosed || connection == null || connection.isClosed()) { + return null; + } + return connection; + }); + + if (localConnection == null) { + throw new MongoSocketException("Monitor closed", serverId.getAddress()); + } + try { OperationContext operationContext = operationContextFactory.create(); - if (!connection.hasMoreToCome()) { + if (!localConnection.hasMoreToCome()) { BsonDocument helloDocument = new BsonDocument(getHandshakeCommandName(currentServerDescription), new BsonInt32(1)) .append("helloOk", BsonBoolean.TRUE); if (shouldStreamResponses) { helloDocument.append("topologyVersion", assertNotNull(currentServerDescription.getTopologyVersion()).asDocument()); helloDocument.append("maxAwaitTimeMS", new BsonInt64(serverSettings.getHeartbeatFrequency(MILLISECONDS))); } - connection.send(createCommandMessage(helloDocument, connection, currentServerDescription), new BsonDocumentCodec(), + localConnection.send(createCommandMessage(helloDocument, localConnection, currentServerDescription), new BsonDocumentCodec(), operationContext); } BsonDocument helloResult; if (shouldStreamResponses) { - helloResult = connection.receive(new BsonDocumentCodec(), operationContextWithAdditionalTimeout(operationContext)); + helloResult = localConnection.receive(new BsonDocumentCodec(), operationContextWithAdditionalTimeout(operationContext)); } else { - helloResult = connection.receive(new BsonDocumentCodec(), operationContext); + helloResult = localConnection.receive(new BsonDocumentCodec(), operationContext); } logAndNotifyHeartbeatSucceeded(shouldStreamResponses, helloResult); return createServerDescription(serverId.getAddress(), helloResult, roundTripTimeSampler.getAverage(), @@ -322,10 +373,23 @@ private ServerDescription doHeartbeat(final ServerDescription currentServerDescr } private void logAndNotifyHeartbeatStarted(final boolean shouldStreamResponses) { - alreadyLoggedHeartBeatStarted = true; - logHeartbeatStarted(serverId, connection.getDescription(), shouldStreamResponses); - serverMonitorListener.serverHearbeatStarted(new ServerHeartbeatStartedEvent( - connection.getDescription().getConnectionId(), shouldStreamResponses)); + ConnectionDescription description = withLock(lock, () -> { + if (connection != null) { + return connection.getDescription(); + } + return null; + }); + if (description != null) { + alreadyLoggedHeartBeatStarted = true; + logHeartbeatStarted(serverId, description, shouldStreamResponses); + serverMonitorListener.serverHearbeatStarted(new ServerHeartbeatStartedEvent( + description.getConnectionId(), shouldStreamResponses)); + } else { + // Connection not fully established yet - skip logging for this heartbeat + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Skipping heartbeat started event for %s - connection description not available", serverId)); + } + } } private void logAndNotifyHeartbeatSucceeded(final boolean shouldStreamResponses, final BsonDocument helloResult) { @@ -334,19 +398,42 @@ private void logAndNotifyHeartbeatSucceeded(final boolean shouldStreamResponses, if (!shouldStreamResponses) { roundTripTimeSampler.addSample(elapsedTimeNanos); } - logHeartbeatSucceeded(serverId, connection.getDescription(), shouldStreamResponses, elapsedTimeNanos, helloResult); - serverMonitorListener.serverHeartbeatSucceeded( - new ServerHeartbeatSucceededEvent(connection.getDescription().getConnectionId(), helloResult, - elapsedTimeNanos, shouldStreamResponses)); + + ConnectionDescription description = withLock(lock, () -> { + if (connection != null) { + return connection.getDescription(); + } + return null; + }); + if (description != null) { + logHeartbeatSucceeded(serverId, description, shouldStreamResponses, elapsedTimeNanos, helloResult); + serverMonitorListener.serverHeartbeatSucceeded( + new ServerHeartbeatSucceededEvent(description.getConnectionId(), helloResult, + elapsedTimeNanos, shouldStreamResponses)); + } } private void logAndNotifyHeartbeatFailed(final boolean shouldStreamResponses, final Exception e) { alreadyLoggedHeartBeatStarted = false; long elapsedTimeNanos = getElapsedTimeNanos(); - logHeartbeatFailed(serverId, connection.getDescription(), shouldStreamResponses, elapsedTimeNanos, e); - serverMonitorListener.serverHeartbeatFailed( - new ServerHeartbeatFailedEvent(connection.getDescription().getConnectionId(), elapsedTimeNanos, - shouldStreamResponses, e)); + + ConnectionDescription description = withLock(lock, () -> { + if (connection != null) { + return connection.getDescription(); + } + return null; + }); + if (description != null) { + logHeartbeatFailed(serverId, description, shouldStreamResponses, elapsedTimeNanos, e); + serverMonitorListener.serverHeartbeatFailed( + new ServerHeartbeatFailedEvent(description.getConnectionId(), elapsedTimeNanos, + shouldStreamResponses, e)); + } else { + // Log failure without connection details + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Heartbeat failed for %s but connection description not available", serverId), e); + } + } } private long getElapsedTimeNanos() { @@ -514,11 +601,17 @@ private class RoundTripTimeMonitor extends Thread implements AutoCloseable { @Override public void close() { - interrupt(); - InternalConnection connection = this.connection; - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; + connection = null; + return result; + }); + + if (localConnection != null) { + localConnection.close(); } + + interrupt(); } @Override @@ -526,15 +619,20 @@ public void run() { try { while (!isClosed) { try { - if (connection == null) { + InternalConnection localConnection = connection; + if (localConnection == null) { initialize(); } else { - pingServer(connection); + pingServer(localConnection); } } catch (Exception t) { - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; connection = null; + return result; + }); + if (localConnection != null) { + localConnection.close(); } } waitForNext(); @@ -545,20 +643,53 @@ public void run() { LOGGER.error(format("%s for %s stopped working. You may want to recreate the MongoClient", this, serverId), t); throw t; } finally { - if (connection != null) { - connection.close(); + InternalConnection localConnection = connection; + if (localConnection != null) { + localConnection.close(); } } } private void initialize() { - connection = null; - connection = internalConnectionFactory.create(serverId); - connection.open(operationContextFactory.create()); - roundTripTimeSampler.addSample(connection.getInitialServerDescription().getRoundTripTimeNanos()); + boolean shouldProceed = withLock(lock, () -> !isClosed); + + if (!shouldProceed) { + return; + } + + InternalConnection newConnection = internalConnectionFactory.create(serverId); + newConnection.open(operationContextFactory.create()); + + // Check again after the potentially long open() operation + boolean stillValid = withLock(lock, () -> { + if (!isClosed) { + connection = newConnection; + return true; + } + return false; + }); + + if (stillValid) { + roundTripTimeSampler.addSample(newConnection.getInitialServerDescription().getRoundTripTimeNanos()); + } else { + // Monitor was closed during open(), clean up the connection + newConnection.close(); + } } private void pingServer(final InternalConnection connection) { + // Atomically check if monitor was closed and connection is still valid + boolean shouldProceed = withLock(lock, () -> + !isClosed && this.connection == connection + ); + + if (!shouldProceed) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Skipping ping for %s - monitor closed or connection changed", serverId)); + } + return; // Monitor closed or connection changed, skip ping + } + long start = System.nanoTime(); OperationContext operationContext = operationContextFactory.create(); executeCommand("admin", diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy deleted file mode 100644 index 0407baeca8a..00000000000 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package com.mongodb.internal.connection - -import com.mongodb.MongoInternalException -import org.bson.io.BasicOutputBuffer -import spock.lang.Specification - -import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize - -class ReplyHeaderSpecification extends Specification { - - def 'should parse reply header'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(responseFlags) - writeLong(9000) - writeInt(4) - writeInt(1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - def replyHeader = new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - replyHeader.messageLength == 186 - replyHeader.requestId == 45 - replyHeader.responseTo == 23 - - where: - responseFlags << [0, 1, 2, 3] - cursorNotFound << [false, true, false, true] - queryFailure << [false, false, true, true] - } - - def 'should parse reply header with compressed header'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(2012) - writeInt(1) - writeInt(258) - writeByte(2) - writeInt(responseFlags) - writeLong(9000) - writeInt(4) - writeInt(1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - def compressedHeader = new CompressedHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - when: - def replyHeader = new ReplyHeader(byteBuf, compressedHeader) - - then: - replyHeader.messageLength == 274 - replyHeader.requestId == 45 - replyHeader.responseTo == 23 - - where: - responseFlags << [0, 1, 2, 3] - cursorNotFound << [false, true, false, true] - queryFailure << [false, false, true, true] - } - - def 'should throw MongoInternalException on incorrect opCode'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(36) - writeInt(45) - writeInt(23) - writeInt(2) - writeInt(0) - writeLong(2) - writeInt(0) - writeInt(0) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'Unexpected reply message opCode 2' - } - - def 'should throw MongoInternalException on message size < 36'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(35) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(0) - writeLong(2) - writeInt(0) - writeInt(0) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message length 35 is less than the minimum message length 36' - } - - def 'should throw MongoInternalException on message size > max message size'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(400) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(0) - writeLong(2) - writeInt(0) - writeInt(0) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, 399)) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message length 400 is greater than the maximum message length 399' - } - - def 'should throw MongoInternalException on num documents < 0'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(1) - writeLong(9000) - writeInt(4) - writeInt(-1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1' - } - - def 'should throw MongoInternalException on num documents < 0 with compressed header'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(2012) - writeInt(1) - writeInt(258) - writeByte(2) - writeInt(1) - writeLong(9000) - writeInt(4) - writeInt(-1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - def compressedHeader = new CompressedHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - when: - new ReplyHeader(byteBuf, compressedHeader) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1' - } -} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java new file mode 100644 index 00000000000..38bc96731c2 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java @@ -0,0 +1,213 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.MongoInternalException; +import org.bson.ByteBuf; +import org.bson.io.BasicOutputBuffer; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.List; + +import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@DisplayName("ReplyHeader") +class ReplyHeaderTest { + + @ParameterizedTest(name = "with responseFlags {0}") + @ValueSource(ints = {0, 1, 2, 3}) + @DisplayName("should parse reply header with various response flags") + void testParseReplyHeader(final int responseFlags) { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(responseFlags); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(1); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + ReplyHeader replyHeader = new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())); + + assertEquals(186, replyHeader.getMessageLength()); + assertEquals(45, replyHeader.getRequestId()); + assertEquals(23, replyHeader.getResponseTo()); + + byteBuffers.forEach(ByteBuf::release); + } + } + + @ParameterizedTest(name = "with responseFlags {0}") + @ValueSource(ints = {0, 1, 2, 3}) + @DisplayName("should parse reply header with compressed header and various response flags") + void testParseReplyHeaderWithCompressedHeader(final int responseFlags) { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(2012); + outputBuffer.writeInt(1); + outputBuffer.writeInt(258); + outputBuffer.writeByte(2); + outputBuffer.writeInt(responseFlags); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(1); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + CompressedHeader compressedHeader = new CompressedHeader(byteBuf, + new MessageHeader(byteBuf, getDefaultMaxMessageSize())); + ReplyHeader replyHeader = new ReplyHeader(byteBuf, compressedHeader); + + assertEquals(274, replyHeader.getMessageLength()); + assertEquals(45, replyHeader.getRequestId()); + assertEquals(23, replyHeader.getResponseTo()); + byteBuffers.forEach(ByteBuf::release); + } + } + + @Test + @DisplayName("should throw MongoInternalException on incorrect opCode") + void testThrowExceptionOnIncorrectOpCode() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(36); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(2); + outputBuffer.writeInt(0); + outputBuffer.writeLong(2); + outputBuffer.writeInt(0); + outputBuffer.writeInt(0); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))); + + assertEquals("Unexpected reply message opCode 2", ex.getMessage()); + byteBuffers.forEach(ByteBuf::release); + } + } + + @Test + @DisplayName("should throw MongoInternalException on message size less than 36 bytes") + void testThrowExceptionOnMessageSizeLessThan36() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(35); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(0); + outputBuffer.writeLong(2); + outputBuffer.writeInt(0); + outputBuffer.writeInt(0); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))); + + assertEquals("The reply message length 35 is less than the minimum message length 36", ex.getMessage()); + byteBuffers.forEach(ByteBuf::release); + } + } + + @Test + @DisplayName("should throw MongoInternalException on message size exceeding max message size") + void testThrowExceptionOnMessageSizeExceedingMax() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(400); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(0); + outputBuffer.writeLong(2); + outputBuffer.writeInt(0); + outputBuffer.writeInt(0); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, 399))); + + assertEquals("The reply message length 400 is greater than the maximum message length 399", ex.getMessage()); + byteBuffers.forEach(ByteBuf::release); + } + } + + @Test + @DisplayName("should throw MongoInternalException on negative number of returned documents") + void testThrowExceptionOnNegativeNumberOfDocuments() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(1); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(-1); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))); + + assertEquals("The reply message number of returned documents, -1, is expected to be 1", ex.getMessage()); + byteBuffers.forEach(ByteBuf::release); + } + } + + @Test + @DisplayName("should throw MongoInternalException on negative number of documents with compressed header") + void testThrowExceptionOnNegativeNumberOfDocumentsWithCompressedHeader() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(2012); + outputBuffer.writeInt(1); + outputBuffer.writeInt(258); + outputBuffer.writeByte(2); + outputBuffer.writeInt(1); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(-1); + + List byteBuffers = outputBuffer.getByteBuffers(); + ByteBuf byteBuf = byteBuffers.get(0); + CompressedHeader compressedHeader = new CompressedHeader(byteBuf, + new MessageHeader(byteBuf, getDefaultMaxMessageSize())); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, compressedHeader)); + + assertEquals("The reply message number of returned documents, -1, is expected to be 1", ex.getMessage()); + byteBuffers.forEach(ByteBuf::release); + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java index 3aff244ea1e..bd587464c23 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java @@ -58,6 +58,8 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -254,6 +256,54 @@ public void serverHeartbeatFailed(final ServerHeartbeatFailedEvent event) { assertEquals(expectedEvents, events); } + @Test + void closeDuringConnectionShouldNotLeakBuffers() throws Exception { + CountDownLatch connectionStarted = new CountDownLatch(1); + CountDownLatch proceedWithOpen = new CountDownLatch(1); + + InternalConnection mockConnection = mock(InternalConnection.class); + doAnswer(invocation -> { + connectionStarted.countDown(); + assertTrue(proceedWithOpen.await(5, TimeUnit.SECONDS)); + return null; + }).when(mockConnection).open(any()); + + when(mockConnection.getInitialServerDescription()) + .thenReturn(createDefaultServerDescription()); + + InternalConnectionFactory factory = createConnectionFactory(mockConnection); + + monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class)); + + // Wait for connection to start opening + assertTrue(connectionStarted.await(5, TimeUnit.SECONDS)); + + // Close monitor while connection is opening + monitor.close(); + + // Allow connection to complete + proceedWithOpen.countDown(); + + // Verify no leaks by checking connection was properly closed + monitor.getServerMonitor().join(5000); + assertFalse(monitor.getServerMonitor().isAlive()); + verify(mockConnection, timeout(500)).close(); + } + + @Test + void heartbeatWithNullConnectionDescriptionShouldNotCrash() throws Exception { + InternalConnection mockConnection = mock(InternalConnection.class); + when(mockConnection.getDescription()).thenReturn(null); + when(mockConnection.getInitialServerDescription()) + .thenReturn(createDefaultServerDescription()); + when(mockConnection.isClosed()).thenReturn(false); + + InternalConnectionFactory factory = createConnectionFactory(mockConnection); + monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class)); + + // Monitor should handle null description gracefully + verify(mockConnection, timeout(500).atLeast(1)).open(any()); + } private InternalConnectionFactory createConnectionFactory(final InternalConnection connection) { InternalConnectionFactory factory = mock(InternalConnectionFactory.class); diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy deleted file mode 100644 index 19bfa994200..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.session - -import com.mongodb.ServerAddress -import com.mongodb.connection.ClusterDescription -import com.mongodb.connection.ClusterSettings -import com.mongodb.connection.ServerDescription -import com.mongodb.connection.ServerSettings -import com.mongodb.internal.connection.Cluster -import com.mongodb.internal.connection.Connection -import com.mongodb.internal.connection.Server -import com.mongodb.internal.connection.ServerTuple -import com.mongodb.internal.validator.NoOpFieldNameValidator -import org.bson.BsonArray -import org.bson.BsonBinarySubType -import org.bson.BsonDocument -import org.bson.codecs.BsonDocumentCodec -import spock.lang.Specification - -import static com.mongodb.ClusterFixture.OPERATION_CONTEXT -import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS -import static com.mongodb.ClusterFixture.getServerApi -import static com.mongodb.ReadPreference.primaryPreferred -import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE -import static com.mongodb.connection.ClusterType.REPLICA_SET -import static com.mongodb.connection.ServerConnectionState.CONNECTED -import static com.mongodb.connection.ServerConnectionState.CONNECTING -import static com.mongodb.connection.ServerType.REPLICA_SET_PRIMARY -import static com.mongodb.connection.ServerType.UNKNOWN -import static java.util.concurrent.TimeUnit.MINUTES - -class ServerSessionPoolSpecification extends Specification { - - def connectedDescription = new ClusterDescription(MULTIPLE, REPLICA_SET, - [ - ServerDescription.builder().ok(true) - .state(CONNECTED) - .address(new ServerAddress()) - .type(REPLICA_SET_PRIMARY) - .logicalSessionTimeoutMinutes(30) - .build() - ], ClusterSettings.builder().hosts([new ServerAddress()]).build(), ServerSettings.builder().build()) - - def unconnectedDescription = new ClusterDescription(MULTIPLE, REPLICA_SET, - [ - ServerDescription.builder().ok(true) - .state(CONNECTING) - .address(new ServerAddress()) - .type(UNKNOWN) - .logicalSessionTimeoutMinutes(null) - .build() - ], ClusterSettings.builder().hosts([new ServerAddress()]).build(), ServerSettings.builder().build()) - - def 'should get session'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - - when: - def session = pool.get() - - then: - session != null - } - - def 'should throw IllegalStateException if pool is closed'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - pool.close() - - when: - pool.get() - - then: - thrown(IllegalStateException) - } - - def 'should pool session'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - def session = pool.get() - - when: - pool.release(session) - def pooledSession = pool.get() - - then: - session == pooledSession - } - - def 'should prune sessions when getting'() { - given: - def cluster = Mock(Cluster) { - getCurrentDescription() >> connectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >>> [0, MINUTES.toMillis(29) + 1, - ] - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - def sessionOne = pool.get() - - when: - pool.release(sessionOne) - - then: - !sessionOne.closed - - when: - def sessionTwo = pool.get() - - then: - sessionTwo != sessionOne - sessionOne.closed - 0 * cluster.selectServer(_) - } - - def 'should not prune session when timeout is null'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> unconnectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >>> [0, 0, 0] - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - def session = pool.get() - - when: - pool.release(session) - def newSession = pool.get() - - then: - session == newSession - } - - def 'should initialize session'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >> 42 - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - - when: - def session = pool.get() as ServerSessionPool.ServerSessionImpl - - then: - session.lastUsedAtMillis == 42 - session.transactionNumber == 0 - def uuid = session.identifier.getBinary('id') - uuid != null - uuid.type == BsonBinarySubType.UUID_STANDARD.value - uuid.data.length == 16 - } - - def 'should advance transaction number'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >> 42 - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - - when: - def session = pool.get() as ServerSessionPool.ServerSessionImpl - - then: - session.transactionNumber == 0 - session.advanceTransactionNumber() == 1 - session.transactionNumber == 1 - } - - def 'should end pooled sessions when pool is closed'() { - given: - def connection = Mock(Connection) - def server = Stub(Server) { - getConnection(_) >> connection - } - def cluster = Mock(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - def sessions = [] - 10.times { sessions.add(pool.get()) } - - for (def cur : sessions) { - pool.release(cur) - } - - when: - pool.close() - - then: - 1 * cluster.selectServer(_, _) >> new ServerTuple(server, connectedDescription.serverDescriptions[0]) - 1 * connection.command('admin', - new BsonDocument('endSessions', new BsonArray(sessions*.getIdentifier())), - { it instanceof NoOpFieldNameValidator }, primaryPreferred(), - { it instanceof BsonDocumentCodec }, _) >> new BsonDocument() - 1 * connection.release() - } -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java new file mode 100644 index 00000000000..0322d0f4063 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java @@ -0,0 +1,320 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.session; + +import com.mongodb.MongoException; +import com.mongodb.ServerAddress; +import com.mongodb.connection.ClusterDescription; +import com.mongodb.connection.ClusterSettings; +import com.mongodb.connection.ServerDescription; +import com.mongodb.connection.ServerSettings; +import com.mongodb.internal.connection.Cluster; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.Server; +import com.mongodb.internal.connection.ServerTuple; +import com.mongodb.session.ServerSession; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentMatcher; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.ArrayList; +import java.util.List; + +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; +import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS; +import static com.mongodb.ClusterFixture.getServerApi; +import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE; +import static com.mongodb.connection.ClusterType.REPLICA_SET; +import static com.mongodb.connection.ServerConnectionState.CONNECTED; +import static com.mongodb.connection.ServerConnectionState.CONNECTING; +import static com.mongodb.connection.ServerType.REPLICA_SET_PRIMARY; +import static com.mongodb.connection.ServerType.UNKNOWN; +import static java.util.Collections.singletonList; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@DisplayName("ServerSessionPool") +@ExtendWith(MockitoExtension.class) +class ServerSessionPoolTest { + + private ClusterDescription connectedDescription; + private ClusterDescription unconnectedDescription; + + @Mock + private Cluster clusterMock; + + @BeforeEach + void setUp() { + connectedDescription = new ClusterDescription( + MULTIPLE, + REPLICA_SET, + singletonList( + ServerDescription.builder() + .ok(true) + .state(CONNECTED) + .address(new ServerAddress()) + .type(REPLICA_SET_PRIMARY) + .logicalSessionTimeoutMinutes(30) + .build() + ), + ClusterSettings.builder().hosts(singletonList(new ServerAddress())).build(), + ServerSettings.builder().build() + ); + + unconnectedDescription = new ClusterDescription( + MULTIPLE, + REPLICA_SET, + singletonList( + ServerDescription.builder() + .ok(true) + .state(CONNECTING) + .address(new ServerAddress()) + .type(UNKNOWN) + .logicalSessionTimeoutMinutes(null) + .build() + ), + ClusterSettings.builder().hosts(singletonList(new ServerAddress())).build(), + ServerSettings.builder().build() + ); + } + + @Test + @DisplayName("should get session from pool") + void testGetSession() { + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + + ServerSession session = pool.get(); + + assertNotNull(session); + } + + @Test + @DisplayName("should throw IllegalStateException when pool is closed") + void testThrowExceptionIfPoolClosed() { + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + pool.close(); + + assertThrows(IllegalStateException.class, pool::get); + } + + @Test + @DisplayName("should reuse released session from pool") + void testPoolSession() { + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + + ServerSession session = pool.get(); + pool.release(session); + ServerSession pooledSession = pool.get(); + + assertEquals(session, pooledSession); + } + + @Test + @DisplayName("should prune expired sessions when getting new session") + void testPruneSessionsWhenGetting() { + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(0L, MINUTES.toMillis(29) + 1); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession sessionOne = pool.get(); + + pool.release(sessionOne); + assertFalse(sessionOne.isClosed()); + + ServerSession sessionTwo = pool.get(); + + assertNotEquals(sessionTwo, sessionOne); + assertTrue(sessionOne.isClosed()); + } + + @Test + @DisplayName("should not prune session when timeout is null") + void testNotPruneSessionWhenTimeoutIsNull() { + when(clusterMock.getCurrentDescription()).thenReturn(unconnectedDescription); + + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(0L, 0L, 0L); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession session = pool.get(); + + pool.release(session); + ServerSession newSession = pool.get(); + + assertEquals(session, newSession); + } + + @Test + @DisplayName("should initialize session with correct properties") + void testInitializeSession() { + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(42L); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession session = pool.get(); + + ServerSessionPool.ServerSessionImpl sessionImpl = (ServerSessionPool.ServerSessionImpl) session; + assertEquals(42L, sessionImpl.getLastUsedAtMillis()); + assertEquals(0L, sessionImpl.getTransactionNumber()); + + BsonDocument identifier = sessionImpl.getIdentifier(); + assertNotNull(identifier); + byte[] uuid = identifier.getBinary("id").getData(); + assertNotNull(uuid); + assertEquals(16, uuid.length); + } + + @Test + @DisplayName("should advance transaction number") + void testAdvanceTransactionNumber() { + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(42L); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession session = pool.get(); + + ServerSessionPool.ServerSessionImpl sessionImpl = (ServerSessionPool.ServerSessionImpl) session; + assertEquals(0L, sessionImpl.getTransactionNumber()); + assertEquals(1L, sessionImpl.advanceTransactionNumber()); + assertEquals(1L, sessionImpl.getTransactionNumber()); + } + + @Test + @DisplayName("should end pooled sessions when pool is closed") + void testEndPooledSessionsWhenPoolClosed() { + Connection connection = mock(Connection.class); + Server server = mock(Server.class); + when(server.getConnection(any())).thenReturn(connection); + + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + when(clusterMock.selectServer(any(), any())) + .thenReturn(new ServerTuple(server, connectedDescription.getServerDescriptions().get(0))); + + when(connection.command( + any(String.class), + any(BsonDocument.class), + any(), + any(), + any(), + any() + )).thenReturn(new BsonDocument()); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + List sessions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + sessions.add(pool.get()); + } + + for (ServerSession session : sessions) { + pool.release(session); + } + + pool.close(); + + verify(clusterMock, times(1)).selectServer(any(), any()); + verify(connection, times(1)).command( + any(String.class), + argThat(endSessionsDocMatcher(sessions)), + any(), + any(), + any(), + any() + ); + verify(connection, times(1)).release(); + } + + @Test + @DisplayName("should handle MongoException during endSessions without leaking resources") + void testHandleMongoExceptionDuringEndSessionsWithoutLeakingResources() { + Connection connection = mock(Connection.class); + Server server = mock(Server.class); + when(server.getConnection(any())).thenReturn(connection); + + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + when(clusterMock.selectServer(any(), any())) + .thenReturn(new ServerTuple(server, connectedDescription.getServerDescriptions().get(0))); + + when(connection.command( + any(String.class), + any(BsonDocument.class), + any(), + any(), + any(), + any() + )).thenThrow(new MongoException("Simulated error")); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + List sessions = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + sessions.add(pool.get()); + } + + for (ServerSession session : sessions) { + pool.release(session); + } + + // Should not throw - exception is handled internally + pool.close(); + + verify(clusterMock, times(1)).selectServer(any(), any()); + verify(connection, times(1)).release(); + } + + /** + * Matcher to verify the endSessions document contains the correct session identifiers. + */ + private ArgumentMatcher endSessionsDocMatcher(List sessions) { + return doc -> { + if (!doc.containsKey("endSessions")) { + return false; + } + BsonArray endSessionsArray = doc.getArray("endSessions"); + if (endSessionsArray.size() != sessions.size()) { + return false; + } + for (int i = 0; i < sessions.size(); i++) { + ServerSession session = sessions.get(i); + BsonDocument sessionIdentifier = session.getIdentifier(); + BsonDocument arrayElement = endSessionsArray.get(i).asDocument(); + if (!sessionIdentifier.equals(arrayElement)) { + return false; + } + } + return true; + }; + } +} diff --git a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java index dd761234df9..75a60ca382f 100644 --- a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java +++ b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java @@ -39,9 +39,9 @@ class DBDecoderAdapter implements Decoder { @Override public DBObject decode(final BsonReader reader, final DecoderContext decoderContext) { - ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider); - BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput); - try { + + try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider); + BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput)) { binaryWriter.pipe(reader); BufferExposingByteArrayOutputStream byteArrayOutputStream = new BufferExposingByteArrayOutputStream(binaryWriter.getBsonOutput().getSize()); @@ -50,9 +50,6 @@ public DBObject decode(final BsonReader reader, final DecoderContext decoderCont } catch (IOException e) { // impossible with a byte array output stream throw new MongoInternalException("An unlikely IOException thrown.", e); - } finally { - binaryWriter.close(); - bsonOutput.close(); } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java index 511f9f62c6b..6d38a0731ab 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java @@ -39,6 +39,8 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.util.concurrent.atomic.AtomicBoolean; + import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL; import static com.mongodb.MongoException.UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL; import static com.mongodb.assertions.Assertions.assertNotNull; @@ -48,6 +50,7 @@ final class ClientSessionPublisherImpl extends BaseClientSessionImpl implements ClientSession { + private final AtomicBoolean closed = new AtomicBoolean(); private final MongoClientImpl mongoClient; private final OperationExecutor executor; private final TracingManager tracingManager; @@ -58,7 +61,6 @@ final class ClientSessionPublisherImpl extends BaseClientSessionImpl implements @Nullable private TransactionSpan transactionSpan; - ClientSessionPublisherImpl(final ServerSessionPool serverSessionPool, final MongoClientImpl mongoClient, final ClientSessionOptions options, final OperationExecutor executor, final TracingManager tracingManager) { super(serverSessionPool, mongoClient, options); @@ -256,10 +258,18 @@ public TransactionSpan getTransactionSpan() { @Override public void close() { - if (transactionState == TransactionState.IN) { - Mono.from(abortTransaction()).doFinally(it -> super.close()).subscribe(); - } else { - super.close(); + if (closed.compareAndSet(false, true)) { + if (transactionState == TransactionState.IN) { + Mono.from(abortTransaction()) + .doFinally(it -> { + clearTransactionContext(); + super.close(); + }) + .subscribe(); + } else { + clearTransactionContext(); + super.close(); + } } } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java index 2881b47e38e..05ca89dd048 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java @@ -24,6 +24,7 @@ import com.mongodb.MongoTimeoutException; import com.mongodb.connection.ClusterType; import com.mongodb.connection.ServerVersion; +import com.mongodb.connection.TransportSettings; import com.mongodb.reactivestreams.client.internal.MongoClientImpl; import org.bson.Document; import org.bson.conversions.Bson; @@ -33,6 +34,7 @@ import java.util.List; import static com.mongodb.ClusterFixture.TIMEOUT_DURATION; +import static com.mongodb.ClusterFixture.getOverriddenTransportSettings; import static com.mongodb.ClusterFixture.getServerApi; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; import static java.lang.Thread.sleep; @@ -67,11 +69,18 @@ public static MongoClientSettings.Builder getMongoClientSettingsBuilder() { } public static MongoClientSettings.Builder getMongoClientSettingsBuilder(final ConnectionString connectionString) { - MongoClientSettings.Builder builder = MongoClientSettings.builder(); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applyConnectionString(connectionString); + + TransportSettings overriddenTransportSettings = getOverriddenTransportSettings(); + if (overriddenTransportSettings != null) { + builder.transportSettings(overriddenTransportSettings); + } + if (getServerApi() != null) { builder.serverApi(getServerApi()); } - return builder.applyConnectionString(connectionString); + return builder; } public static String getDefaultDatabaseName() { @@ -164,6 +173,11 @@ public static synchronized ConnectionString getConnectionString() { public static MongoClientSettings.Builder getMongoClientBuilderFromConnectionString() { MongoClientSettings.Builder builder = MongoClientSettings.builder() .applyConnectionString(getConnectionString()); + + TransportSettings overriddenTransportSettings = getOverriddenTransportSettings(); + if (overriddenTransportSettings != null) { + builder.transportSettings(overriddenTransportSettings); + } if (getServerApi() != null) { builder.serverApi(getServerApi()); } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index 35189aef455..602838cff0c 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -311,6 +311,9 @@ public void cleanUp() { if (testDef != null) { postCleanUp(testDef); } + // Ask the JVM to run garbage collection. + // This should help with Netty's leak detection + System.gc(); } protected void postCleanUp(final TestDef testDef) {