connectionIdSupplier
) {
- if (!isEnabled()) {
+ if (!isEnabled()) {
return null;
}
- BsonDocument command = commandDocumentSupplier.get();
- String commandName = command.getFirstKey();
+
+ String commandName = commandDocument.getFirstKey();
if (isSensitiveCommand.test(commandName)) {
return null;
}
Span operationSpan = operationContext.getTracingSpan();
- Span span = addSpan(commandName, operationSpan != null ? operationSpan.context() : null);
+ Span span = addSpan(commandName, operationSpan != null ? operationSpan.context() : null);
- if (command.containsKey("getMore")) {
- long cursorId = command.getInt64("getMore").longValue();
+ if (commandDocument.containsKey("getMore")) {
+ long cursorId = commandDocument.getInt64("getMore").longValue();
span.tagLowCardinality(CURSOR_ID.withValue(String.valueOf(cursorId)));
if (operationSpan != null) {
operationSpan.tagLowCardinality(CURSOR_ID.withValue(String.valueOf(cursorId)));
@@ -266,4 +268,47 @@ public Span createTracingSpan(final CommandMessage message,
return span;
}
+
+ /**
+ * Creates an operation-level tracing span for a database command.
+ *
+ * The span is named "{commandName} {database}[.{collection}]" and tagged with standard
+ * low-cardinality attributes (system, namespace, collection, operation name, operation summary).
+ * The span is also set on the {@link OperationContext} for use by downstream command-level tracing.
+ *
+ * @param transactionSpan the active transaction span (for parent context), or null
+ * @param operationContext the operation context to attach the span to
+ * @param commandName the name of the command (e.g. "find", "insert")
+ * @param namespace the MongoDB namespace for the operation
+ * @return the created span, or null if tracing is disabled
+ */
+ @Nullable
+ public Span createOperationSpan(@Nullable final TransactionSpan transactionSpan,
+ final OperationContext operationContext, final String commandName, final MongoNamespace namespace) {
+ if (!isEnabled()) {
+ return null;
+ }
+ TraceContext parentContext = null;
+ if (transactionSpan != null) {
+ parentContext = transactionSpan.getContext();
+ }
+ String name = commandName + " " + namespace.getDatabaseName()
+ + (MongoNamespaceHelper.COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName())
+ ? ""
+ : "." + namespace.getCollectionName());
+
+ KeyValues keyValues = KeyValues.of(
+ SYSTEM.withValue("mongodb"),
+ NAMESPACE.withValue(namespace.getDatabaseName()));
+ if (!MongoNamespaceHelper.COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName())) {
+ keyValues = keyValues.and(COLLECTION.withValue(namespace.getCollectionName()));
+ }
+ keyValues = keyValues.and(OPERATION_NAME.withValue(commandName),
+ OPERATION_SUMMARY.withValue(name));
+
+ Span span = addSpan(name, parentContext, namespace);
+ span.tagLowCardinality(keyValues);
+ operationContext.setTracingSpan(span);
+ return span;
+ }
}
diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy
index 2d7dc04d758..f1585f82595 100644
--- a/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy
+++ b/driver-core/src/test/functional/com/mongodb/internal/connection/CommandHelperSpecification.groovy
@@ -52,6 +52,7 @@ class CommandHelperSpecification extends Specification {
}
def cleanup() {
+ InternalStreamConnection.setRecordEverything(false)
connection?.close()
}
@@ -81,5 +82,4 @@ class CommandHelperSpecification extends Specification {
!receivedDocument
receivedException instanceof MongoCommandException
}
-
}
diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java
index fc5926b3bad..81e778b4a61 100644
--- a/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java
+++ b/driver-core/src/test/functional/com/mongodb/internal/connection/DefaultConnectionPoolTest.java
@@ -127,7 +127,7 @@ public void shouldThrowOnTimeout() throws InterruptedException {
// when
TimeoutTrackingConnectionGetter connectionGetter = new TimeoutTrackingConnectionGetter(provider, timeoutSettings);
- cachedExecutor.submit(connectionGetter);
+ cachedExecutor.execute(connectionGetter);
connectionGetter.getLatch().await();
@@ -152,7 +152,7 @@ public void shouldNotUseMaxAwaitTimeMSWhenTimeoutMsIsSet() throws InterruptedExc
// when
TimeoutTrackingConnectionGetter connectionGetter = new TimeoutTrackingConnectionGetter(provider, timeoutSettings);
- cachedExecutor.submit(connectionGetter);
+ cachedExecutor.execute(connectionGetter);
sleep(70); // wait for more than maxWaitTimeMS but less than timeoutMs.
internalConnection.close();
diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy
deleted file mode 100644
index 0407baeca8a..00000000000
--- a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy
+++ /dev/null
@@ -1,201 +0,0 @@
-/*
- * Copyright 2008-present MongoDB, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package com.mongodb.internal.connection
-
-import com.mongodb.MongoInternalException
-import org.bson.io.BasicOutputBuffer
-import spock.lang.Specification
-
-import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize
-
-class ReplyHeaderSpecification extends Specification {
-
- def 'should parse reply header'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(186)
- writeInt(45)
- writeInt(23)
- writeInt(1)
- writeInt(responseFlags)
- writeLong(9000)
- writeInt(4)
- writeInt(1)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
-
- when:
- def replyHeader = new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))
-
- then:
- replyHeader.messageLength == 186
- replyHeader.requestId == 45
- replyHeader.responseTo == 23
-
- where:
- responseFlags << [0, 1, 2, 3]
- cursorNotFound << [false, true, false, true]
- queryFailure << [false, false, true, true]
- }
-
- def 'should parse reply header with compressed header'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(186)
- writeInt(45)
- writeInt(23)
- writeInt(2012)
- writeInt(1)
- writeInt(258)
- writeByte(2)
- writeInt(responseFlags)
- writeLong(9000)
- writeInt(4)
- writeInt(1)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
- def compressedHeader = new CompressedHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))
-
- when:
- def replyHeader = new ReplyHeader(byteBuf, compressedHeader)
-
- then:
- replyHeader.messageLength == 274
- replyHeader.requestId == 45
- replyHeader.responseTo == 23
-
- where:
- responseFlags << [0, 1, 2, 3]
- cursorNotFound << [false, true, false, true]
- queryFailure << [false, false, true, true]
- }
-
- def 'should throw MongoInternalException on incorrect opCode'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(36)
- writeInt(45)
- writeInt(23)
- writeInt(2)
- writeInt(0)
- writeLong(2)
- writeInt(0)
- writeInt(0)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
-
- when:
- new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))
-
- then:
- def ex = thrown(MongoInternalException)
- ex.getMessage() == 'Unexpected reply message opCode 2'
- }
-
- def 'should throw MongoInternalException on message size < 36'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(35)
- writeInt(45)
- writeInt(23)
- writeInt(1)
- writeInt(0)
- writeLong(2)
- writeInt(0)
- writeInt(0)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
-
- when:
- new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))
-
- then:
- def ex = thrown(MongoInternalException)
- ex.getMessage() == 'The reply message length 35 is less than the minimum message length 36'
- }
-
- def 'should throw MongoInternalException on message size > max message size'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(400)
- writeInt(45)
- writeInt(23)
- writeInt(1)
- writeInt(0)
- writeLong(2)
- writeInt(0)
- writeInt(0)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
-
- when:
- new ReplyHeader(byteBuf, new MessageHeader(byteBuf, 399))
-
- then:
- def ex = thrown(MongoInternalException)
- ex.getMessage() == 'The reply message length 400 is greater than the maximum message length 399'
- }
-
- def 'should throw MongoInternalException on num documents < 0'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(186)
- writeInt(45)
- writeInt(23)
- writeInt(1)
- writeInt(1)
- writeLong(9000)
- writeInt(4)
- writeInt(-1)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
-
- when:
- new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))
-
- then:
- def ex = thrown(MongoInternalException)
- ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1'
- }
-
- def 'should throw MongoInternalException on num documents < 0 with compressed header'() {
- def outputBuffer = new BasicOutputBuffer()
- outputBuffer.with {
- writeInt(186)
- writeInt(45)
- writeInt(23)
- writeInt(2012)
- writeInt(1)
- writeInt(258)
- writeByte(2)
- writeInt(1)
- writeLong(9000)
- writeInt(4)
- writeInt(-1)
- }
- def byteBuf = outputBuffer.byteBuffers.get(0)
- def compressedHeader = new CompressedHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))
-
- when:
- new ReplyHeader(byteBuf, compressedHeader)
-
- then:
- def ex = thrown(MongoInternalException)
- ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1'
- }
-}
diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java
new file mode 100644
index 00000000000..38bc96731c2
--- /dev/null
+++ b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java
@@ -0,0 +1,213 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.connection;
+
+import com.mongodb.MongoInternalException;
+import org.bson.ByteBuf;
+import org.bson.io.BasicOutputBuffer;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import java.util.List;
+
+import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+@DisplayName("ReplyHeader")
+class ReplyHeaderTest {
+
+ @ParameterizedTest(name = "with responseFlags {0}")
+ @ValueSource(ints = {0, 1, 2, 3})
+ @DisplayName("should parse reply header with various response flags")
+ void testParseReplyHeader(final int responseFlags) {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(186);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeInt(responseFlags);
+ outputBuffer.writeLong(9000);
+ outputBuffer.writeInt(4);
+ outputBuffer.writeInt(1);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+ ReplyHeader replyHeader = new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()));
+
+ assertEquals(186, replyHeader.getMessageLength());
+ assertEquals(45, replyHeader.getRequestId());
+ assertEquals(23, replyHeader.getResponseTo());
+
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+
+ @ParameterizedTest(name = "with responseFlags {0}")
+ @ValueSource(ints = {0, 1, 2, 3})
+ @DisplayName("should parse reply header with compressed header and various response flags")
+ void testParseReplyHeaderWithCompressedHeader(final int responseFlags) {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(186);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(2012);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeInt(258);
+ outputBuffer.writeByte(2);
+ outputBuffer.writeInt(responseFlags);
+ outputBuffer.writeLong(9000);
+ outputBuffer.writeInt(4);
+ outputBuffer.writeInt(1);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+ CompressedHeader compressedHeader = new CompressedHeader(byteBuf,
+ new MessageHeader(byteBuf, getDefaultMaxMessageSize()));
+ ReplyHeader replyHeader = new ReplyHeader(byteBuf, compressedHeader);
+
+ assertEquals(274, replyHeader.getMessageLength());
+ assertEquals(45, replyHeader.getRequestId());
+ assertEquals(23, replyHeader.getResponseTo());
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+
+ @Test
+ @DisplayName("should throw MongoInternalException on incorrect opCode")
+ void testThrowExceptionOnIncorrectOpCode() {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(36);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(2);
+ outputBuffer.writeInt(0);
+ outputBuffer.writeLong(2);
+ outputBuffer.writeInt(0);
+ outputBuffer.writeInt(0);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+
+ MongoInternalException ex = assertThrows(MongoInternalException.class,
+ () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())));
+
+ assertEquals("Unexpected reply message opCode 2", ex.getMessage());
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+
+ @Test
+ @DisplayName("should throw MongoInternalException on message size less than 36 bytes")
+ void testThrowExceptionOnMessageSizeLessThan36() {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(35);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeInt(0);
+ outputBuffer.writeLong(2);
+ outputBuffer.writeInt(0);
+ outputBuffer.writeInt(0);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+ MongoInternalException ex = assertThrows(MongoInternalException.class,
+ () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())));
+
+ assertEquals("The reply message length 35 is less than the minimum message length 36", ex.getMessage());
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+
+ @Test
+ @DisplayName("should throw MongoInternalException on message size exceeding max message size")
+ void testThrowExceptionOnMessageSizeExceedingMax() {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(400);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeInt(0);
+ outputBuffer.writeLong(2);
+ outputBuffer.writeInt(0);
+ outputBuffer.writeInt(0);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+ MongoInternalException ex = assertThrows(MongoInternalException.class,
+ () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, 399)));
+
+ assertEquals("The reply message length 400 is greater than the maximum message length 399", ex.getMessage());
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+
+ @Test
+ @DisplayName("should throw MongoInternalException on negative number of returned documents")
+ void testThrowExceptionOnNegativeNumberOfDocuments() {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(186);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeLong(9000);
+ outputBuffer.writeInt(4);
+ outputBuffer.writeInt(-1);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+ MongoInternalException ex = assertThrows(MongoInternalException.class,
+ () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())));
+
+ assertEquals("The reply message number of returned documents, -1, is expected to be 1", ex.getMessage());
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+
+ @Test
+ @DisplayName("should throw MongoInternalException on negative number of documents with compressed header")
+ void testThrowExceptionOnNegativeNumberOfDocumentsWithCompressedHeader() {
+ try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) {
+ outputBuffer.writeInt(186);
+ outputBuffer.writeInt(45);
+ outputBuffer.writeInt(23);
+ outputBuffer.writeInt(2012);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeInt(258);
+ outputBuffer.writeByte(2);
+ outputBuffer.writeInt(1);
+ outputBuffer.writeLong(9000);
+ outputBuffer.writeInt(4);
+ outputBuffer.writeInt(-1);
+
+ List byteBuffers = outputBuffer.getByteBuffers();
+ ByteBuf byteBuf = byteBuffers.get(0);
+ CompressedHeader compressedHeader = new CompressedHeader(byteBuf,
+ new MessageHeader(byteBuf, getDefaultMaxMessageSize()));
+
+ MongoInternalException ex = assertThrows(MongoInternalException.class,
+ () -> new ReplyHeader(byteBuf, compressedHeader));
+
+ assertEquals("The reply message number of returned documents, -1, is expected to be 1", ex.getMessage());
+ byteBuffers.forEach(ByteBuf::release);
+ }
+ }
+}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java
index be4526aada7..5f736f421c2 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java
@@ -331,9 +331,10 @@ static Stream shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS()
);
}
- @ParameterizedTest
- @MethodSource
@DisplayName("should choose timeoutMS when timeoutMS is less than connectTimeoutMS")
+ @ParameterizedTest(name = "should choose timeoutMS when timeoutMS is less than connectTimeoutMS. "
+ + "Parameters: connectTimeoutMS: {0}, timeoutMS: {1}, expected: {2}")
+ @MethodSource
void shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS(final Long connectTimeoutMS,
final Long timeoutMS,
final long expected) {
@@ -345,7 +346,7 @@ void shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS(final Long connectTim
0));
long calculatedTimeoutMS = timeoutContext.getConnectTimeoutMs();
- assertTrue(expected - calculatedTimeoutMS <= 1);
+ assertTrue(expected - calculatedTimeoutMS <= 2);
}
private TimeoutContextTest() {
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java
index f7cefbf57c0..4b05607a56f 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonArrayTest.java
@@ -46,10 +46,9 @@
import org.bson.io.BasicOutputBuffer;
import org.bson.types.Decimal128;
import org.bson.types.ObjectId;
+import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Date;
import java.util.Iterator;
@@ -63,145 +62,242 @@
import static org.bson.BsonBoolean.FALSE;
import static org.bson.BsonBoolean.TRUE;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
+@DisplayName("ByteBufBsonArray")
class ByteBufBsonArrayTest {
+ // Basic Operations
+
@Test
+ @DisplayName("getValues() returns array values")
void testGetValues() {
List values = asList(new BsonInt32(0), new BsonInt32(1), new BsonInt32(2));
- ByteBufBsonArray bsonArray = fromBsonValues(values);
- assertEquals(values, bsonArray.getValues());
+ try (ByteBufBsonArray bsonArray = fromBsonValues(values)) {
+ assertEquals(values, bsonArray.getValues());
+ }
}
@Test
+ @DisplayName("size() returns correct count")
void testSize() {
- assertEquals(0, fromBsonValues(emptyList()).size());
- assertEquals(1, fromBsonValues(singletonList(TRUE)).size());
- assertEquals(2, fromBsonValues(asList(TRUE, TRUE)).size());
+ try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) {
+ assertEquals(0, bsonArray.size());
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) {
+ assertEquals(1, bsonArray.size());
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, TRUE))) {
+ assertEquals(2, bsonArray.size());
+ }
}
@Test
+ @DisplayName("isEmpty() returns correct result")
void testIsEmpty() {
- assertTrue(fromBsonValues(emptyList()).isEmpty());
- assertFalse(fromBsonValues(singletonList(TRUE)).isEmpty());
- assertFalse(fromBsonValues(asList(TRUE, TRUE)).isEmpty());
+ try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) {
+ assertTrue(bsonArray.isEmpty());
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) {
+ assertFalse(bsonArray.isEmpty());
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, TRUE))) {
+ assertFalse(bsonArray.isEmpty());
+ }
}
@Test
+ @DisplayName("contains() finds existing values and rejects missing values")
void testContains() {
- assertFalse(fromBsonValues(emptyList()).contains(TRUE));
- assertTrue(fromBsonValues(singletonList(TRUE)).contains(TRUE));
- assertTrue(fromBsonValues(asList(FALSE, TRUE)).contains(TRUE));
- assertFalse(fromBsonValues(singletonList(FALSE)).contains(TRUE));
- assertFalse(fromBsonValues(asList(FALSE, FALSE)).contains(TRUE));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) {
+ assertFalse(bsonArray.contains(TRUE));
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) {
+ assertTrue(bsonArray.contains(TRUE));
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(FALSE, TRUE))) {
+ assertTrue(bsonArray.contains(TRUE));
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(FALSE))) {
+ assertFalse(bsonArray.contains(TRUE));
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(FALSE, FALSE))) {
+ assertFalse(bsonArray.contains(TRUE));
+ }
}
@Test
+ @DisplayName("iterator() navigates through all elements")
void testIterator() {
- Iterator iterator = fromBsonValues(emptyList()).iterator();
- assertFalse(iterator.hasNext());
- assertThrows(NoSuchElementException.class, iterator::next);
-
- iterator = fromBsonValues(singletonList(TRUE)).iterator();
- assertTrue(iterator.hasNext());
- assertEquals(TRUE, iterator.next());
- assertFalse(iterator.hasNext());
- assertThrows(NoSuchElementException.class, iterator::next);
-
- iterator = fromBsonValues(asList(TRUE, FALSE)).iterator();
- assertTrue(iterator.hasNext());
- assertEquals(TRUE, iterator.next());
- assertTrue(iterator.hasNext());
- assertEquals(FALSE, iterator.next());
- assertFalse(iterator.hasNext());
- assertThrows(NoSuchElementException.class, iterator::next);
+ try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) {
+ Iterator iterator = bsonArray.iterator();
+ assertFalse(iterator.hasNext());
+ assertThrows(NoSuchElementException.class, iterator::next);
+ }
+
+ try (ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE))) {
+ Iterator iterator = bsonArray.iterator();
+ assertTrue(iterator.hasNext());
+ assertEquals(TRUE, iterator.next());
+ assertFalse(iterator.hasNext());
+ assertThrows(NoSuchElementException.class, iterator::next);
+ }
+
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ Iterator iterator = bsonArray.iterator();
+ assertTrue(iterator.hasNext());
+ assertEquals(TRUE, iterator.next());
+ assertTrue(iterator.hasNext());
+ assertEquals(FALSE, iterator.next());
+ assertFalse(iterator.hasNext());
+ assertThrows(NoSuchElementException.class, iterator::next);
+ }
+ }
+
+ @Test
+ @DisplayName("Iterators ensure the resource is still open")
+ void iteratorsEnsureResourceIsStillOpen() {
+ ByteBufBsonArray bsonArray = fromBsonValues(singletonList(TRUE));
+ Iterator arrayIterator = bsonArray.iterator();
+
+ assertDoesNotThrow(arrayIterator::hasNext);
+
+ bsonArray.close();
+ assertThrows(IllegalStateException.class, arrayIterator::hasNext);
}
@Test
+ @DisplayName("toArray() converts array to Object array")
void testToArray() {
- assertArrayEquals(new BsonValue[]{TRUE, FALSE}, fromBsonValues(asList(TRUE, FALSE)).toArray());
- assertArrayEquals(new BsonValue[]{TRUE, FALSE}, fromBsonValues(asList(TRUE, FALSE)).toArray(new BsonValue[0]));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertArrayEquals(new BsonValue[]{TRUE, FALSE}, bsonArray.toArray());
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertArrayEquals(new BsonValue[]{TRUE, FALSE}, bsonArray.toArray(new BsonValue[0]));
+ }
}
@Test
+ @DisplayName("containsAll() checks if all elements are present")
void testContainsAll() {
- assertTrue(fromBsonValues(asList(TRUE, FALSE)).containsAll(asList(TRUE, FALSE)));
- assertFalse(fromBsonValues(asList(TRUE, TRUE)).containsAll(asList(TRUE, FALSE)));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertTrue(bsonArray.containsAll(asList(TRUE, FALSE)));
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, TRUE))) {
+ assertFalse(bsonArray.containsAll(asList(TRUE, FALSE)));
+ }
}
@Test
+ @DisplayName("get() retrieves element at index and throws for out of bounds")
void testGet() {
- ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE));
- assertEquals(TRUE, bsonArray.get(0));
- assertEquals(FALSE, bsonArray.get(1));
- assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(-1));
- assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(2));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertEquals(TRUE, bsonArray.get(0));
+ assertEquals(FALSE, bsonArray.get(1));
+ assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(-1));
+ assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.get(2));
+ }
}
@Test
+ @DisplayName("indexOf() finds element position or returns -1")
void testIndexOf() {
- ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE));
- assertEquals(0, bsonArray.indexOf(TRUE));
- assertEquals(1, bsonArray.indexOf(FALSE));
- assertEquals(-1, bsonArray.indexOf(BsonNull.VALUE));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertEquals(0, bsonArray.indexOf(TRUE));
+ assertEquals(1, bsonArray.indexOf(FALSE));
+ assertEquals(-1, bsonArray.indexOf(BsonNull.VALUE));
+ }
}
@Test
+ @DisplayName("lastIndexOf() finds last element position or returns -1")
void testLastIndexOf() {
- ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE, FALSE));
- assertEquals(2, bsonArray.lastIndexOf(TRUE));
- assertEquals(3, bsonArray.lastIndexOf(FALSE));
- assertEquals(-1, bsonArray.lastIndexOf(BsonNull.VALUE));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE, FALSE))) {
+ assertEquals(2, bsonArray.lastIndexOf(TRUE));
+ assertEquals(3, bsonArray.lastIndexOf(FALSE));
+ assertEquals(-1, bsonArray.lastIndexOf(BsonNull.VALUE));
+ }
}
@Test
+ @DisplayName("listIterator() supports bidirectional iteration")
void testListIterator() {
// implementation is delegated to ArrayList, so not much testing is needed
- ListIterator iterator = fromBsonValues(emptyList()).listIterator();
- assertFalse(iterator.hasNext());
- assertFalse(iterator.hasPrevious());
+ try (ByteBufBsonArray bsonArray = fromBsonValues(emptyList())) {
+ ListIterator iterator = bsonArray.listIterator();
+ assertFalse(iterator.hasNext());
+ assertFalse(iterator.hasPrevious());
+ }
}
@Test
+ @DisplayName("subList() returns subset of array elements")
void testSubList() {
- ByteBufBsonArray bsonArray = fromBsonValues(asList(new BsonInt32(0), new BsonInt32(1), new BsonInt32(2)));
- assertEquals(emptyList(), bsonArray.subList(0, 0));
- assertEquals(singletonList(new BsonInt32(0)), bsonArray.subList(0, 1));
- assertEquals(singletonList(new BsonInt32(2)), bsonArray.subList(2, 3));
- assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(-1, 1));
- assertThrows(IllegalArgumentException.class, () -> bsonArray.subList(3, 2));
- assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(2, 4));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(new BsonInt32(0), new BsonInt32(1), new BsonInt32(2)))) {
+ assertEquals(emptyList(), bsonArray.subList(0, 0));
+ assertEquals(singletonList(new BsonInt32(0)), bsonArray.subList(0, 1));
+ assertEquals(singletonList(new BsonInt32(2)), bsonArray.subList(2, 3));
+ assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(-1, 1));
+ assertThrows(IllegalArgumentException.class, () -> bsonArray.subList(3, 2));
+ assertThrows(IndexOutOfBoundsException.class, () -> bsonArray.subList(2, 4));
+ }
}
+ // Equality and HashCode
+
@Test
+ @DisplayName("equals() and hashCode() work correctly")
void testEquals() {
- assertEquals(new BsonArray(asList(TRUE, FALSE)), fromBsonValues(asList(TRUE, FALSE)));
- assertEquals(fromBsonValues(asList(TRUE, FALSE)), new BsonArray(asList(TRUE, FALSE)));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertEquals(new BsonArray(asList(TRUE, FALSE)), bsonArray);
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertEquals(bsonArray, new BsonArray(asList(TRUE, FALSE)));
+ }
- assertNotEquals(new BsonArray(asList(TRUE, FALSE)), fromBsonValues(asList(FALSE, TRUE)));
- assertNotEquals(fromBsonValues(asList(TRUE, FALSE)), new BsonArray(asList(FALSE, TRUE)));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(FALSE, TRUE))) {
+ assertNotEquals(new BsonArray(asList(TRUE, FALSE)), bsonArray);
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertNotEquals(bsonArray, new BsonArray(asList(FALSE, TRUE)));
+ }
- assertNotEquals(new BsonArray(asList(TRUE, FALSE)), fromBsonValues(asList(TRUE, FALSE, TRUE)));
- assertNotEquals(fromBsonValues(asList(TRUE, FALSE)), new BsonArray(asList(TRUE, FALSE, TRUE)));
- assertNotEquals(fromBsonValues(asList(TRUE, FALSE, TRUE)), new BsonArray(asList(TRUE, FALSE)));
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE))) {
+ assertNotEquals(new BsonArray(asList(TRUE, FALSE)), bsonArray);
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertNotEquals(bsonArray, new BsonArray(asList(TRUE, FALSE, TRUE)));
+ }
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE, TRUE))) {
+ assertNotEquals(bsonArray, new BsonArray(asList(TRUE, FALSE)));
+ }
}
@Test
+ @DisplayName("hashCode() is consistent with equals()")
void testHashCode() {
- assertEquals(new BsonArray(asList(TRUE, FALSE)).hashCode(), fromBsonValues(asList(TRUE, FALSE)).hashCode());
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertEquals(new BsonArray(asList(TRUE, FALSE)).hashCode(), bsonArray.hashCode());
+ }
}
@Test
+ @DisplayName("toString() returns equivalent string")
void testToString() {
- assertEquals(new BsonArray(asList(TRUE, FALSE)).toString(), fromBsonValues(asList(TRUE, FALSE)).toString());
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(TRUE, FALSE))) {
+ assertEquals(new BsonArray(asList(TRUE, FALSE)).toString(), bsonArray.toString());
+ }
}
+ // Type Support
+
@Test
+ @DisplayName("All BSON types are supported")
void testAllBsonTypes() {
BsonValue bsonNull = new BsonNull();
BsonValue bsonInt32 = new BsonInt32(42);
@@ -225,44 +321,47 @@ void testAllBsonTypes() {
BsonValue document = new BsonDocument("a", new BsonInt32(1));
BsonValue dbPointer = new BsonDbPointer("db.coll", new ObjectId());
- ByteBufBsonArray bsonArray = fromBsonValues(asList(
+ try (ByteBufBsonArray bsonArray = fromBsonValues(asList(
bsonNull, bsonInt32, bsonInt64, bsonDecimal128, bsonBoolean, bsonDateTime, bsonDouble, bsonString, minKey, maxKey,
- javaScript, objectId, scope, regularExpression, symbol, timestamp, undefined, binary, array, document, dbPointer));
- assertEquals(bsonNull, bsonArray.get(0));
- assertEquals(bsonInt32, bsonArray.get(1));
- assertEquals(bsonInt64, bsonArray.get(2));
- assertEquals(bsonDecimal128, bsonArray.get(3));
- assertEquals(bsonBoolean, bsonArray.get(4));
- assertEquals(bsonDateTime, bsonArray.get(5));
- assertEquals(bsonDouble, bsonArray.get(6));
- assertEquals(bsonString, bsonArray.get(7));
- assertEquals(minKey, bsonArray.get(8));
- assertEquals(maxKey, bsonArray.get(9));
- assertEquals(javaScript, bsonArray.get(10));
- assertEquals(objectId, bsonArray.get(11));
- assertEquals(scope, bsonArray.get(12));
- assertEquals(regularExpression, bsonArray.get(13));
- assertEquals(symbol, bsonArray.get(14));
- assertEquals(timestamp, bsonArray.get(15));
- assertEquals(undefined, bsonArray.get(16));
- assertEquals(binary, bsonArray.get(17));
- assertEquals(array, bsonArray.get(18));
- assertEquals(document, bsonArray.get(19));
- assertEquals(dbPointer, bsonArray.get(20));
+ javaScript, objectId, scope, regularExpression, symbol, timestamp, undefined, binary, array, document, dbPointer))) {
+ assertEquals(bsonNull, bsonArray.get(0));
+ assertEquals(bsonInt32, bsonArray.get(1));
+ assertEquals(bsonInt64, bsonArray.get(2));
+ assertEquals(bsonDecimal128, bsonArray.get(3));
+ assertEquals(bsonBoolean, bsonArray.get(4));
+ assertEquals(bsonDateTime, bsonArray.get(5));
+ assertEquals(bsonDouble, bsonArray.get(6));
+ assertEquals(bsonString, bsonArray.get(7));
+ assertEquals(minKey, bsonArray.get(8));
+ assertEquals(maxKey, bsonArray.get(9));
+ assertEquals(javaScript, bsonArray.get(10));
+ assertEquals(objectId, bsonArray.get(11));
+ assertEquals(scope, bsonArray.get(12));
+ assertEquals(regularExpression, bsonArray.get(13));
+ assertEquals(symbol, bsonArray.get(14));
+ assertEquals(timestamp, bsonArray.get(15));
+ assertEquals(undefined, bsonArray.get(16));
+ assertEquals(binary, bsonArray.get(17));
+ assertEquals(array, bsonArray.get(18));
+ assertEquals(document, bsonArray.get(19));
+ assertEquals(dbPointer, bsonArray.get(20));
+ }
}
static ByteBufBsonArray fromBsonValues(final List extends BsonValue> values) {
- BsonDocument document = new BsonDocument()
- .append("a", new BsonArray(values));
+ BsonDocument document = new BsonDocument("a", new BsonArray(values));
BasicOutputBuffer buffer = new BasicOutputBuffer();
new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build());
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- try {
- buffer.pipe(baos);
- } catch (IOException e) {
- throw new RuntimeException("impossible!");
- }
- ByteBuf documentByteBuf = new ByteBufNIO(ByteBuffer.wrap(baos.toByteArray()));
- return (ByteBufBsonArray) new ByteBufBsonDocument(documentByteBuf).entrySet().iterator().next().getValue();
+ byte[] bytes = new byte[buffer.getPosition()];
+ System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, bytes.length);
+ // Skip past the outer document header to the array value bytes.
+ // Document format: [4-byte size][type byte (0x04)][field name "a\0"][array bytes...][0x00]
+ int arrayOffset = 4 + 1 + 2; // doc size + type byte + "a" + null terminator
+ int arraySize = (bytes[arrayOffset] & 0xFF)
+ | ((bytes[arrayOffset + 1] & 0xFF) << 8)
+ | ((bytes[arrayOffset + 2] & 0xFF) << 16)
+ | ((bytes[arrayOffset + 3] & 0xFF) << 24);
+ ByteBuf arrayByteBuf = new ByteBufNIO(ByteBuffer.wrap(bytes, arrayOffset, arraySize));
+ return new ByteBufBsonArray(arrayByteBuf);
}
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy
deleted file mode 100644
index 8dc599706a9..00000000000
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy
+++ /dev/null
@@ -1,313 +0,0 @@
-/*
- * Copyright 2008-present MongoDB, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.mongodb.internal.connection
-
-import org.bson.BsonArray
-import org.bson.BsonBinaryWriter
-import org.bson.BsonBoolean
-import org.bson.BsonDocument
-import org.bson.BsonInt32
-import org.bson.BsonNull
-import org.bson.BsonValue
-import org.bson.ByteBuf
-import org.bson.ByteBufNIO
-import org.bson.codecs.BsonDocumentCodec
-import org.bson.codecs.DecoderContext
-import org.bson.codecs.EncoderContext
-import org.bson.io.BasicOutputBuffer
-import org.bson.json.JsonMode
-import org.bson.json.JsonWriterSettings
-import spock.lang.Specification
-
-import java.nio.ByteBuffer
-
-import static java.util.Arrays.asList
-
-class ByteBufBsonDocumentSpecification extends Specification {
- def emptyDocumentByteBuf = new ByteBufNIO(ByteBuffer.wrap([5, 0, 0, 0, 0] as byte[]))
- ByteBuf documentByteBuf
- ByteBufBsonDocument emptyByteBufDocument = new ByteBufBsonDocument(emptyDocumentByteBuf)
- def document = new BsonDocument()
- .append('a', new BsonInt32(1))
- .append('b', new BsonInt32(2))
- .append('c', new BsonDocument('x', BsonBoolean.TRUE))
- .append('d', new BsonArray(asList(new BsonDocument('y', BsonBoolean.FALSE), new BsonInt32(1))))
-
- ByteBufBsonDocument byteBufDocument
-
- def setup() {
- def buffer = new BasicOutputBuffer()
- new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build())
- ByteArrayOutputStream baos = new ByteArrayOutputStream()
- buffer.pipe(baos)
- documentByteBuf = new ByteBufNIO(ByteBuffer.wrap(baos.toByteArray()))
- byteBufDocument = new ByteBufBsonDocument(documentByteBuf)
- }
-
- def 'get should get the value of the given key'() {
- expect:
- emptyByteBufDocument.get('a') == null
- byteBufDocument.get('z') == null
- byteBufDocument.get('a') == new BsonInt32(1)
- byteBufDocument.get('b') == new BsonInt32(2)
- }
-
- def 'get should throw if the key is null'() {
- when:
- byteBufDocument.get(null)
-
- then:
- thrown(IllegalArgumentException)
- documentByteBuf.referenceCount == 1
- }
-
- def 'containKey should throw if the key name is null'() {
- when:
- byteBufDocument.containsKey(null)
-
- then:
- thrown(IllegalArgumentException)
- documentByteBuf.referenceCount == 1
- }
-
- def 'containsKey should find an existing key'() {
- expect:
- byteBufDocument.containsKey('a')
- byteBufDocument.containsKey('b')
- byteBufDocument.containsKey('c')
- byteBufDocument.containsKey('d')
- documentByteBuf.referenceCount == 1
- }
-
- def 'containsKey should not find a non-existing key'() {
- expect:
- !byteBufDocument.containsKey('e')
- !byteBufDocument.containsKey('x')
- !byteBufDocument.containsKey('y')
- documentByteBuf.referenceCount == 1
- }
-
- def 'containValue should find an existing value'() {
- expect:
- byteBufDocument.containsValue(document.get('a'))
- byteBufDocument.containsValue(document.get('b'))
- byteBufDocument.containsValue(document.get('c'))
- byteBufDocument.containsValue(document.get('d'))
- documentByteBuf.referenceCount == 1
- }
-
- def 'containValue should not find a non-existing value'() {
- expect:
- !byteBufDocument.containsValue(new BsonInt32(3))
- !byteBufDocument.containsValue(new BsonDocument('e', BsonBoolean.FALSE))
- !byteBufDocument.containsValue(new BsonArray(asList(new BsonInt32(2), new BsonInt32(4))))
- documentByteBuf.referenceCount == 1
- }
-
- def 'isEmpty should return false when the document is not empty'() {
- expect:
- !byteBufDocument.isEmpty()
- documentByteBuf.referenceCount == 1
- }
-
- def 'isEmpty should return true when the document is empty'() {
- expect:
- emptyByteBufDocument.isEmpty()
- emptyDocumentByteBuf.referenceCount == 1
- }
-
- def 'should get correct size'() {
- expect:
- emptyByteBufDocument.size() == 0
- byteBufDocument.size() == 4
- documentByteBuf.referenceCount == 1
- emptyDocumentByteBuf.referenceCount == 1
- }
-
- def 'should get correct key set'() {
- expect:
- emptyByteBufDocument.keySet().isEmpty()
- byteBufDocument.keySet() == ['a', 'b', 'c', 'd'] as Set
- documentByteBuf.referenceCount == 1
- emptyDocumentByteBuf.referenceCount == 1
- }
-
- def 'should get correct values set'() {
- expect:
- emptyByteBufDocument.values().isEmpty()
- byteBufDocument.values() as Set == [document.get('a'), document.get('b'), document.get('c'), document.get('d')] as Set
- documentByteBuf.referenceCount == 1
- emptyDocumentByteBuf.referenceCount == 1
- }
-
- def 'should get correct entry set'() {
- expect:
- emptyByteBufDocument.entrySet().isEmpty()
- byteBufDocument.entrySet() == [new TestEntry('a', document.get('a')),
- new TestEntry('b', document.get('b')),
- new TestEntry('c', document.get('c')),
- new TestEntry('d', document.get('d'))] as Set
- documentByteBuf.referenceCount == 1
- emptyDocumentByteBuf.referenceCount == 1
- }
-
- def 'all write methods should throw UnsupportedOperationException'() {
- when:
- byteBufDocument.clear()
-
- then:
- thrown(UnsupportedOperationException)
-
- when:
- byteBufDocument.put('x', BsonNull.VALUE)
-
- then:
- thrown(UnsupportedOperationException)
-
- when:
- byteBufDocument.append('x', BsonNull.VALUE)
-
- then:
- thrown(UnsupportedOperationException)
-
- when:
- byteBufDocument.putAll(new BsonDocument('x', BsonNull.VALUE))
-
- then:
- thrown(UnsupportedOperationException)
-
- when:
- byteBufDocument.remove(BsonNull.VALUE)
-
- then:
- thrown(UnsupportedOperationException)
- }
-
- def 'should get first key'() {
- expect:
- byteBufDocument.getFirstKey() == document.keySet().iterator().next()
- documentByteBuf.referenceCount == 1
- }
-
- def 'getFirstKey should throw NoSuchElementException if the document is empty'() {
- when:
- emptyByteBufDocument.getFirstKey()
-
- then:
- thrown(NoSuchElementException)
- emptyDocumentByteBuf.referenceCount == 1
- }
-
- def 'should create BsonReader'() {
- when:
- def reader = document.asBsonReader()
-
- then:
- new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()) == document
-
- cleanup:
- reader.close()
- }
-
- def 'clone should make a deep copy'() {
- when:
- BsonDocument cloned = byteBufDocument.clone()
-
- then:
- cloned == byteBufDocument
- documentByteBuf.referenceCount == 1
- }
-
- def 'should serialize and deserialize'() {
- given:
- def baos = new ByteArrayOutputStream()
- def oos = new ObjectOutputStream(baos)
-
- when:
- oos.writeObject(byteBufDocument)
- def bais = new ByteArrayInputStream(baos.toByteArray())
- def ois = new ObjectInputStream(bais)
- def deserializedDocument = ois.readObject()
-
- then:
- byteBufDocument == deserializedDocument
- documentByteBuf.referenceCount == 1
- }
-
- def 'toJson should return equivalent'() {
- expect:
- document.toJson() == byteBufDocument.toJson()
- documentByteBuf.referenceCount == 1
- }
-
- def 'toJson should be callable multiple times'() {
- expect:
- byteBufDocument.toJson()
- byteBufDocument.toJson()
- documentByteBuf.referenceCount == 1
- }
-
- def 'size should be callable multiple times'() {
- expect:
- byteBufDocument.size()
- byteBufDocument.size()
- documentByteBuf.referenceCount == 1
- }
-
- def 'toJson should respect JsonWriteSettings'() {
- given:
- def settings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build()
-
- expect:
- document.toJson(settings) == byteBufDocument.toJson(settings)
- }
-
- def 'toJson should return equivalent when a ByteBufBsonDocument is nested in a BsonDocument'() {
- given:
- def topLevel = new BsonDocument('nested', byteBufDocument)
-
- expect:
- new BsonDocument('nested', document).toJson() == topLevel.toJson()
- }
-
- class TestEntry implements Map.Entry {
-
- private final String key
- private BsonValue value
-
- TestEntry(String key, BsonValue value) {
- this.key = key
- this.value = value
- }
-
- @Override
- String getKey() {
- key
- }
-
- @Override
- BsonValue getValue() {
- value
- }
-
- @Override
- BsonValue setValue(final BsonValue value) {
- this.value = value
- }
- }
-
-}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java
new file mode 100644
index 00000000000..f3744057a18
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java
@@ -0,0 +1,779 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.connection;
+
+import org.bson.BsonArray;
+import org.bson.BsonBinaryWriter;
+import org.bson.BsonBoolean;
+import org.bson.BsonDocument;
+import org.bson.BsonInt32;
+import org.bson.BsonReader;
+import org.bson.BsonString;
+import org.bson.BsonValue;
+import org.bson.ByteBuf;
+import org.bson.ByteBufNIO;
+import org.bson.RawBsonDocument;
+import org.bson.codecs.BsonDocumentCodec;
+import org.bson.codecs.DecoderContext;
+import org.bson.codecs.EncoderContext;
+import org.bson.io.BasicOutputBuffer;
+import org.bson.json.JsonMode;
+import org.bson.json.JsonWriterSettings;
+import org.jetbrains.annotations.NotNull;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Set;
+
+import static java.util.Arrays.asList;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+
+@DisplayName("ByteBufBsonDocument")
+class ByteBufBsonDocumentTest {
+ private ByteBuf documentByteBuf;
+ private ByteBufBsonDocument emptyByteBufDocument;
+ private BsonDocument document;
+
+ @BeforeEach
+ void setUp() {
+ ByteBuf emptyDocumentByteBuf = new ByteBufNIO(ByteBuffer.wrap(new byte[]{5, 0, 0, 0, 0}));
+ emptyByteBufDocument = new ByteBufBsonDocument(emptyDocumentByteBuf);
+
+ document = new BsonDocument()
+ .append("a", new BsonInt32(1))
+ .append("b", new BsonInt32(2))
+ .append("c", new BsonDocument("x", BsonBoolean.TRUE))
+ .append("d", new BsonArray(asList(
+ new BsonDocument("y", BsonBoolean.FALSE),
+ new BsonInt32(1)
+ )));
+
+ RawBsonDocument rawBsonDocument = RawBsonDocument.parse(document.toString());
+ documentByteBuf = rawBsonDocument.getByteBuffer();
+ }
+
+ @AfterEach
+ void tearDown() {
+ if (emptyByteBufDocument != null) {
+ emptyByteBufDocument.close();
+ }
+ }
+
+ // Basic Operations
+
+ @Test
+ @DisplayName("get() returns value for existing key, null for missing key")
+ void getShouldReturnCorrectValue() {
+ assertNull(emptyByteBufDocument.get("a"));
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertNull(byteBufDocument.get("z"));
+ assertEquals(new BsonInt32(1), byteBufDocument.get("a"));
+ assertEquals(new BsonInt32(2), byteBufDocument.get("b"));
+ }
+ }
+
+ @Test
+ @DisplayName("get() throws IllegalArgumentException for null key")
+ void getShouldThrowForNullKey() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertThrows(IllegalArgumentException.class, () -> byteBufDocument.get(null));
+ assertEquals(1, documentByteBuf.getReferenceCount());
+ }
+ }
+
+ @Test
+ @DisplayName("containsKey() finds existing keys and rejects missing keys")
+ void containsKeyShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertThrows(IllegalArgumentException.class, () -> byteBufDocument.containsKey(null));
+ assertTrue(byteBufDocument.containsKey("a"));
+ assertTrue(byteBufDocument.containsKey("d"));
+ assertFalse(byteBufDocument.containsKey("z"));
+ assertEquals(1, documentByteBuf.getReferenceCount());
+ }
+ }
+
+ @Test
+ @DisplayName("containsValue() finds existing values and rejects missing values")
+ void containsValueShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertTrue(byteBufDocument.containsValue(document.get("a")));
+ assertTrue(byteBufDocument.containsValue(document.get("c")));
+ assertFalse(byteBufDocument.containsValue(new BsonInt32(999)));
+ assertEquals(1, documentByteBuf.getReferenceCount());
+ }
+ }
+
+ @Test
+ @DisplayName("isEmpty() returns correct result")
+ void isEmptyShouldWork() {
+ assertTrue(emptyByteBufDocument.isEmpty());
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertFalse(byteBufDocument.isEmpty());
+ }
+ }
+
+ @Test
+ @DisplayName("size() returns correct count")
+ void sizeShouldWork() {
+ assertEquals(0, emptyByteBufDocument.size());
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(4, byteBufDocument.size());
+ assertEquals(4, byteBufDocument.size()); // Verify caching works
+ }
+ }
+
+ @Test
+ @DisplayName("getFirstKey() returns first key or throws for empty document")
+ void getFirstKeyShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals("a", byteBufDocument.getFirstKey());
+ }
+ assertThrows(NoSuchElementException.class, () -> emptyByteBufDocument.getFirstKey());
+ }
+
+ // Collection Views
+
+ @Test
+ @DisplayName("keySet() returns all keys")
+ void keySetShouldWork() {
+ assertTrue(emptyByteBufDocument.keySet().isEmpty());
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(new HashSet<>(asList("a", "b", "c", "d")), byteBufDocument.keySet());
+ }
+ }
+
+ @Test
+ @DisplayName("values() returns all values")
+ void valuesShouldWork() {
+ assertTrue(emptyByteBufDocument.values().isEmpty());
+ Set expected = new HashSet<>(asList(
+ document.get("a"), document.get("b"), document.get("c"), document.get("d")
+ ));
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(expected, new HashSet<>(byteBufDocument.values()));
+ }
+ }
+
+ @Test
+ @DisplayName("entrySet() returns all entries")
+ void entrySetShouldWork() {
+ assertTrue(emptyByteBufDocument.entrySet().isEmpty());
+ Set> expected = new HashSet<>(asList(
+ new AbstractMap.SimpleImmutableEntry<>("a", document.get("a")),
+ new AbstractMap.SimpleImmutableEntry<>("b", document.get("b")),
+ new AbstractMap.SimpleImmutableEntry<>("c", document.get("c")),
+ new AbstractMap.SimpleImmutableEntry<>("d", document.get("d"))
+ ));
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(expected, byteBufDocument.entrySet());
+ }
+ }
+
+ // Type-Specific Accessors
+
+ @Test
+ @DisplayName("getDocument() returns nested document")
+ void getDocumentShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ BsonDocument nested = byteBufDocument.getDocument("c");
+ assertNotNull(nested);
+ assertEquals(BsonBoolean.TRUE, nested.get("x"));
+ }
+ }
+
+ @Test
+ @DisplayName("getArray() returns array")
+ void getArrayShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ BsonArray array = byteBufDocument.getArray("d");
+ assertNotNull(array);
+ assertEquals(2, array.size());
+ }
+ }
+
+ @Test
+ @DisplayName("get() with default value works correctly")
+ void getWithDefaultShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(new BsonInt32(1), byteBufDocument.get("a", new BsonInt32(999)));
+ assertEquals(new BsonInt32(999), byteBufDocument.get("missing", new BsonInt32(999)));
+ }
+ }
+
+ @Test
+ @DisplayName("Type check methods return correct results")
+ void typeChecksShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertTrue(byteBufDocument.isNumber("a"));
+ assertTrue(byteBufDocument.isInt32("a"));
+ assertTrue(byteBufDocument.isDocument("c"));
+ assertTrue(byteBufDocument.isArray("d"));
+ assertFalse(byteBufDocument.isDocument("a"));
+ }
+ }
+
+ // Immutability
+
+ @Test
+ @DisplayName("All write methods throw UnsupportedOperationException")
+ void writeMethodsShouldThrow() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.clear());
+ assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.put("x", new BsonInt32(1)));
+ assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.append("x", new BsonInt32(1)));
+ assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.putAll(new BsonDocument()));
+ assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.remove("a"));
+ }
+ }
+
+ // Conversion and Serialization
+
+ @Test
+ @DisplayName("toBsonDocument() returns equivalent document and caches result")
+ void toBsonDocumentShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(document, byteBufDocument.toBsonDocument());
+ BsonDocument first = byteBufDocument.toBsonDocument();
+ BsonDocument second = byteBufDocument.toBsonDocument();
+ assertEquals(first, second);
+ } }
+
+ @Test
+ @DisplayName("asBsonReader() creates valid reader")
+ void asBsonReaderShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ try (BsonReader reader = byteBufDocument.asBsonReader()) {
+ BsonDocument decoded = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build());
+ assertEquals(document, decoded);
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("toJson() returns correct JSON ")
+ void toJsonShouldWork() {
+ ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf);
+ assertEquals(document.toJson(), byteBufDocument.toJson());
+ byteBufDocument.close();
+
+ assertNotNull(byteBufDocument.toJson()); // Verify caching
+ }
+
+ @Test
+ @DisplayName("toJson() returns correct JSON with different settings")
+ void toJsonWithSettingsShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ JsonWriterSettings shellSettings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build();
+ assertEquals(document.toJson(shellSettings), byteBufDocument.toJson(shellSettings));
+ }
+ }
+
+ @Test
+ @DisplayName("toString() returns equivalent string")
+ void toStringShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(document.toString(), byteBufDocument.toString());
+ }
+ }
+
+ @Test
+ @DisplayName("clone() creates deep copy")
+ void cloneShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ BsonDocument cloned = byteBufDocument.clone();
+ assertNotSame(byteBufDocument, cloned);
+ assertEquals(byteBufDocument, cloned);
+
+ assertNotSame(byteBufDocument.clone(), byteBufDocument.clone());
+ }
+ }
+
+ @Test
+ @DisplayName("Java serialization works correctly")
+ void serializationShouldWork() throws Exception {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ new ObjectOutputStream(baos).writeObject(byteBufDocument);
+ Object deserialized = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())).readObject();
+ assertEquals(byteBufDocument, deserialized);
+ }
+ }
+
+ // Equality and HashCode
+
+ @Test
+ @DisplayName("equals() and hashCode() work correctly")
+ void equalsAndHashCodeShouldWork() {
+ try (ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf)) {
+ assertEquals(document, byteBufDocument);
+ assertEquals(byteBufDocument, document);
+ assertEquals(document.hashCode(), byteBufDocument.hashCode());
+ assertNotEquals(byteBufDocument, new BsonDocument("x", new BsonInt32(99)));
+ }
+ }
+
+ // Resource Management
+
+ @Test
+ @DisplayName("Closed document throws IllegalStateException on all operations")
+ void closedDocumentShouldThrow() {
+ ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf);
+ byteBufDocument.close();
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.size());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.isEmpty());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.containsKey("a"));
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.get("a"));
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.keySet());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.values());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.entrySet());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.getFirstKey());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.toBsonDocument());
+ assertThrows(IllegalStateException.class, () -> byteBufDocument.toJson());
+ }
+
+ @Test
+ @DisplayName("close() can be called multiple times safely")
+ void closeIsIdempotent() {
+ ByteBufBsonDocument byteBufDocument = new ByteBufBsonDocument(documentByteBuf);
+ byteBufDocument.close();
+ byteBufDocument.close(); // Should not throw
+ }
+
+ @Test
+ @DisplayName("Nested documents are closed when parent is closed")
+ void nestedDocumentsClosedWithParent() {
+ BsonDocument doc = new BsonDocument("outer", new BsonDocument("inner", new BsonInt32(42)));
+ ByteBuf buf = createByteBufFromDocument(doc);
+ ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf);
+
+ BsonDocument retrieved = byteBufDoc.getDocument("outer");
+ byteBufDoc.close();
+
+ assertThrows(IllegalStateException.class, byteBufDoc::size);
+ if (retrieved instanceof ByteBufBsonDocument) {
+ assertThrows(IllegalStateException.class, retrieved::size);
+ }
+ }
+
+ @Test
+ @DisplayName("Nested arrays are closed when parent is closed")
+ void nestedArraysClosedWithParent() {
+ BsonDocument doc = new BsonDocument("arr", new BsonArray(asList(
+ new BsonInt32(1), new BsonDocument("x", new BsonInt32(2))
+ )));
+ ByteBuf buf = createByteBufFromDocument(doc);
+ ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf);
+
+ BsonArray retrieved = byteBufDoc.getArray("arr");
+ byteBufDoc.close();
+
+ assertThrows(IllegalStateException.class, byteBufDoc::size);
+ if (retrieved instanceof ByteBufBsonArray) {
+ assertThrows(IllegalStateException.class, retrieved::size);
+ }
+ }
+
+ @Test
+ @DisplayName("Deeply nested structures are closed recursively")
+ void deeplyNestedClosedRecursively() {
+ BsonDocument doc = new BsonDocument()
+ .append("level1", new BsonArray(asList(
+ new BsonDocument("level2", new BsonDocument("level3", new BsonInt32(999))),
+ new BsonInt32(1)
+ )))
+ .append("sibling", new BsonDocument("key", new BsonString("value")));
+
+ ByteBuf buf = createByteBufFromDocument(doc);
+ ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf);
+
+ BsonArray level1 = byteBufDoc.getArray("level1");
+ byteBufDoc.getDocument("sibling");
+
+ if (level1.get(0).isDocument()) {
+ BsonDocument level2Doc = level1.get(0).asDocument();
+ if (level2Doc.containsKey("level2")) {
+ assertEquals(new BsonInt32(999), level2Doc.getDocument("level2").get("level3"));
+ }
+ }
+
+ byteBufDoc.close();
+ assertThrows(IllegalStateException.class, byteBufDoc::size);
+ }
+
+ @Test
+ @DisplayName("Iterators work as expected")
+ void iteratorsWorksAsExpected() {
+ BsonDocument doc = new BsonDocument()
+ .append("doc1", new BsonDocument("a", new BsonInt32(1)))
+ .append("arr1", new BsonArray(asList(new BsonInt32(2), new BsonInt32(3))))
+ .append("primitive", new BsonString("test"));
+
+ try (ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(createByteBufFromDocument(doc))) {
+
+ int count = 0;
+ for (Map.Entry entry : byteBufDoc.entrySet()) {
+ assertNotNull(entry.getKey());
+ assertNotNull(entry.getValue());
+ count++;
+ }
+ assertEquals(3, count);
+
+ Iterator keysIterator = byteBufDoc.keySet().iterator();
+ assertDoesNotThrow(keysIterator::hasNext);
+
+ Iterator nestedKeysIterator = byteBufDoc.getDocument("doc1").keySet().iterator();
+ assertDoesNotThrow(nestedKeysIterator::hasNext);
+
+ Iterator arrayIterator = byteBufDoc.getArray("arr1").iterator();
+ assertDoesNotThrow(arrayIterator::hasNext);
+ }
+ }
+
+ @Test
+ @DisplayName("toBsonDocument() handles nested structures and allows close")
+ void toBsonDocumentHandlesNestedStructures() {
+ BsonDocument complexDoc = new BsonDocument()
+ .append("doc", new BsonDocument("x", new BsonInt32(1)))
+ .append("arr", new BsonArray(asList(new BsonDocument("y", new BsonInt32(2)), new BsonInt32(3))));
+
+ ByteBuf buf = createByteBufFromDocument(complexDoc);
+ ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf);
+
+ BsonDocument hydrated = byteBufDoc.toBsonDocument();
+ assertEquals(complexDoc, hydrated);
+
+ byteBufDoc.close();
+ }
+
+ @Test
+ @DisplayName("cachedDocument is usable after close")
+ void cachedDocumentIsUsableAfterClose() {
+ BsonDocument complexDoc = new BsonDocument()
+ .append("doc", new BsonDocument("x", new BsonInt32(1)))
+ .append("arr", new BsonArray(asList(new BsonDocument("y", new BsonInt32(2)), new BsonInt32(3))));
+
+ ByteBuf buf = createByteBufFromDocument(complexDoc);
+ ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf);
+ BsonDocument hydrated = byteBufDoc.toBsonDocument();
+
+ byteBufDoc.close();
+ assertEquals(complexDoc, hydrated);
+ assertEquals(complexDoc.toJson(), hydrated.toJson());
+ }
+
+ // Sequence Fields (OP_MSG)
+
+ @Test
+ @DisplayName("Sequence field is accessible as array of ByteBufBsonDocuments")
+ void sequenceFieldAccessibleAsArray() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 3)) {
+
+ BsonValue documentsValue = commandDoc.get("documents");
+ assertNotNull(documentsValue);
+ assertTrue(documentsValue.isArray());
+
+ BsonArray documents = documentsValue.asArray();
+ assertEquals(3, documents.size());
+
+ for (int i = 0; i < 3; i++) {
+ BsonValue doc = documents.get(i);
+ assertInstanceOf(ByteBufBsonDocument.class, doc);
+ assertEquals(new BsonInt32(i), doc.asDocument().get("_id"));
+ assertEquals(new BsonString("doc" + i), doc.asDocument().get("name"));
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("Sequence field is included in size, keySet, values, and entrySet")
+ void sequenceFieldIncludedInCollectionViews() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) {
+
+ assertTrue(commandDoc.size() >= 3);
+ assertTrue(commandDoc.keySet().contains("documents"));
+ assertTrue(commandDoc.keySet().contains("insert"));
+
+ boolean foundDocumentsArray = false;
+ for (BsonValue value : commandDoc.values()) {
+ if (value.isArray() && value.asArray().size() == 2) {
+ foundDocumentsArray = true;
+ break;
+ }
+ }
+ assertTrue(foundDocumentsArray);
+
+ boolean foundDocumentsEntry = false;
+ for (Map.Entry entry : commandDoc.entrySet()) {
+ if ("documents".equals(entry.getKey())) {
+ foundDocumentsEntry = true;
+ assertEquals(2, entry.getValue().asArray().size());
+ break;
+ }
+ }
+ assertTrue(foundDocumentsEntry);
+ }
+ }
+
+ @Test
+ @DisplayName("containsKey and containsValue work with sequence fields")
+ void containsMethodsWorkWithSequenceFields() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 3)) {
+
+ assertTrue(commandDoc.containsKey("documents"));
+ assertTrue(commandDoc.containsKey("insert"));
+ assertFalse(commandDoc.containsKey("nonexistent"));
+
+ BsonDocument expectedDoc = new BsonDocument()
+ .append("_id", new BsonInt32(1))
+ .append("name", new BsonString("doc1"));
+ assertTrue(commandDoc.containsValue(expectedDoc));
+ }
+ }
+
+ @Test
+ @DisplayName("Sequence field documents are closed when parent is closed")
+ void sequenceFieldDocumentsClosedWithParent() {
+ ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2);
+
+ BsonArray documents = commandDoc.getArray("documents");
+ List docRefs = new ArrayList<>();
+ for (BsonValue doc : documents) {
+ docRefs.add(doc.asDocument());
+ }
+
+ commandDoc.close();
+ output.close();
+
+ assertThrows(IllegalStateException.class, commandDoc::size);
+ for (BsonDocument doc : docRefs) {
+ if (doc instanceof ByteBufBsonDocument) {
+ assertThrows(IllegalStateException.class, doc::size);
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("Sequence field is cached on multiple access")
+ void sequenceFieldCached() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) {
+
+ BsonArray first = commandDoc.getArray("documents");
+ BsonArray second = commandDoc.getArray("documents");
+ assertNotNull(first);
+ assertEquals(first.size(), second.size());
+ }
+ }
+
+ @Test
+ @DisplayName("toBsonDocument() hydrates sequence fields to regular BsonDocuments")
+ void toBsonDocumentHydratesSequenceFields() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) {
+
+ BsonDocument hydrated = commandDoc.toBsonDocument();
+ assertTrue(hydrated.containsKey("documents"));
+
+ BsonArray documents = hydrated.getArray("documents");
+ assertEquals(2, documents.size());
+ for (BsonValue doc : documents) {
+ assertFalse(doc instanceof ByteBufBsonDocument);
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("Sequence field with nested documents works correctly")
+ void sequenceFieldWithNestedDocuments() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ ByteBufBsonDocument commandDoc = createNestedCommandMessageDocument(output);
+
+ BsonArray documents = commandDoc.getArray("documents");
+ assertEquals(2, documents.size());
+
+ BsonDocument firstDoc = documents.get(0).asDocument();
+ BsonDocument nested = firstDoc.getDocument("nested");
+ assertEquals(new BsonInt32(0), nested.get("inner"));
+
+ BsonArray array = firstDoc.getArray("array");
+ assertEquals(2, array.size());
+
+ commandDoc.close();
+ }
+ }
+
+ @Test
+ @DisplayName("Empty sequence field returns empty array")
+ void emptySequenceField() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 0)) {
+
+ assertTrue(commandDoc.containsKey("insert"));
+ assertTrue(commandDoc.containsKey("documents"));
+ assertTrue(commandDoc.getArray("documents").isEmpty());
+ }
+ }
+
+ @Test
+ @DisplayName("getFirstKey() returns body field, not sequence field")
+ void getFirstKeyReturnsBodyField() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) {
+
+ assertEquals("insert", commandDoc.getFirstKey());
+ }
+ }
+
+ @Test
+ @DisplayName("toJson() includes sequence fields")
+ void toJsonIncludesSequenceFields() {
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) {
+
+ String json = commandDoc.toJson();
+ assertTrue(json.contains("documents"));
+ assertTrue(json.contains("_id"));
+ }
+ }
+
+ @Test
+ @DisplayName("equals() and hashCode() include sequence fields")
+ void equalsAndHashCodeIncludeSequenceFields() {
+ try (ByteBufferBsonOutput output1 = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc1 = createCommandMessageDocument(output1, 2);
+ ByteBufferBsonOutput output2 = new ByteBufferBsonOutput(new SimpleBufferProvider());
+ ByteBufBsonDocument commandDoc2 = createCommandMessageDocument(output2, 2)) {
+
+ assertEquals(commandDoc1.toBsonDocument(), commandDoc2.toBsonDocument());
+ assertEquals(commandDoc1.hashCode(), commandDoc2.hashCode());
+ }
+ }
+
+ // --- Helper Methods ---
+
+ private ByteBufBsonDocument createCommandMessageDocument(final ByteBufferBsonOutput output, final int numDocuments) {
+ BsonDocument bodyDoc = new BsonDocument()
+ .append("insert", new BsonString("test"))
+ .append("$db", new BsonString("db"));
+
+ byte[] bodyBytes = encodeBsonDocument(bodyDoc);
+ List sequenceDocBytes = new ArrayList<>();
+ for (int i = 0; i < numDocuments; i++) {
+ BsonDocument seqDoc = new BsonDocument()
+ .append("_id", new BsonInt32(i))
+ .append("name", new BsonString("doc" + i));
+ sequenceDocBytes.add(encodeBsonDocument(seqDoc));
+ }
+
+ writeOpMsgFormat(output, bodyBytes, "documents", sequenceDocBytes);
+
+ List buffers = output.getByteBuffers();
+ return ByteBufBsonDocument.createCommandMessage(new CompositeByteBuf(buffers));
+ }
+
+ private ByteBufBsonDocument createNestedCommandMessageDocument(final ByteBufferBsonOutput output) {
+ BsonDocument bodyDoc = new BsonDocument()
+ .append("insert", new BsonString("test"))
+ .append("$db", new BsonString("db"));
+
+ byte[] bodyBytes = encodeBsonDocument(bodyDoc);
+ List sequenceDocBytes = new ArrayList<>();
+ for (int i = 0; i < 2; i++) {
+ BsonDocument seqDoc = new BsonDocument()
+ .append("_id", new BsonInt32(i))
+ .append("nested", new BsonDocument("inner", new BsonInt32(i * 10)))
+ .append("array", new BsonArray(asList(
+ new BsonInt32(i),
+ new BsonDocument("arrayNested", new BsonString("value" + i))
+ )));
+ sequenceDocBytes.add(encodeBsonDocument(seqDoc));
+ }
+
+ writeOpMsgFormat(output, bodyBytes, "documents", sequenceDocBytes);
+ return ByteBufBsonDocument.createCommandMessage(new CompositeByteBuf(output.getByteBuffers()));
+ }
+
+ private void writeOpMsgFormat(final ByteBufferBsonOutput output, final byte[] bodyBytes,
+ final String sequenceIdentifier, final List sequenceDocBytes) {
+ output.writeBytes(bodyBytes, 0, bodyBytes.length);
+
+ int sequencePayloadSize = sequenceDocBytes.stream().mapToInt(b -> b.length).sum();
+ int sequenceSectionSize = 4 + sequenceIdentifier.length() + 1 + sequencePayloadSize;
+
+ output.writeByte(1);
+ output.writeInt32(sequenceSectionSize);
+ output.writeCString(sequenceIdentifier);
+ for (byte[] docBytes : sequenceDocBytes) {
+ output.writeBytes(docBytes, 0, docBytes.length);
+ }
+ }
+
+ private static byte[] encodeBsonDocument(final BsonDocument doc) {
+ try {
+ BasicOutputBuffer buffer = new BasicOutputBuffer();
+ new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), doc, EncoderContext.builder().build());
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ buffer.pipe(baos);
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static ByteBuf createByteBufFromDocument(final BsonDocument doc) {
+ return new ByteBufNIO(ByteBuffer.wrap(encodeBsonDocument(doc)));
+ }
+
+ private static class SimpleBufferProvider implements BufferProvider {
+ @NotNull
+ @Override
+ public ByteBuf getBuffer(final int size) {
+ return new ByteBufNIO(ByteBuffer.allocate(size).order(ByteOrder.LITTLE_ENDIAN));
+ }
+ }
+}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy
deleted file mode 100644
index 77bdd5e2045..00000000000
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy
+++ /dev/null
@@ -1,365 +0,0 @@
-/*
- * Copyright 2008-present MongoDB, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.mongodb.internal.connection
-
-
-import com.mongodb.MongoNamespace
-import com.mongodb.ReadConcern
-import com.mongodb.ReadPreference
-import com.mongodb.connection.ClusterConnectionMode
-import com.mongodb.connection.ServerType
-import com.mongodb.internal.IgnorableRequestContext
-import com.mongodb.internal.TimeoutContext
-import com.mongodb.internal.bulk.InsertRequest
-import com.mongodb.internal.bulk.WriteRequestWithIndex
-import com.mongodb.internal.session.SessionContext
-import com.mongodb.internal.validator.NoOpFieldNameValidator
-import org.bson.BsonArray
-import org.bson.BsonBinary
-import org.bson.BsonDocument
-import org.bson.BsonInt32
-import org.bson.BsonMaximumSizeExceededException
-import org.bson.BsonString
-import org.bson.BsonTimestamp
-import org.bson.ByteBuf
-import org.bson.ByteBufNIO
-import org.bson.codecs.BsonDocumentCodec
-import spock.lang.Specification
-
-import java.nio.ByteBuffer
-
-import static com.mongodb.internal.connection.SplittablePayload.Type.INSERT
-import static com.mongodb.internal.operation.ServerVersionHelper.LATEST_WIRE_VERSION
-
-/**
- * New tests must be added to {@link CommandMessageTest}.
- */
-class CommandMessageSpecification extends Specification {
-
- def namespace = new MongoNamespace('db.test')
- def command = new BsonDocument('find', new BsonString(namespace.collectionName))
- def fieldNameValidator = NoOpFieldNameValidator.INSTANCE
-
- def 'should encode command message with OP_MSG when server version is >= 3.6'() {
- given:
- def message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, readPreference,
- MessageSettings.builder()
- .maxWireVersion(LATEST_WIRE_VERSION)
- .serverType(serverType as ServerType)
- .sessionSupported(true)
- .build(),
- responseExpected, MessageSequences.EmptyMessageSequences.INSTANCE, clusterConnectionMode, null)
- def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
-
- when:
- message.encode(output, operationContext)
-
- then:
- def byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- def messageHeader = new MessageHeader(byteBuf, 512)
- def replyHeader = new ReplyHeader(byteBuf, messageHeader)
- messageHeader.opCode == OpCode.OP_MSG.value
- replyHeader.requestId < RequestMessage.currentGlobalId
- replyHeader.responseTo == 0
- replyHeader.hasMoreToCome() != responseExpected
-
- def expectedCommandDocument = command.clone()
- .append('$db', new BsonString(namespace.databaseName))
-
- if (operationContext.getSessionContext().clusterTime != null) {
- expectedCommandDocument.append('$clusterTime', operationContext.getSessionContext().clusterTime)
- }
- if (operationContext.getSessionContext().hasSession() && responseExpected) {
- expectedCommandDocument.append('lsid', operationContext.getSessionContext().sessionId)
- }
-
- if (readPreference != ReadPreference.primary()) {
- expectedCommandDocument.append('$readPreference', readPreference.toDocument())
- } else if (clusterConnectionMode == ClusterConnectionMode.SINGLE && serverType != ServerType.SHARD_ROUTER) {
- expectedCommandDocument.append('$readPreference', ReadPreference.primaryPreferred().toDocument())
- }
- getCommandDocument(byteBuf, replyHeader) == expectedCommandDocument
-
- cleanup:
- output.close()
-
- where:
- [readPreference, serverType, clusterConnectionMode, operationContext, responseExpected, isCryptd] << [
- [ReadPreference.primary(), ReadPreference.secondary()],
- [ServerType.REPLICA_SET_PRIMARY, ServerType.SHARD_ROUTER],
- [ClusterConnectionMode.SINGLE, ClusterConnectionMode.MULTIPLE],
- [
- new OperationContext(
- IgnorableRequestContext.INSTANCE,
- Stub(SessionContext) {
- hasSession() >> false
- getClusterTime() >> null
- getSessionId() >> new BsonDocument('id', new BsonBinary([1, 2, 3] as byte[]))
- getReadConcern() >> ReadConcern.DEFAULT
- }, Stub(TimeoutContext), null),
- new OperationContext(
- IgnorableRequestContext.INSTANCE,
- Stub(SessionContext) {
- hasSession() >> false
- getClusterTime() >> new BsonDocument('clusterTime', new BsonTimestamp(42, 1))
- getReadConcern() >> ReadConcern.DEFAULT
- }, Stub(TimeoutContext), null),
- new OperationContext(
- IgnorableRequestContext.INSTANCE,
- Stub(SessionContext) {
- hasSession() >> true
- getClusterTime() >> null
- getSessionId() >> new BsonDocument('id', new BsonBinary([1, 2, 3] as byte[]))
- getReadConcern() >> ReadConcern.DEFAULT
- }, Stub(TimeoutContext), null),
- new OperationContext(
- IgnorableRequestContext.INSTANCE,
- Stub(SessionContext) {
- hasSession() >> true
- getClusterTime() >> new BsonDocument('clusterTime', new BsonTimestamp(42, 1))
- getSessionId() >> new BsonDocument('id', new BsonBinary([1, 2, 3] as byte[]))
- getReadConcern() >> ReadConcern.DEFAULT
- }, Stub(TimeoutContext), null)
- ],
- [true, false],
- [true, false]
- ].combinations()
- }
-
- String getString(final ByteBuf byteBuf) {
- def byteArrayOutputStream = new ByteArrayOutputStream()
- def cur = byteBuf.get()
- while (cur != 0) {
- byteArrayOutputStream.write(cur)
- cur = byteBuf.get()
- }
- new String(byteArrayOutputStream.toByteArray(), 'UTF-8')
- }
-
- def 'should get command document'() {
- given:
- def message = new CommandMessage(namespace.getDatabaseName(), originalCommandDocument, fieldNameValidator,
- ReadPreference.primary(), MessageSettings.builder().maxWireVersion(maxWireVersion).build(), true,
- payload == null ? MessageSequences.EmptyMessageSequences.INSTANCE : payload,
- ClusterConnectionMode.MULTIPLE, null)
- def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
- Stub(TimeoutContext), null))
-
- when:
- def commandDocument = message.getCommandDocument(output)
-
- def expectedCommandDocument = new BsonDocument('insert', new BsonString('coll')).append('documents',
- new BsonArray([new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))]))
- expectedCommandDocument.append('$db', new BsonString(namespace.getDatabaseName()))
- then:
- commandDocument == expectedCommandDocument
-
-
- where:
- [maxWireVersion, originalCommandDocument, payload] << [
- [
- LATEST_WIRE_VERSION,
- new BsonDocument('insert', new BsonString('coll')),
- new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)),
- new BsonDocument('_id', new BsonInt32(2))]
- .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) },
- true, NoOpFieldNameValidator.INSTANCE),
- ],
- [
- LATEST_WIRE_VERSION,
- new BsonDocument('insert', new BsonString('coll')).append('documents',
- new BsonArray([new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))])),
- null
- ]
- ]
- }
-
- def 'should respect the max message size'() {
- given:
- def maxMessageSize = 1024
- def messageSettings = MessageSettings.builder().maxMessageSize(maxMessageSize).maxWireVersion(LATEST_WIRE_VERSION).build()
- def insertCommand = new BsonDocument('insert', new BsonString(namespace.collectionName))
- def payload = new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)).append('a', new BsonBinary(new byte[913])),
- new BsonDocument('_id', new BsonInt32(2)).append('b', new BsonBinary(new byte[441])),
- new BsonDocument('_id', new BsonInt32(3)).append('c', new BsonBinary(new byte[450])),
- new BsonDocument('_id', new BsonInt32(4)).append('b', new BsonBinary(new byte[441])),
- new BsonDocument('_id', new BsonInt32(5)).append('c', new BsonBinary(new byte[451]))]
- .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator)
- def message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(),
- messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null)
- def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
- def sessionContext = Stub(SessionContext) {
- getReadConcern() >> ReadConcern.DEFAULT
- }
-
- when:
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext,
- Stub(TimeoutContext), null))
- def byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- def messageHeader = new MessageHeader(byteBuf, maxMessageSize)
-
- then:
- messageHeader.opCode == OpCode.OP_MSG.value
- messageHeader.requestId < RequestMessage.currentGlobalId
- messageHeader.responseTo == 0
- messageHeader.messageLength == 1024
- byteBuf.getInt() == 0
- payload.getPosition() == 1
- payload.hasAnotherSplit()
-
- when:
- payload = payload.getNextSplit()
- message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(),
- messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null)
- output.truncateToPosition(0)
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null))
- byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- messageHeader = new MessageHeader(byteBuf, maxMessageSize)
-
- then:
- messageHeader.opCode == OpCode.OP_MSG.value
- messageHeader.requestId < RequestMessage.currentGlobalId
- messageHeader.responseTo == 0
- messageHeader.messageLength == 1024
- byteBuf.getInt() == 0
- payload.getPosition() == 2
- payload.hasAnotherSplit()
-
- when:
- payload = payload.getNextSplit()
- message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(),
- messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null)
- output.truncateToPosition(0)
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null))
- byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- messageHeader = new MessageHeader(byteBuf, maxMessageSize)
-
- then:
- messageHeader.opCode == OpCode.OP_MSG.value
- messageHeader.requestId < RequestMessage.currentGlobalId
- messageHeader.responseTo == 0
- messageHeader.messageLength == 552
- byteBuf.getInt() == 0
- payload.getPosition() == 1
- payload.hasAnotherSplit()
-
- when:
- payload = payload.getNextSplit()
- message = new CommandMessage(namespace.getDatabaseName(), insertCommand, fieldNameValidator, ReadPreference.primary(),
- messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null)
- output.truncateToPosition(0)
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE,
- sessionContext,
- Stub(TimeoutContext),
- null))
- byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- messageHeader = new MessageHeader(byteBuf, maxMessageSize)
-
- then:
- messageHeader.opCode == OpCode.OP_MSG.value
- messageHeader.requestId < RequestMessage.currentGlobalId
- messageHeader.responseTo == 0
- messageHeader.messageLength == 562
- byteBuf.getInt() == 1 << 1
- payload.getPosition() == 1
- !payload.hasAnotherSplit()
-
- cleanup:
- output.close()
- }
-
- def 'should respect the max batch count'() {
- given:
- def messageSettings = MessageSettings.builder().maxBatchCount(2).maxWireVersion(LATEST_WIRE_VERSION).build()
- def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900])),
- new BsonDocument('b', new BsonBinary(new byte[450])),
- new BsonDocument('c', new BsonBinary(new byte[450]))]
- .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator)
- def message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, ReadPreference.primary(),
- messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null)
- def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
- def sessionContext = Stub(SessionContext) {
- getReadConcern() >> ReadConcern.DEFAULT
- }
-
- when:
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext,
- Stub(TimeoutContext),
- null))
- def byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- def messageHeader = new MessageHeader(byteBuf, 2048)
-
- then:
- messageHeader.opCode == OpCode.OP_MSG.value
- messageHeader.requestId < RequestMessage.currentGlobalId
- messageHeader.responseTo == 0
- messageHeader.messageLength == 1497
- byteBuf.getInt() == 0
- payload.getPosition() == 2
- payload.hasAnotherSplit()
-
- when:
- payload = payload.getNextSplit()
- message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, ReadPreference.primary(), messageSettings,
- false, payload, ClusterConnectionMode.MULTIPLE, null)
- output.truncateToPosition(0)
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext,
- Stub(TimeoutContext), null))
- byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
- messageHeader = new MessageHeader(byteBuf, 1024)
-
- then:
- messageHeader.opCode == OpCode.OP_MSG.value
- messageHeader.requestId < RequestMessage.currentGlobalId
- messageHeader.responseTo == 0
- byteBuf.getInt() == 1 << 1
- payload.getPosition() == 1
- !payload.hasAnotherSplit()
-
- cleanup:
- output.close()
- }
-
- def 'should throw if payload document bigger than max document size'() {
- given:
- def messageSettings = MessageSettings.builder().maxDocumentSize(900)
- .maxWireVersion(LATEST_WIRE_VERSION).build()
- def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900]))]
- .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator)
- def message = new CommandMessage(namespace.getDatabaseName(), command, fieldNameValidator, ReadPreference.primary(),
- messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null)
- def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
- def sessionContext = Stub(SessionContext) {
- getReadConcern() >> ReadConcern.DEFAULT
- }
-
- when:
- message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext,
- Stub(TimeoutContext), null))
-
- then:
- thrown(BsonMaximumSizeExceededException)
-
- cleanup:
- output.close()
- }
-
- private static BsonDocument getCommandDocument(ByteBufNIO byteBuf, ReplyHeader replyHeader) {
- new ReplyMessage(new ResponseBuffers(replyHeader, byteBuf), new BsonDocumentCodec(), 0).document
- }
-}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java
index 091518c715c..e5eab18869b 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java
@@ -27,6 +27,8 @@
import com.mongodb.internal.IgnorableRequestContext;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
+import com.mongodb.internal.bulk.InsertRequest;
+import com.mongodb.internal.bulk.WriteRequestWithIndex;
import com.mongodb.internal.client.model.bulk.ConcreteClientBulkWriteOptions;
import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences;
import com.mongodb.internal.operation.ClientBulkWriteOperation;
@@ -34,11 +36,14 @@
import com.mongodb.internal.session.SessionContext;
import com.mongodb.internal.validator.NoOpFieldNameValidator;
import org.bson.BsonArray;
+import org.bson.BsonBinary;
import org.bson.BsonBoolean;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
+import org.bson.BsonMaximumSizeExceededException;
import org.bson.BsonString;
import org.bson.BsonTimestamp;
+import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import java.util.List;
@@ -53,17 +58,20 @@
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+@DisplayName("CommandMessage")
class CommandMessageTest {
private static final MongoNamespace NAMESPACE = new MongoNamespace("db.test");
private static final BsonDocument COMMAND = new BsonDocument("find", new BsonString(NAMESPACE.getCollectionName()));
@Test
+ @DisplayName("encode should throw timeout exception when timeout context is called")
void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() {
//given
CommandMessage commandMessage = new CommandMessage(NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(),
@@ -91,6 +99,7 @@ void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() {
}
@Test
+ @DisplayName("encode should not add extra elements from timeout context when connected to mongocryptd")
void encodeShouldNotAddExtraElementsFromTimeoutContextWhenConnectedToMongoCrypt() {
//given
CommandMessage commandMessage = new CommandMessage(NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(),
@@ -126,6 +135,7 @@ void encodeShouldNotAddExtraElementsFromTimeoutContextWhenConnectedToMongoCrypt(
}
@Test
+ @DisplayName("get command document from client bulk write operation")
void getCommandDocumentFromClientBulkWrite() {
MongoNamespace ns = new MongoNamespace("db", "test");
boolean retryWrites = false;
@@ -164,8 +174,466 @@ void getCommandDocumentFromClientBulkWrite() {
new OperationContext(
IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
new TimeoutContext(TimeoutSettings.DEFAULT), null));
- BsonDocument actualCommandDocument = commandMessage.getCommandDocument(output);
- assertEquals(expectedCommandDocument, actualCommandDocument);
+
+ try (ByteBufBsonDocument actualCommandDocument = commandMessage.getCommandDocument(output)) {
+ assertEquals(expectedCommandDocument, actualCommandDocument);
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("get command document with payload containing documents")
+ void getCommandDocumentWithPayload() {
+ // given
+ BsonDocument originalCommandDocument = new BsonDocument("insert", new BsonString("coll"));
+ List documents = asList(
+ new BsonDocument("_id", new BsonInt32(1)),
+ new BsonDocument("_id", new BsonInt32(2))
+ );
+ List requestsFromDocs = IntStream.range(0, documents.size())
+ .mapToObj(i -> new WriteRequestWithIndex(new InsertRequest(documents.get(i)), i))
+ .collect(Collectors.toList());
+
+ SplittablePayload payload = new SplittablePayload(
+ SplittablePayload.Type.INSERT,
+ requestsFromDocs,
+ true,
+ NoOpFieldNameValidator.INSTANCE
+ );
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), originalCommandDocument, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(), MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), true,
+ payload, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(
+ output,
+ new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null)
+ );
+
+ // when
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ // then
+ assertEquals("coll", commandDoc.getString("insert").getValue());
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ BsonArray docsArray = commandDoc.getArray("documents");
+ assertEquals(2, docsArray.size());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("get command document with pre-encoded documents")
+ void getCommandDocumentWithPreEncodedDocuments() {
+ // given
+ BsonDocument originalCommandDocument = new BsonDocument("insert", new BsonString("coll"))
+ .append("documents", new BsonArray(asList(
+ new BsonDocument("_id", new BsonInt32(1)),
+ new BsonDocument("_id", new BsonInt32(2))
+ )));
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), originalCommandDocument, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(), MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(), true,
+ EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(
+ output,
+ new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null)
+ );
+
+ // when
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ // then
+ assertEquals("coll", commandDoc.getString("insert").getValue());
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ BsonArray docsArray = commandDoc.getArray("documents");
+ assertEquals(2, docsArray.size());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode respects max message size constraint")
+ void encodeShouldRespectMaxMessageSize() {
+ // given
+ int maxMessageSize = 1024;
+ MessageSettings messageSettings = MessageSettings.builder()
+ .maxMessageSize(maxMessageSize)
+ .maxWireVersion(LATEST_WIRE_VERSION)
+ .build();
+ BsonDocument insertCommand = new BsonDocument("insert", new BsonString(NAMESPACE.getCollectionName()));
+
+ List requests = asList(
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(1)).append("a", new BsonBinary(new byte[913]))),
+ 0),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(2)).append("b", new BsonBinary(new byte[441]))),
+ 1),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(3)).append("c", new BsonBinary(new byte[450]))),
+ 2),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(4)).append("b", new BsonBinary(new byte[441]))),
+ 3),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(5)).append("c", new BsonBinary(new byte[451]))),
+ 4)
+ );
+
+ SplittablePayload payload = new SplittablePayload(
+ SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE
+ );
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), insertCommand, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(), messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ // when - encode first batch
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ // then - first batch respects size constraint
+ assertTrue(output.size() <= maxMessageSize, "Output size " + output.size() + " should not exceed max " + maxMessageSize);
+ assertEquals(1, payload.getPosition());
+
+ // Verify multiple splits were created
+ assertTrue(payload.hasAnotherSplit());
+ }
+ }
+
+ @Test
+ @DisplayName("encode respects max batch count constraint")
+ void encodeShouldRespectMaxBatchCount() {
+ // given
+ MessageSettings messageSettings = MessageSettings.builder()
+ .maxBatchCount(2)
+ .maxWireVersion(LATEST_WIRE_VERSION)
+ .build();
+
+ List requests = asList(
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("a", new BsonBinary(new byte[900]))),
+ 0),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("b", new BsonBinary(new byte[450]))),
+ 1),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("c", new BsonBinary(new byte[450]))),
+ 2)
+ );
+
+ SplittablePayload payload = new SplittablePayload(
+ SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE
+ );
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(), messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ // when - encode first batch with max 2 documents
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ // then - first batch has 2 documents
+ assertEquals(2, payload.getPosition());
+ assertTrue(payload.hasAnotherSplit());
}
}
+
+ @Test
+ @DisplayName("encode throws exception when payload document exceeds max document size")
+ void encodeShouldThrowWhenPayloadDocumentExceedsMaxSize() {
+ // given
+ MessageSettings messageSettings = MessageSettings.builder()
+ .maxDocumentSize(900)
+ .maxWireVersion(LATEST_WIRE_VERSION)
+ .build();
+
+ List requests = singletonList(
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("a", new BsonBinary(new byte[900]))),
+ 0)
+ );
+
+ SplittablePayload payload = new SplittablePayload(
+ SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE
+ );
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(), messageSettings, false, payload, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ // when & then
+ assertThrows(BsonMaximumSizeExceededException.class, () ->
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ))
+ );
+ }
+ }
+
+ @Test
+ @DisplayName("encode message with cluster time encodes successfully")
+ void encodeWithClusterTime() {
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ assertTrue(output.size() > 0, "Output should contain encoded message");
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode message with active session encodes successfully")
+ void encodeWithActiveSession() {
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ assertTrue(output.size() > 0, "Output should contain encoded message");
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode message with secondary read preference encodes successfully")
+ void encodeWithSecondaryReadPreference() {
+ ReadPreference secondary = ReadPreference.secondary();
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ secondary,
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ assertTrue(output.size() > 0, "Output should contain encoded message");
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode message in single cluster mode encodes successfully")
+ void encodeInSingleClusterMode() {
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder()
+ .maxWireVersion(LATEST_WIRE_VERSION)
+ .serverType(ServerType.REPLICA_SET_PRIMARY)
+ .build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.SINGLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ assertTrue(output.size() > 0, "Output should contain encoded message");
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode includes database name in command document")
+ void encodeIncludesDatabaseName() {
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ assertEquals(NAMESPACE.getDatabaseName(), commandDoc.getString("$db").getValue());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("command document can be accessed multiple times")
+ void commandDocumentCanBeAccessedMultipleTimes() {
+ BsonDocument originalCommand = new BsonDocument("find", new BsonString("coll"))
+ .append("filter", new BsonDocument("_id", new BsonInt32(1)));
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), originalCommand, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ // Access same fields multiple times
+ assertEquals("coll", commandDoc.getString("find").getValue());
+ assertEquals("coll", commandDoc.getString("find").getValue());
+ BsonDocument filter = commandDoc.getDocument("filter");
+ BsonDocument filter2 = commandDoc.getDocument("filter");
+ assertEquals(filter, filter2);
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode with multiple document sequences creates proper arrays")
+ void encodeWithMultipleDocumentsInSequence() {
+ BsonDocument insertCommand = new BsonDocument("insert", new BsonString("coll"));
+ List requests = asList(
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(1)).append("name", new BsonString("doc1"))),
+ 0),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(2)).append("name", new BsonString("doc2"))),
+ 1),
+ new WriteRequestWithIndex(
+ new InsertRequest(new BsonDocument("_id", new BsonInt32(3)).append("name", new BsonString("doc3"))),
+ 2)
+ );
+
+ SplittablePayload payload = new SplittablePayload(
+ SplittablePayload.Type.INSERT, requests, true, NoOpFieldNameValidator.INSTANCE
+ );
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), insertCommand, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, payload, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ BsonArray documents = commandDoc.getArray("documents");
+ assertEquals(3, documents.size());
+ assertEquals(1, documents.get(0).asDocument().getInt32("_id").getValue());
+ assertEquals(2, documents.get(1).asDocument().getInt32("_id").getValue());
+ assertEquals(3, documents.get(2).asDocument().getInt32("_id").getValue());
+ }
+ }
+ }
+
+ @Test
+ @DisplayName("encode with response not expected sets continuation flag")
+ void encodeWithResponseNotExpected() {
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), COMMAND, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ false, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ // Verify encoded message has continuation flag (0x02)
+ assertTrue(output.size() > 0, "Output should contain encoded message");
+ }
+ }
+
+ @Test
+ @DisplayName("encode preserves original command structure")
+ void encodePreservesCommandStructure() {
+ BsonDocument complexCommand = new BsonDocument("aggregate", new BsonString("coll"))
+ .append("pipeline", new BsonArray(asList(
+ new BsonDocument("$match", new BsonDocument("status", new BsonString("active"))),
+ new BsonDocument("$group", new BsonDocument("_id", new BsonString("$category")))
+ )))
+ .append("cursor", new BsonDocument("batchSize", new BsonInt32(100)));
+
+ CommandMessage message = new CommandMessage(
+ NAMESPACE.getDatabaseName(), complexCommand, NoOpFieldNameValidator.INSTANCE,
+ ReadPreference.primary(),
+ MessageSettings.builder().maxWireVersion(LATEST_WIRE_VERSION).build(),
+ true, EmptyMessageSequences.INSTANCE, ClusterConnectionMode.MULTIPLE, null
+ );
+
+ try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) {
+ message.encode(output, new OperationContext(
+ IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
+ new TimeoutContext(TimeoutSettings.DEFAULT), null
+ ));
+
+ try (ByteBufBsonDocument commandDoc = message.getCommandDocument(output)) {
+ assertEquals("coll", commandDoc.getString("aggregate").getValue());
+ BsonArray pipeline = commandDoc.getArray("pipeline");
+ assertEquals(2, pipeline.size());
+ BsonDocument cursor = commandDoc.getDocument("cursor");
+ assertEquals(100, cursor.getInt32("batchSize").getValue());
+ }
+ }
+ }
+
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java
index 3aff244ea1e..bd587464c23 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java
@@ -58,6 +58,8 @@
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -254,6 +256,54 @@ public void serverHeartbeatFailed(final ServerHeartbeatFailedEvent event) {
assertEquals(expectedEvents, events);
}
+ @Test
+ void closeDuringConnectionShouldNotLeakBuffers() throws Exception {
+ CountDownLatch connectionStarted = new CountDownLatch(1);
+ CountDownLatch proceedWithOpen = new CountDownLatch(1);
+
+ InternalConnection mockConnection = mock(InternalConnection.class);
+ doAnswer(invocation -> {
+ connectionStarted.countDown();
+ assertTrue(proceedWithOpen.await(5, TimeUnit.SECONDS));
+ return null;
+ }).when(mockConnection).open(any());
+
+ when(mockConnection.getInitialServerDescription())
+ .thenReturn(createDefaultServerDescription());
+
+ InternalConnectionFactory factory = createConnectionFactory(mockConnection);
+
+ monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class));
+
+ // Wait for connection to start opening
+ assertTrue(connectionStarted.await(5, TimeUnit.SECONDS));
+
+ // Close monitor while connection is opening
+ monitor.close();
+
+ // Allow connection to complete
+ proceedWithOpen.countDown();
+
+ // Verify no leaks by checking connection was properly closed
+ monitor.getServerMonitor().join(5000);
+ assertFalse(monitor.getServerMonitor().isAlive());
+ verify(mockConnection, timeout(500)).close();
+ }
+
+ @Test
+ void heartbeatWithNullConnectionDescriptionShouldNotCrash() throws Exception {
+ InternalConnection mockConnection = mock(InternalConnection.class);
+ when(mockConnection.getDescription()).thenReturn(null);
+ when(mockConnection.getInitialServerDescription())
+ .thenReturn(createDefaultServerDescription());
+ when(mockConnection.isClosed()).thenReturn(false);
+
+ InternalConnectionFactory factory = createConnectionFactory(mockConnection);
+ monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class));
+
+ // Monitor should handle null description gracefully
+ verify(mockConnection, timeout(500).atLeast(1)).open(any());
+ }
private InternalConnectionFactory createConnectionFactory(final InternalConnection connection) {
InternalConnectionFactory factory = mock(InternalConnectionFactory.class);
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy
index e6f6afb02e0..8e7a7b9d78d 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy
@@ -64,8 +64,10 @@ class LoggingCommandEventSenderSpecification extends Specification {
isDebugEnabled() >> debugLoggingEnabled
}
def operationContext = OPERATION_CONTEXT
+ def commandMessageDocument = message.getCommandDocument(bsonOutput)
+
def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, commandListener,
- operationContext, message, message.getCommandDocument(bsonOutput),
+ operationContext, message, commandMessageDocument,
new StructuredLogger(logger), LoggerSettings.builder().build())
when:
@@ -87,6 +89,9 @@ class LoggingCommandEventSenderSpecification extends Specification {
database, commandDocument.getFirstKey(), 1, failureException)
])
+ cleanup:
+ commandMessageDocument?.close()
+
where:
debugLoggingEnabled << [true, false]
}
@@ -110,8 +115,10 @@ class LoggingCommandEventSenderSpecification extends Specification {
isDebugEnabled() >> true
}
def operationContext = OPERATION_CONTEXT
+ def commandMessageDocument = message.getCommandDocument(bsonOutput)
+
def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, commandListener,
- operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger),
+ operationContext, message, commandMessageDocument, new StructuredLogger(logger),
LoggerSettings.builder().build())
when:
sender.sendStartedEvent()
@@ -146,6 +153,9 @@ class LoggingCommandEventSenderSpecification extends Specification {
"request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}.")
}, failureException)
+ cleanup:
+ commandMessageDocument?.close()
+
where:
commandListener << [null, Stub(CommandListener)]
}
@@ -167,6 +177,7 @@ class LoggingCommandEventSenderSpecification extends Specification {
isDebugEnabled() >> true
}
def operationContext = OPERATION_CONTEXT
+ def commandMessageDocument = message.getCommandDocument(bsonOutput)
def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, null, operationContext,
message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), LoggerSettings.builder().build())
@@ -182,6 +193,9 @@ class LoggingCommandEventSenderSpecification extends Specification {
"request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}. " +
"Command: {\"fake\": {\"\$binary\": {\"base64\": \"${'A' * 967} ..."
}
+
+ cleanup:
+ commandMessageDocument?.close()
}
def 'should log redacted command with ellipses'() {
@@ -201,8 +215,9 @@ class LoggingCommandEventSenderSpecification extends Specification {
isDebugEnabled() >> true
}
def operationContext = OPERATION_CONTEXT
+ def commandMessageDocument = message.getCommandDocument(bsonOutput)
def sender = new LoggingCommandEventSender(['createUser'] as Set, [] as Set, connectionDescription, null,
- operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger),
+ operationContext, message, commandMessageDocument, new StructuredLogger(logger),
LoggerSettings.builder().build())
when:
@@ -215,5 +230,8 @@ class LoggingCommandEventSenderSpecification extends Specification {
"${connectionDescription.connectionId.serverValue} to 127.0.0.1:27017. The " +
"request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}. Command: {}"
}
+
+ cleanup:
+ commandMessageDocument?.close()
}
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy
deleted file mode 100644
index 19bfa994200..00000000000
--- a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy
+++ /dev/null
@@ -1,229 +0,0 @@
-/*
- * Copyright 2008-present MongoDB, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.mongodb.internal.session
-
-import com.mongodb.ServerAddress
-import com.mongodb.connection.ClusterDescription
-import com.mongodb.connection.ClusterSettings
-import com.mongodb.connection.ServerDescription
-import com.mongodb.connection.ServerSettings
-import com.mongodb.internal.connection.Cluster
-import com.mongodb.internal.connection.Connection
-import com.mongodb.internal.connection.Server
-import com.mongodb.internal.connection.ServerTuple
-import com.mongodb.internal.validator.NoOpFieldNameValidator
-import org.bson.BsonArray
-import org.bson.BsonBinarySubType
-import org.bson.BsonDocument
-import org.bson.codecs.BsonDocumentCodec
-import spock.lang.Specification
-
-import static com.mongodb.ClusterFixture.OPERATION_CONTEXT
-import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS
-import static com.mongodb.ClusterFixture.getServerApi
-import static com.mongodb.ReadPreference.primaryPreferred
-import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE
-import static com.mongodb.connection.ClusterType.REPLICA_SET
-import static com.mongodb.connection.ServerConnectionState.CONNECTED
-import static com.mongodb.connection.ServerConnectionState.CONNECTING
-import static com.mongodb.connection.ServerType.REPLICA_SET_PRIMARY
-import static com.mongodb.connection.ServerType.UNKNOWN
-import static java.util.concurrent.TimeUnit.MINUTES
-
-class ServerSessionPoolSpecification extends Specification {
-
- def connectedDescription = new ClusterDescription(MULTIPLE, REPLICA_SET,
- [
- ServerDescription.builder().ok(true)
- .state(CONNECTED)
- .address(new ServerAddress())
- .type(REPLICA_SET_PRIMARY)
- .logicalSessionTimeoutMinutes(30)
- .build()
- ], ClusterSettings.builder().hosts([new ServerAddress()]).build(), ServerSettings.builder().build())
-
- def unconnectedDescription = new ClusterDescription(MULTIPLE, REPLICA_SET,
- [
- ServerDescription.builder().ok(true)
- .state(CONNECTING)
- .address(new ServerAddress())
- .type(UNKNOWN)
- .logicalSessionTimeoutMinutes(null)
- .build()
- ], ClusterSettings.builder().hosts([new ServerAddress()]).build(), ServerSettings.builder().build())
-
- def 'should get session'() {
- given:
- def cluster = Stub(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi())
-
- when:
- def session = pool.get()
-
- then:
- session != null
- }
-
- def 'should throw IllegalStateException if pool is closed'() {
- given:
- def cluster = Stub(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi())
- pool.close()
-
- when:
- pool.get()
-
- then:
- thrown(IllegalStateException)
- }
-
- def 'should pool session'() {
- given:
- def cluster = Stub(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi())
- def session = pool.get()
-
- when:
- pool.release(session)
- def pooledSession = pool.get()
-
- then:
- session == pooledSession
- }
-
- def 'should prune sessions when getting'() {
- given:
- def cluster = Mock(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def clock = Stub(ServerSessionPool.Clock) {
- millis() >>> [0, MINUTES.toMillis(29) + 1,
- ]
- }
- def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock)
- def sessionOne = pool.get()
-
- when:
- pool.release(sessionOne)
-
- then:
- !sessionOne.closed
-
- when:
- def sessionTwo = pool.get()
-
- then:
- sessionTwo != sessionOne
- sessionOne.closed
- 0 * cluster.selectServer(_)
- }
-
- def 'should not prune session when timeout is null'() {
- given:
- def cluster = Stub(Cluster) {
- getCurrentDescription() >> unconnectedDescription
- }
- def clock = Stub(ServerSessionPool.Clock) {
- millis() >>> [0, 0, 0]
- }
- def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock)
- def session = pool.get()
-
- when:
- pool.release(session)
- def newSession = pool.get()
-
- then:
- session == newSession
- }
-
- def 'should initialize session'() {
- given:
- def cluster = Stub(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def clock = Stub(ServerSessionPool.Clock) {
- millis() >> 42
- }
- def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock)
-
- when:
- def session = pool.get() as ServerSessionPool.ServerSessionImpl
-
- then:
- session.lastUsedAtMillis == 42
- session.transactionNumber == 0
- def uuid = session.identifier.getBinary('id')
- uuid != null
- uuid.type == BsonBinarySubType.UUID_STANDARD.value
- uuid.data.length == 16
- }
-
- def 'should advance transaction number'() {
- given:
- def cluster = Stub(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def clock = Stub(ServerSessionPool.Clock) {
- millis() >> 42
- }
- def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock)
-
- when:
- def session = pool.get() as ServerSessionPool.ServerSessionImpl
-
- then:
- session.transactionNumber == 0
- session.advanceTransactionNumber() == 1
- session.transactionNumber == 1
- }
-
- def 'should end pooled sessions when pool is closed'() {
- given:
- def connection = Mock(Connection)
- def server = Stub(Server) {
- getConnection(_) >> connection
- }
- def cluster = Mock(Cluster) {
- getCurrentDescription() >> connectedDescription
- }
- def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi())
- def sessions = []
- 10.times { sessions.add(pool.get()) }
-
- for (def cur : sessions) {
- pool.release(cur)
- }
-
- when:
- pool.close()
-
- then:
- 1 * cluster.selectServer(_, _) >> new ServerTuple(server, connectedDescription.serverDescriptions[0])
- 1 * connection.command('admin',
- new BsonDocument('endSessions', new BsonArray(sessions*.getIdentifier())),
- { it instanceof NoOpFieldNameValidator }, primaryPreferred(),
- { it instanceof BsonDocumentCodec }, _) >> new BsonDocument()
- 1 * connection.release()
- }
-}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java
new file mode 100644
index 00000000000..0322d0f4063
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java
@@ -0,0 +1,320 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.session;
+
+import com.mongodb.MongoException;
+import com.mongodb.ServerAddress;
+import com.mongodb.connection.ClusterDescription;
+import com.mongodb.connection.ClusterSettings;
+import com.mongodb.connection.ServerDescription;
+import com.mongodb.connection.ServerSettings;
+import com.mongodb.internal.connection.Cluster;
+import com.mongodb.internal.connection.Connection;
+import com.mongodb.internal.connection.Server;
+import com.mongodb.internal.connection.ServerTuple;
+import com.mongodb.session.ServerSession;
+import org.bson.BsonArray;
+import org.bson.BsonDocument;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentMatcher;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static com.mongodb.ClusterFixture.OPERATION_CONTEXT;
+import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS;
+import static com.mongodb.ClusterFixture.getServerApi;
+import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE;
+import static com.mongodb.connection.ClusterType.REPLICA_SET;
+import static com.mongodb.connection.ServerConnectionState.CONNECTED;
+import static com.mongodb.connection.ServerConnectionState.CONNECTING;
+import static com.mongodb.connection.ServerType.REPLICA_SET_PRIMARY;
+import static com.mongodb.connection.ServerType.UNKNOWN;
+import static java.util.Collections.singletonList;
+import static java.util.concurrent.TimeUnit.MINUTES;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@DisplayName("ServerSessionPool")
+@ExtendWith(MockitoExtension.class)
+class ServerSessionPoolTest {
+
+ private ClusterDescription connectedDescription;
+ private ClusterDescription unconnectedDescription;
+
+ @Mock
+ private Cluster clusterMock;
+
+ @BeforeEach
+ void setUp() {
+ connectedDescription = new ClusterDescription(
+ MULTIPLE,
+ REPLICA_SET,
+ singletonList(
+ ServerDescription.builder()
+ .ok(true)
+ .state(CONNECTED)
+ .address(new ServerAddress())
+ .type(REPLICA_SET_PRIMARY)
+ .logicalSessionTimeoutMinutes(30)
+ .build()
+ ),
+ ClusterSettings.builder().hosts(singletonList(new ServerAddress())).build(),
+ ServerSettings.builder().build()
+ );
+
+ unconnectedDescription = new ClusterDescription(
+ MULTIPLE,
+ REPLICA_SET,
+ singletonList(
+ ServerDescription.builder()
+ .ok(true)
+ .state(CONNECTING)
+ .address(new ServerAddress())
+ .type(UNKNOWN)
+ .logicalSessionTimeoutMinutes(null)
+ .build()
+ ),
+ ClusterSettings.builder().hosts(singletonList(new ServerAddress())).build(),
+ ServerSettings.builder().build()
+ );
+ }
+
+ @Test
+ @DisplayName("should get session from pool")
+ void testGetSession() {
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi());
+
+ ServerSession session = pool.get();
+
+ assertNotNull(session);
+ }
+
+ @Test
+ @DisplayName("should throw IllegalStateException when pool is closed")
+ void testThrowExceptionIfPoolClosed() {
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi());
+ pool.close();
+
+ assertThrows(IllegalStateException.class, pool::get);
+ }
+
+ @Test
+ @DisplayName("should reuse released session from pool")
+ void testPoolSession() {
+ when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription);
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi());
+
+ ServerSession session = pool.get();
+ pool.release(session);
+ ServerSession pooledSession = pool.get();
+
+ assertEquals(session, pooledSession);
+ }
+
+ @Test
+ @DisplayName("should prune expired sessions when getting new session")
+ void testPruneSessionsWhenGetting() {
+ when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription);
+
+ ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class);
+ when(clock.millis()).thenReturn(0L, MINUTES.toMillis(29) + 1);
+
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock);
+ ServerSession sessionOne = pool.get();
+
+ pool.release(sessionOne);
+ assertFalse(sessionOne.isClosed());
+
+ ServerSession sessionTwo = pool.get();
+
+ assertNotEquals(sessionTwo, sessionOne);
+ assertTrue(sessionOne.isClosed());
+ }
+
+ @Test
+ @DisplayName("should not prune session when timeout is null")
+ void testNotPruneSessionWhenTimeoutIsNull() {
+ when(clusterMock.getCurrentDescription()).thenReturn(unconnectedDescription);
+
+ ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class);
+ when(clock.millis()).thenReturn(0L, 0L, 0L);
+
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock);
+ ServerSession session = pool.get();
+
+ pool.release(session);
+ ServerSession newSession = pool.get();
+
+ assertEquals(session, newSession);
+ }
+
+ @Test
+ @DisplayName("should initialize session with correct properties")
+ void testInitializeSession() {
+ ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class);
+ when(clock.millis()).thenReturn(42L);
+
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock);
+ ServerSession session = pool.get();
+
+ ServerSessionPool.ServerSessionImpl sessionImpl = (ServerSessionPool.ServerSessionImpl) session;
+ assertEquals(42L, sessionImpl.getLastUsedAtMillis());
+ assertEquals(0L, sessionImpl.getTransactionNumber());
+
+ BsonDocument identifier = sessionImpl.getIdentifier();
+ assertNotNull(identifier);
+ byte[] uuid = identifier.getBinary("id").getData();
+ assertNotNull(uuid);
+ assertEquals(16, uuid.length);
+ }
+
+ @Test
+ @DisplayName("should advance transaction number")
+ void testAdvanceTransactionNumber() {
+ ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class);
+ when(clock.millis()).thenReturn(42L);
+
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock);
+ ServerSession session = pool.get();
+
+ ServerSessionPool.ServerSessionImpl sessionImpl = (ServerSessionPool.ServerSessionImpl) session;
+ assertEquals(0L, sessionImpl.getTransactionNumber());
+ assertEquals(1L, sessionImpl.advanceTransactionNumber());
+ assertEquals(1L, sessionImpl.getTransactionNumber());
+ }
+
+ @Test
+ @DisplayName("should end pooled sessions when pool is closed")
+ void testEndPooledSessionsWhenPoolClosed() {
+ Connection connection = mock(Connection.class);
+ Server server = mock(Server.class);
+ when(server.getConnection(any())).thenReturn(connection);
+
+ when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription);
+ when(clusterMock.selectServer(any(), any()))
+ .thenReturn(new ServerTuple(server, connectedDescription.getServerDescriptions().get(0)));
+
+ when(connection.command(
+ any(String.class),
+ any(BsonDocument.class),
+ any(),
+ any(),
+ any(),
+ any()
+ )).thenReturn(new BsonDocument());
+
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi());
+ List sessions = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ sessions.add(pool.get());
+ }
+
+ for (ServerSession session : sessions) {
+ pool.release(session);
+ }
+
+ pool.close();
+
+ verify(clusterMock, times(1)).selectServer(any(), any());
+ verify(connection, times(1)).command(
+ any(String.class),
+ argThat(endSessionsDocMatcher(sessions)),
+ any(),
+ any(),
+ any(),
+ any()
+ );
+ verify(connection, times(1)).release();
+ }
+
+ @Test
+ @DisplayName("should handle MongoException during endSessions without leaking resources")
+ void testHandleMongoExceptionDuringEndSessionsWithoutLeakingResources() {
+ Connection connection = mock(Connection.class);
+ Server server = mock(Server.class);
+ when(server.getConnection(any())).thenReturn(connection);
+
+ when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription);
+ when(clusterMock.selectServer(any(), any()))
+ .thenReturn(new ServerTuple(server, connectedDescription.getServerDescriptions().get(0)));
+
+ when(connection.command(
+ any(String.class),
+ any(BsonDocument.class),
+ any(),
+ any(),
+ any(),
+ any()
+ )).thenThrow(new MongoException("Simulated error"));
+
+ ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi());
+ List sessions = new ArrayList<>();
+ for (int i = 0; i < 5; i++) {
+ sessions.add(pool.get());
+ }
+
+ for (ServerSession session : sessions) {
+ pool.release(session);
+ }
+
+ // Should not throw - exception is handled internally
+ pool.close();
+
+ verify(clusterMock, times(1)).selectServer(any(), any());
+ verify(connection, times(1)).release();
+ }
+
+ /**
+ * Matcher to verify the endSessions document contains the correct session identifiers.
+ */
+ private ArgumentMatcher endSessionsDocMatcher(List sessions) {
+ return doc -> {
+ if (!doc.containsKey("endSessions")) {
+ return false;
+ }
+ BsonArray endSessionsArray = doc.getArray("endSessions");
+ if (endSessionsArray.size() != sessions.size()) {
+ return false;
+ }
+ for (int i = 0; i < sessions.size(); i++) {
+ ServerSession session = sessions.get(i);
+ BsonDocument sessionIdentifier = session.getIdentifier();
+ BsonDocument arrayElement = endSessionsArray.get(i).asDocument();
+ if (!sessionIdentifier.equals(arrayElement)) {
+ return false;
+ }
+ }
+ return true;
+ };
+ }
+}
diff --git a/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt b/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt
index 6c53a1faf47..cbe308eece0 100644
--- a/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt
+++ b/driver-kotlin-coroutine/src/main/kotlin/com/mongodb/kotlin/client/coroutine/ClientSession.kt
@@ -19,6 +19,7 @@ import com.mongodb.ClientSessionOptions
import com.mongodb.ServerAddress
import com.mongodb.TransactionOptions
import com.mongodb.internal.TimeoutContext
+import com.mongodb.internal.observability.micrometer.TransactionSpan
import com.mongodb.reactivestreams.client.ClientSession as reactiveClientSession
import com.mongodb.session.ClientSession as jClientSession
import com.mongodb.session.ServerSession
@@ -58,6 +59,9 @@ public class ClientSession(public val wrapped: reactiveClientSession) : jClientS
*/
public fun notifyOperationInitiated(operation: Any): Unit = wrapped.notifyOperationInitiated(operation)
+ /** Get the transaction span (if started). */
+ public fun getTransactionSpan(): TransactionSpan? = wrapped.transactionSpan
+
/**
* Get the server address of the pinned mongos on this session. For internal use only.
*
diff --git a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java
index dd761234df9..75a60ca382f 100644
--- a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java
+++ b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java
@@ -39,9 +39,9 @@ class DBDecoderAdapter implements Decoder {
@Override
public DBObject decode(final BsonReader reader, final DecoderContext decoderContext) {
- ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider);
- BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput);
- try {
+
+ try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider);
+ BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput)) {
binaryWriter.pipe(reader);
BufferExposingByteArrayOutputStream byteArrayOutputStream =
new BufferExposingByteArrayOutputStream(binaryWriter.getBsonOutput().getSize());
@@ -50,9 +50,6 @@ public DBObject decode(final BsonReader reader, final DecoderContext decoderCont
} catch (IOException e) {
// impossible with a byte array output stream
throw new MongoInternalException("An unlikely IOException thrown.", e);
- } finally {
- binaryWriter.close();
- bsonOutput.close();
}
}
diff --git a/driver-reactive-streams/build.gradle.kts b/driver-reactive-streams/build.gradle.kts
index dab192e2583..b55dd95d683 100644
--- a/driver-reactive-streams/build.gradle.kts
+++ b/driver-reactive-streams/build.gradle.kts
@@ -15,6 +15,7 @@
*/
import ProjectExtensions.configureJarManifest
import ProjectExtensions.configureMavenPublication
+import project.DEFAULT_JAVA_VERSION
plugins {
id("project.java")
@@ -36,6 +37,9 @@ dependencies {
implementation(libs.project.reactor.core)
compileOnly(project(path = ":mongodb-crypt", configuration = "default"))
+ optionalImplementation(platform(libs.micrometer.observation.bom))
+ optionalImplementation(libs.micrometer.observation)
+
testImplementation(libs.project.reactor.test)
testImplementation(project(path = ":driver-sync", configuration = "default"))
testImplementation(project(path = ":bson", configuration = "testArtifacts"))
@@ -45,11 +49,20 @@ dependencies {
// Reactive Streams TCK testing
testImplementation(libs.reactive.streams.tck)
- // Tracing
+ // Tracing testing
testImplementation(platform(libs.micrometer.tracing.integration.test.bom))
testImplementation(libs.micrometer.tracing.integration.test) { exclude(group = "org.junit.jupiter") }
}
+tasks.withType {
+ // Needed for MicrometerProseTest to set env variable programmatically (calls
+ // `field.setAccessible(true)`)
+ val testJavaVersion: Int = findProperty("javaVersion")?.toString()?.toInt() ?: DEFAULT_JAVA_VERSION
+ if (testJavaVersion >= DEFAULT_JAVA_VERSION) {
+ jvmArgs("--add-opens=java.base/java.util=ALL-UNNAMED")
+ }
+}
+
configureMavenPublication {
pom {
name.set("The MongoDB Reactive Streams Driver")
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java
index 3d9354e9ae9..fe58864fad0 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/ClientSession.java
@@ -18,6 +18,8 @@
package com.mongodb.reactivestreams.client;
import com.mongodb.TransactionOptions;
+import com.mongodb.internal.observability.micrometer.TransactionSpan;
+import com.mongodb.lang.Nullable;
import org.reactivestreams.Publisher;
/**
@@ -94,4 +96,13 @@ public interface ClientSession extends com.mongodb.session.ClientSession {
* @mongodb.server.release 4.0
*/
Publisher abortTransaction();
+
+ /**
+ * Get the transaction span (if started).
+ *
+ * @return the transaction span
+ * @since 5.7
+ */
+ @Nullable
+ TransactionSpan getTransactionSpan();
}
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java
index 30714a6a576..b5e94c02975 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionHelper.java
@@ -18,6 +18,7 @@
import com.mongodb.ClientSessionOptions;
import com.mongodb.TransactionOptions;
+import com.mongodb.internal.observability.micrometer.TracingManager;
import com.mongodb.internal.session.ServerSessionPool;
import com.mongodb.lang.Nullable;
import com.mongodb.reactivestreams.client.ClientSession;
@@ -31,10 +32,13 @@
public class ClientSessionHelper {
private final MongoClientImpl mongoClient;
private final ServerSessionPool serverSessionPool;
+ private final TracingManager tracingManager;
- public ClientSessionHelper(final MongoClientImpl mongoClient, final ServerSessionPool serverSessionPool) {
+ public ClientSessionHelper(final MongoClientImpl mongoClient, final ServerSessionPool serverSessionPool,
+ final TracingManager tracingManager) {
this.mongoClient = mongoClient;
this.serverSessionPool = serverSessionPool;
+ this.tracingManager = tracingManager;
}
Mono withClientSession(@Nullable final ClientSession clientSessionFromOperation, final OperationExecutor executor) {
@@ -62,6 +66,6 @@ ClientSession createClientSession(final ClientSessionOptions options, final Oper
.readPreference(mongoClient.getSettings().getReadPreference())
.build()))
.build();
- return new ClientSessionPublisherImpl(serverSessionPool, mongoClient, mergedOptions, executor);
+ return new ClientSessionPublisherImpl(serverSessionPool, mongoClient, mergedOptions, executor, tracingManager);
}
}
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java
index 5cf0ea103bd..6d38a0731ab 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java
@@ -24,6 +24,8 @@
import com.mongodb.TransactionOptions;
import com.mongodb.WriteConcern;
import com.mongodb.internal.TimeoutContext;
+import com.mongodb.internal.observability.micrometer.TracingManager;
+import com.mongodb.internal.observability.micrometer.TransactionSpan;
import com.mongodb.internal.operation.AbortTransactionOperation;
import com.mongodb.internal.operation.CommitTransactionOperation;
import com.mongodb.internal.operation.ReadOperation;
@@ -37,6 +39,8 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
+import java.util.concurrent.atomic.AtomicBoolean;
+
import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL;
import static com.mongodb.MongoException.UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL;
import static com.mongodb.assertions.Assertions.assertNotNull;
@@ -46,19 +50,23 @@
final class ClientSessionPublisherImpl extends BaseClientSessionImpl implements ClientSession {
+ private final AtomicBoolean closed = new AtomicBoolean();
private final MongoClientImpl mongoClient;
private final OperationExecutor executor;
+ private final TracingManager tracingManager;
private TransactionState transactionState = TransactionState.NONE;
private boolean messageSentInCurrentTransaction;
private boolean commitInProgress;
private TransactionOptions transactionOptions;
-
+ @Nullable
+ private TransactionSpan transactionSpan;
ClientSessionPublisherImpl(final ServerSessionPool serverSessionPool, final MongoClientImpl mongoClient,
- final ClientSessionOptions options, final OperationExecutor executor) {
+ final ClientSessionOptions options, final OperationExecutor executor, final TracingManager tracingManager) {
super(serverSessionPool, mongoClient, options);
this.executor = executor;
this.mongoClient = mongoClient;
+ this.tracingManager = tracingManager;
}
@Override
@@ -128,6 +136,10 @@ public void startTransaction(final TransactionOptions transactionOptions) {
if (!writeConcern.isAcknowledged()) {
throw new MongoClientException("Transactions do not support unacknowledged write concern");
}
+
+ if (tracingManager.isEnabled()) {
+ transactionSpan = new TransactionSpan(tracingManager);
+ }
clearTransactionContext();
setTimeoutContext(timeoutContext);
}
@@ -152,6 +164,9 @@ public Publisher commitTransaction() {
}
if (!messageSentInCurrentTransaction) {
cleanupTransaction(TransactionState.COMMITTED);
+ if (transactionSpan != null) {
+ transactionSpan.finalizeTransactionSpan(TransactionState.COMMITTED.name());
+ }
return Mono.create(MonoSink::success);
} else {
ReadConcern readConcern = transactionOptions.getReadConcern();
@@ -171,7 +186,17 @@ public Publisher commitTransaction() {
commitInProgress = false;
transactionState = TransactionState.COMMITTED;
})
- .doOnError(MongoException.class, this::clearTransactionContextOnError);
+ .doOnError(MongoException.class, e -> {
+ clearTransactionContextOnError(e);
+ if (transactionSpan != null) {
+ transactionSpan.handleTransactionSpanError(e);
+ }
+ })
+ .doOnSuccess(v -> {
+ if (transactionSpan != null) {
+ transactionSpan.finalizeTransactionSpan(TransactionState.COMMITTED.name());
+ }
+ });
}
});
}
@@ -191,6 +216,9 @@ public Publisher abortTransaction() {
}
if (!messageSentInCurrentTransaction) {
cleanupTransaction(TransactionState.ABORTED);
+ if (transactionSpan != null) {
+ transactionSpan.finalizeTransactionSpan(TransactionState.ABORTED.name());
+ }
return Mono.create(MonoSink::success);
} else {
ReadConcern readConcern = transactionOptions.getReadConcern();
@@ -208,6 +236,9 @@ public Publisher abortTransaction() {
.doOnTerminate(() -> {
clearTransactionContext();
cleanupTransaction(TransactionState.ABORTED);
+ if (transactionSpan != null) {
+ transactionSpan.finalizeTransactionSpan(TransactionState.ABORTED.name());
+ }
});
}
});
@@ -219,12 +250,26 @@ private void clearTransactionContextOnError(final MongoException e) {
}
}
+ @Override
+ @Nullable
+ public TransactionSpan getTransactionSpan() {
+ return transactionSpan;
+ }
+
@Override
public void close() {
- if (transactionState == TransactionState.IN) {
- Mono.from(abortTransaction()).doFinally(it -> super.close()).subscribe();
- } else {
- super.close();
+ if (closed.compareAndSet(false, true)) {
+ if (transactionState == TransactionState.IN) {
+ Mono.from(abortTransaction())
+ .doFinally(it -> {
+ clearTransactionContext();
+ super.close();
+ })
+ .subscribe();
+ } else {
+ clearTransactionContext();
+ super.close();
+ }
}
}
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java
index 07a17badcd7..8fda2e9294d 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MongoClientImpl.java
@@ -33,6 +33,7 @@
import com.mongodb.internal.connection.Cluster;
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
+import com.mongodb.internal.observability.micrometer.TracingManager;
import com.mongodb.internal.session.ServerSessionPool;
import com.mongodb.lang.Nullable;
import com.mongodb.reactivestreams.client.ChangeStreamPublisher;
@@ -88,9 +89,10 @@ private MongoClientImpl(final MongoClientSettings settings, final MongoDriverInf
notNull("settings", settings);
notNull("cluster", cluster);
+ TracingManager tracingManager = new TracingManager(settings.getObservabilitySettings());
TimeoutSettings timeoutSettings = TimeoutSettings.create(settings);
ServerSessionPool serverSessionPool = new ServerSessionPool(cluster, timeoutSettings, settings.getServerApi());
- ClientSessionHelper clientSessionHelper = new ClientSessionHelper(this, serverSessionPool);
+ ClientSessionHelper clientSessionHelper = new ClientSessionHelper(this, serverSessionPool, tracingManager);
AutoEncryptionSettings autoEncryptSettings = settings.getAutoEncryptionSettings();
Crypt crypt = autoEncryptSettings != null ? Crypts.createCrypt(settings, autoEncryptSettings) : null;
@@ -100,7 +102,8 @@ private MongoClientImpl(final MongoClientSettings settings, final MongoDriverInf
+ ReactiveContextProvider.class.getName() + " when using the Reactive Streams driver");
}
OperationExecutor operationExecutor = executor != null ? executor
- : new OperationExecutorImpl(this, clientSessionHelper, timeoutSettings, (ReactiveContextProvider) contextProvider);
+ : new OperationExecutorImpl(this, clientSessionHelper, timeoutSettings, (ReactiveContextProvider) contextProvider,
+ tracingManager);
MongoOperationPublisher mongoOperationPublisher = new MongoOperationPublisher<>(Document.class,
withUuidRepresentation(settings.getCodecRegistry(),
settings.getUuidRepresentation()),
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java
index ef18c2c6b1f..62a4431cc9a 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java
@@ -31,10 +31,11 @@
import com.mongodb.internal.binding.AsyncReadWriteBinding;
import com.mongodb.internal.connection.OperationContext;
import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext;
+import com.mongodb.internal.observability.micrometer.Span;
+import com.mongodb.internal.observability.micrometer.TracingManager;
import com.mongodb.internal.operation.OperationHelper;
import com.mongodb.internal.operation.ReadOperation;
import com.mongodb.internal.operation.WriteOperation;
-import com.mongodb.internal.observability.micrometer.TracingManager;
import com.mongodb.lang.Nullable;
import com.mongodb.reactivestreams.client.ClientSession;
import com.mongodb.reactivestreams.client.ReactiveContextProvider;
@@ -63,13 +64,16 @@ public class OperationExecutorImpl implements OperationExecutor {
@Nullable
private final ReactiveContextProvider contextProvider;
private final TimeoutSettings timeoutSettings;
+ private final TracingManager tracingManager;
OperationExecutorImpl(final MongoClientImpl mongoClient, final ClientSessionHelper clientSessionHelper,
- final TimeoutSettings timeoutSettings, @Nullable final ReactiveContextProvider contextProvider) {
+ final TimeoutSettings timeoutSettings, @Nullable final ReactiveContextProvider contextProvider,
+ final TracingManager tracingManager) {
this.mongoClient = mongoClient;
this.clientSessionHelper = clientSessionHelper;
this.timeoutSettings = timeoutSettings;
this.contextProvider = contextProvider;
+ this.tracingManager = tracingManager;
}
@Override
@@ -93,22 +97,37 @@ public Mono execute(final ReadOperation, T> operation, final ReadPrefer
OperationContext operationContext = getOperationContext(requestContext, actualClientSession, readConcern, operation.getCommandName())
.withSessionContext(new ClientSessionBinding.AsyncClientSessionContext(actualClientSession,
isImplicitSession(session), readConcern));
+ Span span = tracingManager.createOperationSpan(actualClientSession.getTransactionSpan(),
+ operationContext, operation.getCommandName(), operation.getNamespace());
if (session != null && session.hasActiveTransaction() && !binding.getReadPreference().equals(primary())) {
binding.release();
- return Mono.error(new MongoClientException("Read preference in a transaction must be primary"));
+ MongoClientException error = new MongoClientException("Read preference in a transaction must be primary");
+ if (span != null) {
+ span.error(error);
+ span.end();
+ }
+ return Mono.error(error);
} else {
return Mono.create(sink -> operation.executeAsync(binding, operationContext, (result, t) -> {
try {
binding.release();
} finally {
+ if (t != null) {
+ Throwable exceptionToHandle = t instanceof MongoException
+ ? OperationHelper.unwrap((MongoException) t) : t;
+ labelException(session, exceptionToHandle);
+ unpinServerAddressOnTransientTransactionError(session, exceptionToHandle);
+ if (span != null) {
+ span.error(t);
+ }
+ }
+ if (span != null) {
+ span.end();
+ }
sinkToCallback(sink).onResult(result, t);
}
- })).doOnError((t) -> {
- Throwable exceptionToHandle = t instanceof MongoException ? OperationHelper.unwrap((MongoException) t) : t;
- labelException(session, exceptionToHandle);
- unpinServerAddressOnTransientTransactionError(session, exceptionToHandle);
- });
+ }));
}
}).subscribe(subscriber)
);
@@ -133,18 +152,28 @@ public Mono execute(final WriteOperation operation, final ReadConcern
OperationContext operationContext = getOperationContext(requestContext, actualClientSession, readConcern, operation.getCommandName())
.withSessionContext(new ClientSessionBinding.AsyncClientSessionContext(actualClientSession,
isImplicitSession(session), readConcern));
+ Span span = tracingManager.createOperationSpan(actualClientSession.getTransactionSpan(),
+ operationContext, operation.getCommandName(), operation.getNamespace());
return Mono.create(sink -> operation.executeAsync(binding, operationContext, (result, t) -> {
try {
binding.release();
} finally {
+ if (t != null) {
+ Throwable exceptionToHandle = t instanceof MongoException
+ ? OperationHelper.unwrap((MongoException) t) : t;
+ labelException(session, exceptionToHandle);
+ unpinServerAddressOnTransientTransactionError(session, exceptionToHandle);
+ if (span != null) {
+ span.error(t);
+ }
+ }
+ if (span != null) {
+ span.end();
+ }
sinkToCallback(sink).onResult(result, t);
}
- })).doOnError((t) -> {
- Throwable exceptionToHandle = t instanceof MongoException ? OperationHelper.unwrap((MongoException) t) : t;
- labelException(session, exceptionToHandle);
- unpinServerAddressOnTransientTransactionError(session, exceptionToHandle);
- });
+ }));
}
).subscribe(subscriber)
);
@@ -155,7 +184,7 @@ public OperationExecutor withTimeoutSettings(final TimeoutSettings newTimeoutSet
if (Objects.equals(timeoutSettings, newTimeoutSettings)) {
return this;
}
- return new OperationExecutorImpl(mongoClient, clientSessionHelper, newTimeoutSettings, contextProvider);
+ return new OperationExecutorImpl(mongoClient, clientSessionHelper, newTimeoutSettings, contextProvider, tracingManager);
}
@Override
@@ -214,7 +243,7 @@ private OperationContext getOperationContext(final RequestContext requestContext
requestContext,
new ReadConcernAwareNoOpSessionContext(readConcern),
createTimeoutContext(session, timeoutSettings),
- TracingManager.NO_OP,
+ tracingManager,
mongoClient.getSettings().getServerApi(),
commandName);
}
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java
index bc4da3026a9..cefdf7184d8 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java
@@ -55,8 +55,14 @@ public static MongoCollection collectionWithTimeout(final MongoCollection
public static Mono> collectionWithTimeoutMono(final MongoCollection collection,
@Nullable final Timeout timeout) {
+ return collectionWithTimeoutMono(collection, timeout, DEFAULT_TIMEOUT_MESSAGE);
+ }
+
+ public static Mono> collectionWithTimeoutMono(final MongoCollection collection,
+ @Nullable final Timeout timeout,
+ final String message) {
try {
- return Mono.just(collectionWithTimeout(collection, timeout));
+ return Mono.just(collectionWithTimeout(collection, timeout, message));
} catch (MongoOperationTimeoutException e) {
return Mono.error(e);
}
@@ -64,9 +70,14 @@ public static Mono> collectionWithTimeoutMono(final Mongo
public static Mono> collectionWithTimeoutDeferred(final MongoCollection collection,
@Nullable final Timeout timeout) {
- return Mono.defer(() -> collectionWithTimeoutMono(collection, timeout));
+ return collectionWithTimeoutDeferred(collection, timeout, DEFAULT_TIMEOUT_MESSAGE);
}
+ public static Mono> collectionWithTimeoutDeferred(final MongoCollection collection,
+ @Nullable final Timeout timeout,
+ final String message) {
+ return Mono.defer(() -> collectionWithTimeoutMono(collection, timeout, message));
+ }
public static MongoDatabase databaseWithTimeout(final MongoDatabase database,
@Nullable final Timeout timeout) {
diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java
index 7d9a46cdf3f..50586e92102 100644
--- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java
+++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSUploadPublisherImpl.java
@@ -54,7 +54,8 @@
*/
public final class GridFSUploadPublisherImpl implements GridFSUploadPublisher {
- private static final String TIMEOUT_ERROR_MESSAGE = "Saving chunks exceeded the timeout limit.";
+ private static final String TIMEOUT_ERROR_MESSAGE_CHUNKS_SAVING = "Saving chunks exceeded the timeout limit.";
+ private static final String TIMEOUT_ERROR_MESSAGE_UPLOAD_CANCELLATION = "Upload cancellation exceeded the timeout limit.";
private static final Document PROJECTION = new Document("_id", 1);
private static final Document FILES_INDEX = new Document("filename", 1).append("uploadDate", 1);
private static final Document CHUNKS_INDEX = new Document("files_id", 1).append("n", 1);
@@ -226,8 +227,8 @@ private Mono createSaveChunksMono(final AtomicBoolean terminated, @Nullabl
.append("data", data);
Publisher insertOnePublisher = clientSession == null
- ? collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(chunkDocument)
- : collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE)
+ ? collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE_CHUNKS_SAVING).insertOne(chunkDocument)
+ : collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE_CHUNKS_SAVING)
.insertOne(clientSession, chunkDocument);
return Mono.from(insertOnePublisher).thenReturn(data.length());
@@ -270,7 +271,8 @@ private Mono createSaveFileDataMono(final AtomicBoolean termina
}
private Mono createCancellationMono(final AtomicBoolean terminated, @Nullable final Timeout timeout) {
- Mono> chunksCollectionMono = collectionWithTimeoutDeferred(chunksCollection, timeout);
+ Mono> chunksCollectionMono = collectionWithTimeoutDeferred(chunksCollection, timeout,
+ TIMEOUT_ERROR_MESSAGE_UPLOAD_CANCELLATION);
if (terminated.compareAndSet(false, true)) {
if (clientSession != null) {
return chunksCollectionMono.flatMap(collection -> Mono.from(collection
diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java
index b922ec20b71..90446953fc1 100644
--- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java
+++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java
@@ -16,7 +16,6 @@
package com.mongodb.reactivestreams.client;
-import com.mongodb.ClusterFixture;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCommandException;
import com.mongodb.MongoNamespace;
@@ -24,7 +23,6 @@
import com.mongodb.ReadPreference;
import com.mongodb.WriteConcern;
import com.mongodb.client.AbstractClientSideOperationsTimeoutProseTest;
-import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.changestream.FullDocument;
import com.mongodb.event.CommandFailedEvent;
import com.mongodb.event.CommandStartedEvent;
@@ -43,6 +41,7 @@
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks;
+import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import java.nio.ByteBuffer;
@@ -58,12 +57,16 @@
import static com.mongodb.ClusterFixture.TIMEOUT_DURATION;
import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet;
+import static com.mongodb.ClusterFixture.isStandalone;
import static com.mongodb.ClusterFixture.serverVersionAtLeast;
import static com.mongodb.ClusterFixture.sleep;
+import static com.mongodb.assertions.Assertions.assertTrue;
import static java.util.Collections.singletonList;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assumptions.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
@@ -104,7 +107,6 @@ protected boolean isAsync() {
@Override
public void testGridFSUploadViaOpenUploadStreamTimeout() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
//given
collectionHelper.runAdminCommand("{"
@@ -113,12 +115,12 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() {
+ " data: {"
+ " failCommands: [\"insert\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 405)
+ + " blockTimeMS: " + 600
+ " }"
+ "}");
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 400, TimeUnit.MILLISECONDS))) {
+ .timeout(600, TimeUnit.MILLISECONDS))) {
MongoDatabase database = client.getDatabase(gridFsFileNamespace.getDatabaseName());
GridFSBucket gridFsBucket = createReaciveGridFsBucket(database, GRID_FS_BUCKET_NAME);
@@ -158,7 +160,6 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() {
@Override
public void testAbortingGridFsUploadStreamTimeout() throws ExecutionException, InterruptedException, TimeoutException {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
//given
CompletableFuture droppedErrorFuture = new CompletableFuture<>();
@@ -170,12 +171,12 @@ public void testAbortingGridFsUploadStreamTimeout() throws ExecutionException, I
+ " data: {"
+ " failCommands: [\"delete\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 405)
+ + " blockTimeMS: " + 405
+ " }"
+ "}");
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 400, TimeUnit.MILLISECONDS))) {
+ .timeout(400, TimeUnit.MILLISECONDS))) {
MongoDatabase database = client.getDatabase(gridFsFileNamespace.getDatabaseName());
GridFSBucket gridFsBucket = createReaciveGridFsBucket(database, GRID_FS_BUCKET_NAME);
@@ -198,12 +199,25 @@ public void testAbortingGridFsUploadStreamTimeout() throws ExecutionException, I
//then
Throwable droppedError = droppedErrorFuture.get(TIMEOUT_DURATION.toMillis(), TimeUnit.MILLISECONDS);
Throwable commandError = droppedError.getCause();
- assertInstanceOf(MongoOperationTimeoutException.class, commandError);
CommandFailedEvent deleteFailedEvent = commandListener.getCommandFailedEvent("delete");
assertNotNull(deleteFailedEvent);
- assertEquals(commandError, commandListener.getCommandFailedEvent("delete").getThrowable());
+ CommandStartedEvent deleteStartedEvent = commandListener.getCommandStartedEvent("delete");
+ assertTrue(deleteStartedEvent.getCommand().containsKey("maxTimeMS"), "Expected delete command to have maxTimeMS");
+ long deleteMaxTimeMS = deleteStartedEvent
+ .getCommand()
+ .get("maxTimeMS")
+ .asNumber()
+ .longValue();
+
+ assertTrue(deleteMaxTimeMS <= 420
+ // some leeway for timing variations, when compression is used it is often less then 300.
+ // Without it, it is more than 300.
+ && deleteMaxTimeMS >= 150,
+ "Expected maxTimeMS for delete command to be between 150s and 420ms, " + "but was: " + deleteMaxTimeMS + "ms");
+ assertEquals(commandError, deleteFailedEvent.getThrowable());
+
// When subscription is cancelled, we should not receive any more events.
testSubscriber.assertNoTerminalEvent();
}
@@ -219,9 +233,8 @@ public void testTimeoutMSAppliesToFullResumeAttemptInNextCall() {
assumeTrue(isDiscoverableReplicaSet());
//given
- long rtt = ClusterFixture.getPrimaryRTT();
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 500, TimeUnit.MILLISECONDS))) {
+ .timeout(500, TimeUnit.MILLISECONDS))) {
MongoNamespace namespace = generateNamespace();
MongoCollection collection = client.getDatabase(namespace.getDatabaseName())
@@ -273,9 +286,8 @@ public void testTimeoutMSAppliedToInitialAggregate() {
assumeTrue(isDiscoverableReplicaSet());
//given
- long rtt = ClusterFixture.getPrimaryRTT();
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 200, TimeUnit.MILLISECONDS))) {
+ .timeout(200, TimeUnit.MILLISECONDS))) {
MongoNamespace namespace = generateNamespace();
MongoCollection collection = client.getDatabase(namespace.getDatabaseName())
@@ -290,7 +302,7 @@ public void testTimeoutMSAppliedToInitialAggregate() {
+ " data: {"
+ " failCommands: [\"aggregate\" ],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 201)
+ + " blockTimeMS: " + 201
+ " }"
+ "}");
@@ -321,13 +333,10 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsNotSet() {
//given
BsonTimestamp startTime = new BsonTimestamp((int) Instant.now().getEpochSecond(), 0);
- collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions());
sleep(2000);
-
- long rtt = ClusterFixture.getPrimaryRTT();
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 300, TimeUnit.MILLISECONDS))) {
+ .timeout(500, TimeUnit.MILLISECONDS))) {
MongoCollection collection = client.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary());
@@ -338,7 +347,7 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsNotSet() {
+ " data: {"
+ " failCommands: [\"getMore\", \"aggregate\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 200)
+ + " blockTimeMS: " + 200
+ " }"
+ "}");
@@ -389,12 +398,10 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsSet() {
//given
BsonTimestamp startTime = new BsonTimestamp((int) Instant.now().getEpochSecond(), 0);
- collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions());
sleep(2000);
- long rtt = ClusterFixture.getPrimaryRTT();
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 300, TimeUnit.MILLISECONDS))) {
+ .timeout(500, TimeUnit.MILLISECONDS))) {
MongoCollection collection = client.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName())
@@ -406,7 +413,7 @@ public void testTimeoutMsRefreshedForGetMoreWhenMaxAwaitTimeMsSet() {
+ " data: {"
+ " failCommands: [\"aggregate\", \"getMore\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 200)
+ + " blockTimeMS: " + 200
+ " }"
+ "}");
@@ -449,9 +456,8 @@ public void testTimeoutMsISHonoredForNnextOperationWhenSeveralGetMoreExecutedInt
assumeTrue(isDiscoverableReplicaSet());
//given
- long rtt = ClusterFixture.getPrimaryRTT();
try (MongoClient client = createReactiveClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 2500, TimeUnit.MILLISECONDS))) {
+ .timeout(2500, TimeUnit.MILLISECONDS))) {
MongoCollection collection = client.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary());
@@ -468,7 +474,78 @@ public void testTimeoutMsISHonoredForNnextOperationWhenSeveralGetMoreExecutedInt
List commandStartedEvents = commandListener.getCommandStartedEvents();
assertCommandStartedEventsInOder(Arrays.asList("aggregate", "getMore", "getMore", "getMore", "killCursors"),
commandStartedEvents);
- assertOnlyOneCommandTimeoutFailure("getMore");
+
+ }
+ }
+
+ @DisplayName("9. End Session. The timeout specified via the MongoClient timeoutMS option")
+ @Test
+ @Override
+ public void test9EndSessionClientTimeout() {
+ assumeTrue(serverVersionAtLeast(4, 4));
+ assumeFalse(isStandalone());
+
+ collectionHelper.runAdminCommand("{"
+ + " configureFailPoint: \"failCommand\","
+ + " mode: { times: 1 },"
+ + " data: {"
+ + " failCommands: [\"abortTransaction\"],"
+ + " blockConnection: true,"
+ + " blockTimeMS: " + 400
+ + " }"
+ + "}");
+
+ try (MongoClient mongoClient = createReactiveClient(getMongoClientSettingsBuilder().retryWrites(false)
+ .timeout(300, TimeUnit.MILLISECONDS))) {
+ MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
+ .getCollection(namespace.getCollectionName());
+
+ try (ClientSession session = Mono.from(mongoClient.startSession()).block()) {
+ session.startTransaction();
+ Mono.from(collection.insertOne(session, new Document("x", 1))).block();
+ }
+
+ sleep(postSessionCloseSleep());
+ CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() -> commandListener.getCommandFailedEvent("abortTransaction"));
+ long elapsedTime = abortTransactionEvent.getElapsedTime(TimeUnit.MILLISECONDS);
+ assertInstanceOf(MongoOperationTimeoutException.class, abortTransactionEvent.getThrowable());
+ assertTrue(elapsedTime <= 400, "Took too long to time out, elapsedMS: " + elapsedTime);
+ }
+ }
+
+ @Test
+ @DisplayName("9. End Session. The timeout specified via the ClientSession defaultTimeoutMS option")
+ @Override
+ public void test9EndSessionSessionTimeout() {
+ assumeTrue(serverVersionAtLeast(4, 4));
+ assumeFalse(isStandalone());
+
+ collectionHelper.runAdminCommand("{"
+ + " configureFailPoint: \"failCommand\","
+ + " mode: { times: 1 },"
+ + " data: {"
+ + " failCommands: [\"abortTransaction\"],"
+ + " blockConnection: true,"
+ + " blockTimeMS: " + 400
+ + " }"
+ + "}");
+
+ try (MongoClient mongoClient = createReactiveClient(getMongoClientSettingsBuilder())) {
+ MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
+ .getCollection(namespace.getCollectionName());
+
+ try (ClientSession session = Mono.from(mongoClient.startSession(com.mongodb.ClientSessionOptions.builder()
+ .defaultTimeout(300, TimeUnit.MILLISECONDS).build())).block()) {
+
+ session.startTransaction();
+ Mono.from(collection.insertOne(session, new Document("x", 1))).block();
+ }
+
+ sleep(postSessionCloseSleep());
+ CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() -> commandListener.getCommandFailedEvent("abortTransaction"));
+ long elapsedTime = abortTransactionEvent.getElapsedTime(TimeUnit.MILLISECONDS);
+ assertInstanceOf(MongoOperationTimeoutException.class, abortTransactionEvent.getThrowable());
+ assertTrue(elapsedTime <= 400, "Took too long to time out, elapsedMS: " + elapsedTime);
}
}
@@ -512,6 +589,6 @@ public void tearDown() throws InterruptedException {
@Override
protected int postSessionCloseSleep() {
- return 256;
+ return 1000;
}
}
diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java
index 2881b47e38e..05ca89dd048 100644
--- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java
+++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java
@@ -24,6 +24,7 @@
import com.mongodb.MongoTimeoutException;
import com.mongodb.connection.ClusterType;
import com.mongodb.connection.ServerVersion;
+import com.mongodb.connection.TransportSettings;
import com.mongodb.reactivestreams.client.internal.MongoClientImpl;
import org.bson.Document;
import org.bson.conversions.Bson;
@@ -33,6 +34,7 @@
import java.util.List;
import static com.mongodb.ClusterFixture.TIMEOUT_DURATION;
+import static com.mongodb.ClusterFixture.getOverriddenTransportSettings;
import static com.mongodb.ClusterFixture.getServerApi;
import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException;
import static java.lang.Thread.sleep;
@@ -67,11 +69,18 @@ public static MongoClientSettings.Builder getMongoClientSettingsBuilder() {
}
public static MongoClientSettings.Builder getMongoClientSettingsBuilder(final ConnectionString connectionString) {
- MongoClientSettings.Builder builder = MongoClientSettings.builder();
+ MongoClientSettings.Builder builder = MongoClientSettings.builder()
+ .applyConnectionString(connectionString);
+
+ TransportSettings overriddenTransportSettings = getOverriddenTransportSettings();
+ if (overriddenTransportSettings != null) {
+ builder.transportSettings(overriddenTransportSettings);
+ }
+
if (getServerApi() != null) {
builder.serverApi(getServerApi());
}
- return builder.applyConnectionString(connectionString);
+ return builder;
}
public static String getDefaultDatabaseName() {
@@ -164,6 +173,11 @@ public static synchronized ConnectionString getConnectionString() {
public static MongoClientSettings.Builder getMongoClientBuilderFromConnectionString() {
MongoClientSettings.Builder builder = MongoClientSettings.builder()
.applyConnectionString(getConnectionString());
+
+ TransportSettings overriddenTransportSettings = getOverriddenTransportSettings();
+ if (overriddenTransportSettings != null) {
+ builder.transportSettings(overriddenTransportSettings);
+ }
if (getServerApi() != null) {
builder.serverApi(getServerApi());
}
diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java
new file mode 100644
index 00000000000..c58bb98f2cc
--- /dev/null
+++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/observability/MicrometerProseTest.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.reactivestreams.client.observability;
+
+import com.mongodb.MongoClientSettings;
+import com.mongodb.client.AbstractMicrometerProseTest;
+import com.mongodb.client.MongoClient;
+import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient;
+
+/**
+ * Reactive Streams driver implementation of the Micrometer prose tests.
+ */
+public class MicrometerProseTest extends AbstractMicrometerProseTest {
+ @Override
+ protected MongoClient createMongoClient(final MongoClientSettings settings) {
+ return new SyncMongoClient(settings);
+ }
+}
diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java
index e1d765150a7..473d57a3878 100644
--- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java
+++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/syncadapter/SyncClientSession.java
@@ -192,7 +192,7 @@ public TimeoutContext getTimeoutContext() {
@Override
@Nullable
public TransactionSpan getTransactionSpan() {
- return null;
+ return wrapped.getTransactionSpan();
}
private static void sleep(final long millis) {
diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java
new file mode 100644
index 00000000000..bf2e6205ad6
--- /dev/null
+++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.reactivestreams.client.unified;
+
+import org.junit.jupiter.params.provider.Arguments;
+
+import java.util.Collection;
+
+final class MicrometerTracingTest extends UnifiedReactiveStreamsTest {
+ private static Collection data() {
+ return getTestData("open-telemetry/tests");
+ }
+}
diff --git a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java
index c192ae17896..0fda131f4ff 100644
--- a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java
+++ b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/MongoClientImplTest.java
@@ -25,6 +25,7 @@
import com.mongodb.internal.connection.ClientMetadata;
import com.mongodb.internal.connection.Cluster;
import com.mongodb.internal.mockito.MongoMockito;
+import com.mongodb.internal.observability.micrometer.TracingManager;
import com.mongodb.internal.session.ServerSessionPool;
import com.mongodb.reactivestreams.client.ChangeStreamPublisher;
import com.mongodb.reactivestreams.client.ClientSession;
@@ -179,7 +180,7 @@ void testWatch() {
@Test
void testStartSession() {
ServerSessionPool serverSessionPool = mock(ServerSessionPool.class);
- ClientSessionHelper clientSessionHelper = new ClientSessionHelper(mongoClient, serverSessionPool);
+ ClientSessionHelper clientSessionHelper = new ClientSessionHelper(mongoClient, serverSessionPool, TracingManager.NO_OP);
assertAll("Start Session Tests",
() -> assertAll("check validation",
diff --git a/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java b/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java
index 920feb1f986..eb36678761a 100644
--- a/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java
+++ b/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java
@@ -22,7 +22,6 @@
import com.mongodb.MongoClientException;
import com.mongodb.MongoException;
import com.mongodb.MongoInternalException;
-import com.mongodb.MongoNamespace;
import com.mongodb.MongoQueryException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoTimeoutException;
@@ -53,17 +52,14 @@
import com.mongodb.internal.connection.Cluster;
import com.mongodb.internal.connection.OperationContext;
import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext;
+import com.mongodb.internal.observability.micrometer.Span;
+import com.mongodb.internal.observability.micrometer.TracingManager;
import com.mongodb.internal.operation.OperationHelper;
import com.mongodb.internal.operation.Operations;
import com.mongodb.internal.operation.ReadOperation;
import com.mongodb.internal.operation.WriteOperation;
import com.mongodb.internal.session.ServerSessionPool;
-import com.mongodb.internal.observability.micrometer.Span;
-import com.mongodb.internal.observability.micrometer.TraceContext;
-import com.mongodb.internal.observability.micrometer.TracingManager;
-import com.mongodb.internal.observability.micrometer.TransactionSpan;
import com.mongodb.lang.Nullable;
-import io.micrometer.common.KeyValues;
import org.bson.BsonDocument;
import org.bson.Document;
import org.bson.UuidRepresentation;
@@ -77,17 +73,11 @@
import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL;
import static com.mongodb.MongoException.UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL;
-import static com.mongodb.internal.MongoNamespaceHelper.COMMAND_COLLECTION_NAME;
import static com.mongodb.ReadPreference.primary;
import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.assertions.Assertions.isTrueArgument;
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.internal.TimeoutContext.createTimeoutContext;
-import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COLLECTION;
-import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NAMESPACE;
-import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.OPERATION_NAME;
-import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.OPERATION_SUMMARY;
-import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SYSTEM;
final class MongoClusterImpl implements MongoCluster {
@Nullable
@@ -434,7 +424,8 @@ public T execute(final ReadOperation operation, final ReadPreference r
boolean implicitSession = isImplicitSession(session);
OperationContext operationContext = getOperationContext(actualClientSession, readConcern, operation.getCommandName())
.withSessionContext(new ClientSessionBinding.SyncClientSessionContext(actualClientSession, readConcern, implicitSession));
- Span span = createOperationSpan(actualClientSession, operationContext, operation.getCommandName(), operation.getNamespace());
+ Span span = operationContext.getTracingManager().createOperationSpan(
+ actualClientSession.getTransactionSpan(), operationContext, operation.getCommandName(), operation.getNamespace());
ReadBinding binding = getReadBinding(readPreference, actualClientSession, implicitSession);
@@ -469,7 +460,8 @@ public T execute(final WriteOperation operation, final ReadConcern readCo
ClientSession actualClientSession = getClientSession(session);
OperationContext operationContext = getOperationContext(actualClientSession, readConcern, operation.getCommandName())
.withSessionContext(new ClientSessionBinding.SyncClientSessionContext(actualClientSession, readConcern, isImplicitSession(session)));
- Span span = createOperationSpan(actualClientSession, operationContext, operation.getCommandName(), operation.getNamespace());
+ Span span = operationContext.getTracingManager().createOperationSpan(
+ actualClientSession.getTransactionSpan(), operationContext, operation.getCommandName(), operation.getNamespace());
WriteBinding binding = getWriteBinding(actualClientSession, isImplicitSession(session));
try {
@@ -587,48 +579,6 @@ ClientSession getClientSession(@Nullable final ClientSession clientSessionFromOp
return session;
}
- /**
- * Create a tracing span for the given operation, and set it on operation context.
- *
- * @param actualClientSession the session that the operation is part of
- * @param operationContext the operation context for the operation
- * @param commandName the name of the command
- * @param namespace the namespace of the command
- * @return the created span, or null if tracing is not enabled
- */
- @Nullable
- private Span createOperationSpan(final ClientSession actualClientSession, final OperationContext operationContext, final String commandName, final MongoNamespace namespace) {
- TracingManager tracingManager = operationContext.getTracingManager();
- if (tracingManager.isEnabled()) {
- TraceContext parentContext = null;
- TransactionSpan transactionSpan = actualClientSession.getTransactionSpan();
- if (transactionSpan != null) {
- parentContext = transactionSpan.getContext();
- }
- String name = commandName + " " + namespace.getDatabaseName() + (COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName())
- ? ""
- : "." + namespace.getCollectionName());
-
- KeyValues keyValues = KeyValues.of(
- SYSTEM.withValue("mongodb"),
- NAMESPACE.withValue(namespace.getDatabaseName()));
- if (!COMMAND_COLLECTION_NAME.equalsIgnoreCase(namespace.getCollectionName())) {
- keyValues = keyValues.and(COLLECTION.withValue(namespace.getCollectionName()));
- }
- keyValues = keyValues.and(OPERATION_NAME.withValue(commandName),
- OPERATION_SUMMARY.withValue(name));
-
- Span span = tracingManager.addSpan(name, parentContext, namespace);
-
- span.tagLowCardinality(keyValues);
-
- operationContext.setTracingSpan(span);
- return span;
-
- } else {
- return null;
- }
- }
}
private boolean isImplicitSession(@Nullable final ClientSession session) {
diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java
index 9ce58b1654f..7828ecde684 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java
@@ -56,11 +56,8 @@
import com.mongodb.internal.connection.TestCommandListener;
import com.mongodb.internal.connection.TestConnectionPoolListener;
import com.mongodb.test.FlakyTest;
-import org.bson.BsonArray;
-import org.bson.BsonBoolean;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
-import org.bson.BsonString;
import org.bson.BsonTimestamp;
import org.bson.Document;
import org.bson.codecs.BsonDocumentCodec;
@@ -256,7 +253,6 @@ public void testBlockingIterationMethodsChangeStream() {
assumeFalse(isAsync()); // Async change stream cursor is non-deterministic for cursor::next
BsonTimestamp startTime = new BsonTimestamp((int) Instant.now().getEpochSecond(), 0);
- collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions());
sleep(2000);
collectionHelper.insertDocuments(singletonList(BsonDocument.parse("{x: 1}")), WriteConcern.MAJORITY);
@@ -298,7 +294,6 @@ public void testBlockingIterationMethodsChangeStream() {
@FlakyTest(maxAttempts = 3)
public void testGridFSUploadViaOpenUploadStreamTimeout() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
collectionHelper.runAdminCommand("{"
+ " configureFailPoint: \"failCommand\","
@@ -306,7 +301,7 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() {
+ " data: {"
+ " failCommands: [\"insert\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 205)
+ + " blockTimeMS: " + 205
+ " }"
+ "}");
@@ -314,7 +309,7 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() {
filesCollectionHelper.create();
try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 200, TimeUnit.MILLISECONDS))) {
+ .timeout(200, TimeUnit.MILLISECONDS))) {
MongoDatabase database = client.getDatabase(namespace.getDatabaseName());
GridFSBucket gridFsBucket = createGridFsBucket(database, GRID_FS_BUCKET_NAME);
@@ -329,7 +324,6 @@ public void testGridFSUploadViaOpenUploadStreamTimeout() {
@Test
public void testAbortingGridFsUploadStreamTimeout() throws Throwable {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
collectionHelper.runAdminCommand("{"
+ " configureFailPoint: \"failCommand\","
@@ -337,7 +331,7 @@ public void testAbortingGridFsUploadStreamTimeout() throws Throwable {
+ " data: {"
+ " failCommands: [\"delete\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 305)
+ + " blockTimeMS: " + 320
+ " }"
+ "}");
@@ -345,7 +339,7 @@ public void testAbortingGridFsUploadStreamTimeout() throws Throwable {
filesCollectionHelper.create();
try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 300, TimeUnit.MILLISECONDS))) {
+ .timeout(300, TimeUnit.MILLISECONDS))) {
MongoDatabase database = client.getDatabase(namespace.getDatabaseName());
GridFSBucket gridFsBucket = createGridFsBucket(database, GRID_FS_BUCKET_NAME).withChunkSizeBytes(2);
@@ -360,7 +354,6 @@ public void testAbortingGridFsUploadStreamTimeout() throws Throwable {
@Test
public void testGridFsDownloadStreamTimeout() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
chunksCollectionHelper.create();
filesCollectionHelper.create();
@@ -382,18 +375,19 @@ public void testGridFsDownloadStreamTimeout() {
+ " metadata: {}"
+ "}"
)), WriteConcern.MAJORITY);
+
collectionHelper.runAdminCommand("{"
+ " configureFailPoint: \"failCommand\","
+ " mode: { skip: 1 },"
+ " data: {"
+ " failCommands: [\"find\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 95)
+ + " blockTimeMS: " + 500
+ " }"
+ "}");
try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder()
- .timeout(rtt + 100, TimeUnit.MILLISECONDS))) {
+ .timeout(300, TimeUnit.MILLISECONDS))) {
MongoDatabase database = client.getDatabase(namespace.getDatabaseName());
GridFSBucket gridFsBucket = createGridFsBucket(database, GRID_FS_BUCKET_NAME).withChunkSizeBytes(2);
@@ -401,7 +395,9 @@ public void testGridFsDownloadStreamTimeout() {
assertThrows(MongoOperationTimeoutException.class, downloadStream::read);
List events = commandListener.getCommandStartedEvents();
- List findCommands = events.stream().filter(e -> e.getCommandName().equals("find")).collect(Collectors.toList());
+ List findCommands = events.stream()
+ .filter(e -> e.getCommandName().equals("find"))
+ .collect(Collectors.toList());
assertEquals(2, findCommands.size());
assertEquals(gridFsFileNamespace.getCollectionName(), findCommands.get(0).getCommand().getString("find").getValue());
@@ -414,7 +410,7 @@ public void testGridFsDownloadStreamTimeout() {
@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("test8ServerSelectionArguments")
public void test8ServerSelection(final String connectionString) {
- int timeoutBuffer = 100; // 5 in spec, Java is slower
+ int timeoutBuffer = 150; // 5 in spec, Java is slower
// 1. Create a MongoClient
try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder()
.applyConnectionString(new ConnectionString(connectionString)))
@@ -450,7 +446,7 @@ public void test8ServerSelectionHandshake(final String ignoredTestName, final in
+ " data: {"
+ " failCommands: [\"saslContinue\"],"
+ " blockConnection: true,"
- + " blockTimeMS: 350"
+ + " blockTimeMS: 600"
+ " }"
+ "}");
@@ -466,7 +462,7 @@ public void test8ServerSelectionHandshake(final String ignoredTestName, final in
.insertOne(new Document("x", 1));
});
long elapsed = msElapsedSince(start);
- assertTrue(elapsed <= 310, "Took too long to time out, elapsedMS: " + elapsed);
+ assertTrue(elapsed <= 350, "Took too long to time out, elapsedMS: " + elapsed);
}
}
@@ -483,23 +479,23 @@ public void test9EndSessionClientTimeout() {
+ " data: {"
+ " failCommands: [\"abortTransaction\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + 150
+ + " blockTimeMS: " + 500
+ " }"
+ "}");
try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder().retryWrites(false)
- .timeout(100, TimeUnit.MILLISECONDS))) {
- MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
+ .timeout(250, TimeUnit.MILLISECONDS))) {
+ MongoDatabase database = mongoClient.getDatabase(namespace.getDatabaseName());
+ MongoCollection collection = database
.getCollection(namespace.getCollectionName());
try (ClientSession session = mongoClient.startSession()) {
session.startTransaction();
collection.insertOne(session, new Document("x", 1));
-
long start = System.nanoTime();
session.close();
- long elapsed = msElapsedSince(start) - postSessionCloseSleep();
- assertTrue(elapsed <= 150, "Took too long to time out, elapsedMS: " + elapsed);
+ long elapsed = msElapsedSince(start);
+ assertTrue(elapsed <= 300, "Took too long to time out, elapsedMS: " + elapsed);
}
}
CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() ->
@@ -520,7 +516,7 @@ public void test9EndSessionSessionTimeout() {
+ " data: {"
+ " failCommands: [\"abortTransaction\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + 150
+ + " blockTimeMS: " + 400
+ " }"
+ "}");
@@ -529,14 +525,14 @@ public void test9EndSessionSessionTimeout() {
.getCollection(namespace.getCollectionName());
try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder()
- .defaultTimeout(100, TimeUnit.MILLISECONDS).build())) {
+ .defaultTimeout(300, TimeUnit.MILLISECONDS).build())) {
session.startTransaction();
collection.insertOne(session, new Document("x", 1));
long start = System.nanoTime();
session.close();
- long elapsed = msElapsedSince(start) - postSessionCloseSleep();
- assertTrue(elapsed <= 150, "Took too long to time out, elapsedMS: " + elapsed);
+ long elapsed = msElapsedSince(start);
+ assertTrue(elapsed <= 400, "Took too long to time out, elapsedMS: " + elapsed);
}
}
CommandFailedEvent abortTransactionEvent = assertDoesNotThrow(() ->
@@ -563,11 +559,12 @@ public void test9EndSessionCustomTesEachOperationHasItsOwnTimeoutWithCommit() {
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName());
+ int defaultTimeout = 300;
try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder()
- .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) {
+ .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) {
session.startTransaction();
collection.insertOne(session, new Document("x", 1));
- sleep(200);
+ sleep(defaultTimeout);
assertDoesNotThrow(session::commitTransaction);
}
@@ -594,11 +591,12 @@ public void test9EndSessionCustomTesEachOperationHasItsOwnTimeoutWithAbort() {
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName());
+ int defaultTimeout = 300;
try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder()
- .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) {
+ .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) {
session.startTransaction();
collection.insertOne(session, new Document("x", 1));
- sleep(200);
+ sleep(defaultTimeout);
assertDoesNotThrow(session::close);
}
@@ -618,12 +616,12 @@ public void test10ConvenientTransactions() {
+ " data: {"
+ " failCommands: [\"insert\", \"abortTransaction\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + 150
+ + " blockTimeMS: " + 200
+ " }"
+ "}");
try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder()
- .timeout(100, TimeUnit.MILLISECONDS))) {
+ .timeout(150, TimeUnit.MILLISECONDS))) {
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName());
@@ -661,12 +659,13 @@ public void test10CustomTestWithTransactionUsesASingleTimeout() {
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName());
+ int defaultTimeout = 200;
try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder()
- .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) {
+ .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) {
assertThrows(MongoOperationTimeoutException.class,
() -> session.withTransaction(() -> {
collection.insertOne(session, new Document("x", 1));
- sleep(200);
+ sleep(defaultTimeout);
return true;
})
);
@@ -696,12 +695,13 @@ public void test10CustomTestWithTransactionUsesASingleTimeoutWithLock() {
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName());
+ int defaultTimeout = 200;
try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder()
- .defaultTimeout(200, TimeUnit.MILLISECONDS).build())) {
+ .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS).build())) {
assertThrows(MongoOperationTimeoutException.class,
() -> session.withTransaction(() -> {
collection.insertOne(session, new Document("x", 1));
- sleep(200);
+ sleep(defaultTimeout);
return true;
})
);
@@ -710,7 +710,7 @@ public void test10CustomTestWithTransactionUsesASingleTimeoutWithLock() {
}
@DisplayName("11. Multi-batch bulkWrites")
- @Test
+ @FlakyTest(maxAttempts = 3)
@SuppressWarnings("try")
protected void test11MultiBatchBulkWrites() throws InterruptedException {
assumeTrue(serverVersionAtLeast(8, 0));
@@ -718,12 +718,18 @@ protected void test11MultiBatchBulkWrites() throws InterruptedException {
// a workaround for https://jira.mongodb.org/browse/DRIVERS-2997, remove this block when the aforementioned bug is fixed
client.getDatabase(namespace.getDatabaseName()).drop();
}
- BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand"))
- .append("mode", new BsonDocument("times", new BsonInt32(2)))
- .append("data", new BsonDocument("failCommands", new BsonArray(singletonList(new BsonString("bulkWrite"))))
- .append("blockConnection", BsonBoolean.TRUE)
- .append("blockTimeMS", new BsonInt32(2020)));
- try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().timeout(4000, TimeUnit.MILLISECONDS));
+ BsonDocument failPointDocument = BsonDocument.parse("{"
+ + " configureFailPoint: \"failCommand\","
+ + " mode: { times: 2},"
+ + " data: {"
+ + " failCommands: [\"bulkWrite\" ],"
+ + " blockConnection: true,"
+ + " blockTimeMS: " + 2020
+ + " }"
+ + "}");
+
+ long timeout = 4000;
+ try (MongoClient client = createMongoClient(getMongoClientSettingsBuilder().timeout(timeout, TimeUnit.MILLISECONDS));
FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) {
MongoDatabase db = client.getDatabase(namespace.getDatabaseName());
db.drop();
@@ -746,8 +752,8 @@ protected void test11MultiBatchBulkWrites() throws InterruptedException {
* Not a prose spec test. However, it is additional test case for better coverage.
*/
@Test
- @DisplayName("Should ignore wTimeoutMS of WriteConcern to initial and subsequent commitTransaction operations")
- public void shouldIgnoreWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTransactionOperations() {
+ @DisplayName("Should not include wTimeoutMS of WriteConcern to initial and subsequent commitTransaction operations")
+ public void shouldNotIncludeWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTransactionOperations() {
assumeTrue(serverVersionAtLeast(4, 4));
assumeFalse(isStandalone());
@@ -755,14 +761,15 @@ public void shouldIgnoreWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTran
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
.getCollection(namespace.getCollectionName());
+ int defaultTimeout = 200;
try (ClientSession session = mongoClient.startSession(ClientSessionOptions.builder()
- .defaultTimeout(200, TimeUnit.MILLISECONDS)
+ .defaultTimeout(defaultTimeout, TimeUnit.MILLISECONDS)
.build())) {
session.startTransaction(TransactionOptions.builder()
.writeConcern(WriteConcern.ACKNOWLEDGED.withWTimeout(100, TimeUnit.MILLISECONDS))
.build());
collection.insertOne(session, new Document("x", 1));
- sleep(200);
+ sleep(defaultTimeout);
assertDoesNotThrow(session::commitTransaction);
//repeat commit.
@@ -805,12 +812,12 @@ public void shouldIgnoreWaitQueueTimeoutMSWhenTimeoutMsIsSet() {
+ " data: {"
+ " failCommands: [\"find\" ],"
+ " blockConnection: true,"
- + " blockTimeMS: " + 300
+ + " blockTimeMS: " + 450
+ " }"
+ "}");
- executor.submit(() -> collection.find().first());
- sleep(100);
+ executor.execute(() -> collection.find().first());
+ sleep(150);
//when && then
assertDoesNotThrow(() -> collection.find().first());
@@ -844,7 +851,7 @@ public void shouldThrowOperationTimeoutExceptionWhenConnectionIsNotAvailableAndT
+ " }"
+ "}");
- executor.submit(() -> collection.withTimeout(0, TimeUnit.MILLISECONDS).find().first());
+ executor.execute(() -> collection.withTimeout(0, TimeUnit.MILLISECONDS).find().first());
sleep(100);
//when && then
@@ -863,7 +870,7 @@ public void shouldUseWaitQueueTimeoutMSWhenTimeoutIsNotSet() {
//given
try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder()
.applyToConnectionPoolSettings(builder -> builder
- .maxWaitTime(100, TimeUnit.MILLISECONDS)
+ .maxWaitTime(20, TimeUnit.MILLISECONDS)
.maxSize(1)
))) {
MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName())
@@ -875,12 +882,12 @@ public void shouldUseWaitQueueTimeoutMSWhenTimeoutIsNotSet() {
+ " data: {"
+ " failCommands: [\"find\" ],"
+ " blockConnection: true,"
- + " blockTimeMS: " + 300
+ + " blockTimeMS: " + 400
+ " }"
+ "}");
- executor.submit(() -> collection.find().first());
- sleep(100);
+ executor.execute(() -> collection.find().first());
+ sleep(200);
//when & then
assertThrows(MongoTimeoutException.class, () -> collection.find().first());
@@ -896,7 +903,6 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkErrorWhenTimeoutMsIsN
assumeTrue(serverVersionAtLeast(4, 4));
assumeTrue(isLoadBalanced());
- long rtt = ClusterFixture.getPrimaryRTT();
collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions());
collectionHelper.insertDocuments(new Document(), new Document());
collectionHelper.runAdminCommand("{"
@@ -905,7 +911,7 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkErrorWhenTimeoutMsIsN
+ " data: {"
+ " failCommands: [\"getMore\" ],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 600)
+ + " blockTimeMS: " + 600
+ " }"
+ "}");
@@ -943,7 +949,6 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkError() {
assumeTrue(serverVersionAtLeast(4, 4));
assumeTrue(isLoadBalanced());
- long rtt = ClusterFixture.getPrimaryRTT();
collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions());
collectionHelper.insertDocuments(new Document(), new Document());
collectionHelper.runAdminCommand("{"
@@ -952,7 +957,7 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkError() {
+ " data: {"
+ " failCommands: [\"getMore\" ],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 600)
+ + " blockTimeMS: " + 600
+ " }"
+ "}");
@@ -1040,11 +1045,16 @@ public void shouldUseConnectTimeoutMsWhenEstablishingConnectionInBackground() {
+ " data: {"
+ " failCommands: [\"hello\", \"isMaster\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + 500
+ + " blockTimeMS: " + 500 + ","
+ // The appName is unique to prevent this failpoint from affecting ClusterFixture's ServerMonitor.
+ // Without the appName, ClusterFixture's heartbeats would be blocked, polluting RTT measurements with 500ms values,
+ // which would cause flakiness in other prose tests that use ClusterFixture.getPrimaryRTT() for timeout adjustments.
+ + " appName: \"connectTimeoutBackgroundTest\""
+ " }"
+ "}");
try (MongoClient ignored = createMongoClient(getMongoClientSettingsBuilder()
+ .applicationName("connectTimeoutBackgroundTest")
.applyToConnectionPoolSettings(builder -> builder.minSize(1))
// Use a very short timeout to ensure that the connection establishment will fail on the first handshake command.
.timeout(10, TimeUnit.MILLISECONDS))) {
@@ -1075,9 +1085,10 @@ private static Stream test8ServerSelectionArguments() {
}
private static Stream test8ServerSelectionHandshakeArguments() {
+
return Stream.of(
- Arguments.of("timeoutMS honored for connection handshake commands if it's lower than serverSelectionTimeoutMS", 200, 300),
- Arguments.of("serverSelectionTimeoutMS honored for connection handshake commands if it's lower than timeoutMS", 300, 200)
+ Arguments.of("timeoutMS honored for connection handshake commands if it's lower than serverSelectionTimeoutMS", 200, 500),
+ Arguments.of("serverSelectionTimeoutMS honored for connection handshake commands if it's lower than timeoutMS", 500, 200)
);
}
@@ -1088,7 +1099,8 @@ protected MongoNamespace generateNamespace() {
protected MongoClientSettings.Builder getMongoClientSettingsBuilder() {
commandListener.reset();
- return Fixture.getMongoClientSettingsBuilder()
+ MongoClientSettings.Builder mongoClientSettingsBuilder = Fixture.getMongoClientSettingsBuilder();
+ return mongoClientSettingsBuilder
.readConcern(ReadConcern.MAJORITY)
.writeConcern(WriteConcern.MAJORITY)
.readPreference(ReadPreference.primary())
@@ -1103,6 +1115,9 @@ public void setUp() {
gridFsChunksNamespace = new MongoNamespace(getDefaultDatabaseName(), GRID_FS_BUCKET_NAME + ".chunks");
collectionHelper = new CollectionHelper<>(new BsonDocumentCodec(), namespace);
+ // in some test collection might not have been created yet, thus dropping it in afterEach will throw an error
+ collectionHelper.create();
+
filesCollectionHelper = new CollectionHelper<>(new BsonDocumentCodec(), gridFsFileNamespace);
chunksCollectionHelper = new CollectionHelper<>(new BsonDocumentCodec(), gridFsChunksNamespace);
commandListener = new TestCommandListener();
@@ -1112,10 +1127,13 @@ public void setUp() {
public void tearDown() throws InterruptedException {
ClusterFixture.disableFailPoint(FAIL_COMMAND_NAME);
if (collectionHelper != null) {
+ // Due to testing abortTransaction via failpoint, there may be open transactions
+ // after the test finishes, thus drop() command hangs for 60 seconds until transaction
+ // is automatically rolled back.
+ collectionHelper.runAdminCommand("{killAllSessions: []}");
collectionHelper.drop();
filesCollectionHelper.drop();
chunksCollectionHelper.drop();
- commandListener.reset();
try {
ServerHelper.checkPool(getPrimary());
} catch (InterruptedException e) {
@@ -1139,7 +1157,7 @@ private MongoClient createMongoClient(final MongoClientSettings.Builder builder)
return createMongoClient(builder.build());
}
- private long msElapsedSince(final long t1) {
+ protected long msElapsedSince(final long t1) {
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1);
}
diff --git a/driver-sync/src/test/functional/com/mongodb/observability/MicrometerProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractMicrometerProseTest.java
similarity index 57%
rename from driver-sync/src/test/functional/com/mongodb/observability/MicrometerProseTest.java
rename to driver-sync/src/test/functional/com/mongodb/client/AbstractMicrometerProseTest.java
index d4239aa44d7..746b0ffd8d9 100644
--- a/driver-sync/src/test/functional/com/mongodb/observability/MicrometerProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractMicrometerProseTest.java
@@ -14,44 +14,59 @@
* limitations under the License.
*/
-package com.mongodb.observability;
+package com.mongodb.client;
import com.mongodb.MongoClientSettings;
-import com.mongodb.client.Fixture;
-import com.mongodb.client.MongoClient;
-import com.mongodb.client.MongoClients;
-import com.mongodb.client.MongoCollection;
-import com.mongodb.client.MongoDatabase;
+import com.mongodb.lang.Nullable;
+import com.mongodb.observability.ObservabilitySettings;
+import com.mongodb.client.observability.SpanTree;
+import com.mongodb.client.observability.SpanTree.SpanNode;
import com.mongodb.observability.micrometer.MicrometerObservabilitySettings;
import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.tracing.exporter.FinishedSpan;
import io.micrometer.tracing.test.reporter.inmemory.InMemoryOtelSetup;
import org.bson.Document;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
import static com.mongodb.ClusterFixture.getDefaultDatabaseName;
+import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder;
+import static com.mongodb.internal.observability.micrometer.MongodbObservation.HighCardinalityKeyNames.QUERY_TEXT;
import static com.mongodb.internal.observability.micrometer.TracingManager.ENV_OBSERVABILITY_ENABLED;
import static com.mongodb.internal.observability.micrometer.TracingManager.ENV_OBSERVABILITY_QUERY_TEXT_MAX_LENGTH;
-import static com.mongodb.internal.observability.micrometer.MongodbObservation.HighCardinalityKeyNames.QUERY_TEXT;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
- * Implementation of the prose tests for Micrometer OpenTelemetry tracing.
+ * Implementation of the prose tests
+ * for Micrometer OpenTelemetry tracing.
*/
-public class MicrometerProseTest {
+public abstract class AbstractMicrometerProseTest {
private final ObservationRegistry observationRegistry = ObservationRegistry.create();
private InMemoryOtelSetup memoryOtelSetup;
private InMemoryOtelSetup.Builder.OtelBuildingBlocks inMemoryOtel;
private static String previousEnvVarMdbTracingEnabled;
private static String previousEnvVarMdbQueryTextLength;
+ protected abstract MongoClient createMongoClient(MongoClientSettings settings);
+
@BeforeAll
static void beforeAll() {
// preserve original env var values
@@ -77,18 +92,19 @@ void tearDown() {
memoryOtelSetup.close();
}
+ @DisplayName("Test 1: Tracing Enable/Disable via Environment Variable")
@Test
void testControlOtelInstrumentationViaEnvironmentVariable() throws Exception {
setEnv(ENV_OBSERVABILITY_ENABLED, "false");
// don't enable command payload by default
- MongoClientSettings clientSettings = Fixture.getMongoClientSettingsBuilder()
+ MongoClientSettings clientSettings = getMongoClientSettingsBuilder()
.observabilitySettings(ObservabilitySettings.micrometerBuilder()
.observationRegistry(observationRegistry)
.build())
.build();
- try (MongoClient client = MongoClients.create(clientSettings)) {
+ try (MongoClient client = createMongoClient(clientSettings)) {
MongoDatabase database = client.getDatabase(getDefaultDatabaseName());
MongoCollection collection = database.getCollection("test");
collection.find().first();
@@ -98,7 +114,7 @@ void testControlOtelInstrumentationViaEnvironmentVariable() throws Exception {
}
setEnv(ENV_OBSERVABILITY_ENABLED, "true");
- try (MongoClient client = MongoClients.create(clientSettings)) {
+ try (MongoClient client = createMongoClient(clientSettings)) {
MongoDatabase database = client.getDatabase(getDefaultDatabaseName());
MongoCollection collection = database.getCollection("test");
collection.find().first();
@@ -114,6 +130,7 @@ void testControlOtelInstrumentationViaEnvironmentVariable() throws Exception {
}
}
+ @DisplayName("Test 2: Command Payload Emission via Environment Variable")
@Test
void testControlCommandPayloadViaEnvironmentVariable() throws Exception {
setEnv(ENV_OBSERVABILITY_QUERY_TEXT_MAX_LENGTH, "42");
@@ -123,13 +140,13 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception {
.maxQueryTextLength(75) // should be overridden by env var
.build();
- MongoClientSettings clientSettings = Fixture.getMongoClientSettingsBuilder()
+ MongoClientSettings clientSettings = getMongoClientSettingsBuilder()
.observabilitySettings(ObservabilitySettings.micrometerBuilder()
.applySettings(settings)
.build()).
build();
- try (MongoClient client = MongoClients.create(clientSettings)) {
+ try (MongoClient client = createMongoClient(clientSettings)) {
MongoDatabase database = client.getDatabase(getDefaultDatabaseName());
MongoCollection collection = database.getCollection("test");
collection.find().first();
@@ -153,14 +170,14 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception {
setEnv(ENV_OBSERVABILITY_QUERY_TEXT_MAX_LENGTH, null); // Unset the environment variable
- clientSettings = Fixture.getMongoClientSettingsBuilder()
+ clientSettings = getMongoClientSettingsBuilder()
.observabilitySettings(ObservabilitySettings.micrometerBuilder()
.observationRegistry(observationRegistry)
.maxQueryTextLength(42) // setting this will not matter since env var is not set and enableCommandPayloadTracing is false
.build())
.build();
- try (MongoClient client = MongoClients.create(clientSettings)) {
+ try (MongoClient client = createMongoClient(clientSettings)) {
MongoDatabase database = client.getDatabase(getDefaultDatabaseName());
MongoCollection collection = database.getCollection("test");
collection.find().first();
@@ -182,11 +199,11 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception {
.maxQueryTextLength(7) // setting this will be used;
.build();
- clientSettings = Fixture.getMongoClientSettingsBuilder()
+ clientSettings = getMongoClientSettingsBuilder()
.observabilitySettings(settings)
.build();
- try (MongoClient client = MongoClients.create(clientSettings)) {
+ try (MongoClient client = createMongoClient(clientSettings)) {
MongoDatabase database = client.getDatabase(getDefaultDatabaseName());
MongoCollection collection = database.getCollection("test");
collection.find().first();
@@ -200,8 +217,108 @@ void testControlCommandPayloadViaEnvironmentVariable() throws Exception {
}
}
+ /**
+ * Verifies that concurrent operations produce isolated span trees with no cross-contamination.
+ * Each operation should get its own trace ID, correct parent-child linkage, and collection-specific tags,
+ * even when multiple operations execute simultaneously on the same client.
+ *
+ * This test is not from the specification.
+ */
+ @Test
+ void testConcurrentOperationsHaveSeparateSpans() throws Exception {
+ setEnv(ENV_OBSERVABILITY_ENABLED, "true");
+ int nbrConcurrentOps = 10;
+ MongoClientSettings clientSettings = getMongoClientSettingsBuilder()
+ .applyToConnectionPoolSettings(pool -> pool.maxSize(nbrConcurrentOps))
+ .observabilitySettings(ObservabilitySettings.micrometerBuilder()
+ .observationRegistry(observationRegistry)
+ .build())
+ .build();
+
+ try (MongoClient client = createMongoClient(clientSettings)) {
+ MongoDatabase database = client.getDatabase(getDefaultDatabaseName());
+
+ // Warm up connections so the concurrent phase doesn't include handshake overhead
+ for (int i = 0; i < nbrConcurrentOps; i++) {
+ database.getCollection("concurrent_test_" + i).find().first();
+ }
+ // Clear spans from warm-up before the actual concurrent test
+ memoryOtelSetup.close();
+ memoryOtelSetup = InMemoryOtelSetup.builder().register(observationRegistry);
+ inMemoryOtel = memoryOtelSetup.getBuildingBlocks();
+
+ ExecutorService executor = Executors.newFixedThreadPool(nbrConcurrentOps);
+ try {
+ CountDownLatch startLatch = new CountDownLatch(1);
+ List> futures = new ArrayList<>();
+
+ for (int i = 0; i < nbrConcurrentOps; i++) {
+ String collectionName = "concurrent_test_" + i;
+ futures.add(executor.submit(() -> {
+ try {
+ startLatch.await(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ return;
+ }
+ database.getCollection(collectionName).find().first();
+ }));
+ }
+
+ // Release all threads simultaneously to maximize concurrency
+ startLatch.countDown();
+
+ for (Future> future : futures) {
+ future.get(30, TimeUnit.SECONDS);
+ }
+ } finally {
+ executor.shutdown();
+ }
+
+ List allSpans = inMemoryOtel.getFinishedSpans();
+
+ // Each find() produces 2 spans: operation-level span + command-level span
+ assertEquals(nbrConcurrentOps * 2, allSpans.size(),
+ "Each concurrent operation should produce exactly 2 spans (operation + command).");
+
+ // Verify trace isolation: each independent operation should get its own traceId
+ Map> spansByTrace = allSpans.stream()
+ .collect(Collectors.groupingBy(FinishedSpan::getTraceId));
+ assertEquals(nbrConcurrentOps, spansByTrace.size(),
+ "Each concurrent operation should have its own distinct trace ID.");
+
+ // Use SpanTree to validate parent-child structure built from spanId/parentId linkage
+ SpanTree spanTree = SpanTree.from(allSpans);
+ List roots = spanTree.getRoots();
+
+ // Each operation span is a root; its command span is a child
+ assertEquals(nbrConcurrentOps, roots.size(),
+ "SpanTree should have one root per concurrent operation.");
+
+ Set observedCollections = new HashSet<>();
+ for (SpanNode root : roots) {
+ assertTrue(root.getName().startsWith("find " + getDefaultDatabaseName() + ".concurrent_test_"),
+ "Root span should be an operation span, but was: " + root.getName());
+
+ assertEquals(1, root.getChildren().size(),
+ "Each operation span should have exactly one child (command span).");
+ assertEquals("find", root.getChildren().get(0).getName(),
+ "Child span should be the command span 'find'.");
+
+ // Extract collection name from the operation span name to verify no cross-contamination
+ String collectionName = root.getName().substring(
+ ("find " + getDefaultDatabaseName() + ".").length());
+ assertTrue(observedCollections.add(collectionName),
+ "Each operation should target a unique collection, but found duplicate: " + collectionName);
+ }
+
+ assertEquals(nbrConcurrentOps, observedCollections.size(),
+ "All " + nbrConcurrentOps + " concurrent operations should be represented in distinct traces.");
+ }
+ }
+
@SuppressWarnings("unchecked")
- private static void setEnv(final String key, final String value) throws Exception {
+ private static void setEnv(final String key, @Nullable final String value) throws Exception {
// Get the unmodifiable Map from System.getenv()
Map env = System.getenv();
diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java
index 3682bd64ff0..910cf57edfd 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java
@@ -93,7 +93,7 @@ public void shouldCreateServerSessionOnlyAfterConnectionCheckout() throws Interr
.addCommandListener(new CommandListener() {
@Override
public void commandStarted(final CommandStartedEvent event) {
- lsidSet.add(event.getCommand().getDocument("lsid"));
+ lsidSet.add(event.getCommand().getDocument("lsid").clone());
}
})
.build())) {
diff --git a/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java
index dd45bc8ae2c..04303833bf5 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/csot/AbstractClientSideOperationsEncryptionTimeoutProseTest.java
@@ -93,14 +93,13 @@ public abstract class AbstractClientSideOperationsEncryptionTimeoutProseTest {
@Test
void shouldThrowOperationTimeoutExceptionWhenCreateDataKey() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
Map> kmsProviders = new HashMap<>();
Map localProviderMap = new HashMap<>();
localProviderMap.put("key", Base64.getDecoder().decode(MASTER_KEY));
kmsProviders.put("local", localProviderMap);
- try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 100))) {
+ try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(100))) {
keyVaultCollectionHelper.runAdminCommand("{"
+ " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\","
@@ -108,7 +107,7 @@ void shouldThrowOperationTimeoutExceptionWhenCreateDataKey() {
+ " data: {"
+ " failCommands: [\"insert\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 100)
+ + " blockTimeMS: " + 100
+ " }"
+ "}");
@@ -126,9 +125,8 @@ void shouldThrowOperationTimeoutExceptionWhenCreateDataKey() {
@Test
void shouldThrowOperationTimeoutExceptionWhenEncryptData() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
- try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 150))) {
+ try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(150))) {
clientEncryption.createDataKey("local");
@@ -138,7 +136,7 @@ void shouldThrowOperationTimeoutExceptionWhenEncryptData() {
+ " data: {"
+ " failCommands: [\"find\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 150)
+ + " blockTimeMS: " + 150
+ " }"
+ "}");
@@ -160,10 +158,9 @@ void shouldThrowOperationTimeoutExceptionWhenEncryptData() {
@Test
void shouldThrowOperationTimeoutExceptionWhenDecryptData() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
BsonBinary encrypted;
- try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 400))) {
+ try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(400))) {
clientEncryption.createDataKey("local");
BsonBinary dataKey = clientEncryption.createDataKey("local");
EncryptOptions encryptOptions = new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic");
@@ -171,14 +168,14 @@ void shouldThrowOperationTimeoutExceptionWhenDecryptData() {
encrypted = clientEncryption.encrypt(new BsonString("hello"), encryptOptions);
}
- try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(rtt + 400))) {
+ try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder(400))) {
keyVaultCollectionHelper.runAdminCommand("{"
- + " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\","
+ + " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\","
+ " mode: { times: 1 },"
+ " data: {"
+ " failCommands: [\"find\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 500)
+ + " blockTimeMS: " + 500
+ " }"
+ "}");
commandListener.reset();
@@ -197,8 +194,7 @@ void shouldThrowOperationTimeoutExceptionWhenDecryptData() {
@Test
void shouldDecreaseOperationTimeoutForSubsequentOperations() {
assumeTrue(serverVersionAtLeast(4, 4));
- long rtt = ClusterFixture.getPrimaryRTT();
- long initialTimeoutMS = rtt + 2500;
+ long initialTimeoutMS = 2500;
keyVaultCollectionHelper.runAdminCommand("{"
+ " configureFailPoint: \"" + FAIL_COMMAND_NAME + "\","
@@ -206,7 +202,7 @@ void shouldDecreaseOperationTimeoutForSubsequentOperations() {
+ " data: {"
+ " failCommands: [\"insert\", \"find\", \"listCollections\"],"
+ " blockConnection: true,"
- + " blockTimeMS: " + (rtt + 10)
+ + " blockTimeMS: " + 10
+ " }"
+ "}");
@@ -272,8 +268,7 @@ void shouldDecreaseOperationTimeoutForSubsequentOperations() {
void shouldThrowTimeoutExceptionWhenCreateEncryptedCollection(final String commandToTimeout) {
assumeTrue(serverVersionAtLeast(7, 0));
//given
- long rtt = ClusterFixture.getPrimaryRTT();
- long initialTimeoutMS = rtt + 200;
+ long initialTimeoutMS = 200;
try (ClientEncryption clientEncryption = createClientEncryption(getClientEncryptionSettingsBuilder()
.timeout(initialTimeoutMS, MILLISECONDS))) {
diff --git a/driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java
new file mode 100644
index 00000000000..38bd4350b1d
--- /dev/null
+++ b/driver-sync/src/test/functional/com/mongodb/client/observability/MicrometerProseTest.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.client.observability;
+
+import com.mongodb.MongoClientSettings;
+import com.mongodb.client.AbstractMicrometerProseTest;
+import com.mongodb.client.MongoClient;
+import com.mongodb.client.MongoClients;
+
+/**
+ * Sync driver implementation of the Micrometer prose tests.
+ */
+public class MicrometerProseTest extends AbstractMicrometerProseTest {
+ @Override
+ protected MongoClient createMongoClient(final MongoClientSettings settings) {
+ return MongoClients.create(settings);
+ }
+}
diff --git a/driver-sync/src/test/functional/com/mongodb/observability/SpanTree.java b/driver-sync/src/test/functional/com/mongodb/client/observability/SpanTree.java
similarity index 98%
rename from driver-sync/src/test/functional/com/mongodb/observability/SpanTree.java
rename to driver-sync/src/test/functional/com/mongodb/client/observability/SpanTree.java
index aa6697bf3ad..7d3bff3224d 100644
--- a/driver-sync/src/test/functional/com/mongodb/observability/SpanTree.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/observability/SpanTree.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package com.mongodb.observability;
+package com.mongodb.client.observability;
import com.mongodb.lang.Nullable;
import io.micrometer.tracing.exporter.FinishedSpan;
@@ -204,6 +204,10 @@ private static void assertValid(final SpanNode reportedNode, final SpanNode expe
}
}
+ public List getRoots() {
+ return Collections.unmodifiableList(roots);
+ }
+
@Override
public String toString() {
return "SpanTree{"
diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java
index cf003078f04..602838cff0c 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java
@@ -28,7 +28,7 @@
import com.mongodb.client.gridfs.GridFSBucket;
import com.mongodb.client.model.Filters;
import com.mongodb.client.test.CollectionHelper;
-import com.mongodb.observability.SpanTree;
+import com.mongodb.client.observability.SpanTree;
import com.mongodb.client.unified.UnifiedTestModifications.TestDef;
import com.mongodb.client.vault.ClientEncryption;
import com.mongodb.connection.ClusterDescription;
@@ -311,6 +311,9 @@ public void cleanUp() {
if (testDef != null) {
postCleanUp(testDef);
}
+ // Ask the JVM to run garbage collection.
+ // This should help with Netty's leak detection
+ System.gc();
}
protected void postCleanUp(final TestDef testDef) {
diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java
index 2225f837ec5..328c8298b6c 100644
--- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java
+++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestModifications.java
@@ -63,6 +63,25 @@ public static void applyCustomizations(final TestDef def) {
.file("client-side-encryption/tests/unified", "client bulkWrite with queryable encryption");
// client-side-operation-timeout (CSOT)
+ def.retry("Unified CSOT tests do not account for RTT which varies in TLS vs non-TLS runs")
+ .whenFailureContains("timeout")
+ .test("client-side-operations-timeout",
+ "timeoutMS behaves correctly for non-tailable cursors",
+ "timeoutMS is refreshed for getMore if timeoutMode is iteration - success");
+
+ def.retry("Unified CSOT tests do not account for RTT which varies in TLS vs non-TLS runs")
+ .whenFailureContains("timeout")
+ .test("client-side-operations-timeout",
+ "timeoutMS behaves correctly for tailable non-awaitData cursors",
+ "timeoutMS is refreshed for getMore - success");
+
+ def.retry("Unified CSOT tests do not account for RTT which varies in TLS vs non-TLS runs")
+ .whenFailureContains("timeout")
+ .test("client-side-operations-timeout",
+ "timeoutMS behaves correctly for tailable non-awaitData cursors",
+ "timeoutMS is refreshed for getMore - success");
+
+ //TODO-invistigate
/*
As to the background connection pooling section:
timeoutMS set at the MongoClient level MUST be used as the timeout for all commands sent as part of the handshake.
diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml
index 8a08c34f213..b5e561c7f7e 100644
--- a/gradle/libs.versions.toml
+++ b/gradle/libs.versions.toml
@@ -18,10 +18,10 @@ aws-sdk-v2 = "2.30.31"
graal-sdk = "24.0.0"
jna = "5.11.0"
jnr-unixsocket = "0.38.17"
-netty-bom = "4.1.87.Final"
+netty-bom = "4.2.9.Final"
project-reactor-bom = "2022.0.0"
reactive-streams = "1.0.4"
-snappy = "1.1.10.3"
+snappy = "1.1.10.4"
zstd = "1.5.5-3"
jetbrains-annotations = "26.0.2"
micrometer-tracing = "1.6.0-M3" # This version has a fix for https://github.com/micrometer-metrics/tracing/issues/1092
diff --git a/testing/resources/specifications b/testing/resources/specifications
index de684cf1ef9..bb9dddd8176 160000
--- a/testing/resources/specifications
+++ b/testing/resources/specifications
@@ -1 +1 @@
-Subproject commit de684cf1ef9feede71d358cbb7d253840f1a8647
+Subproject commit bb9dddd8176eddbb9424f9bebedfe8c6bbf28c3a