diff --git a/.agents/references/testing-guide.md b/.agents/references/testing-guide.md index a7b5c128a2c..4140bfc6e33 100644 --- a/.agents/references/testing-guide.md +++ b/.agents/references/testing-guide.md @@ -50,6 +50,9 @@ description: Testing frameworks, conventions, and commands for the MongoDB Java # Scala tests (all versions) ./gradlew scalaCheck + +# Custom test JVM heap size (default: 4g) +./gradlew :driver-core:test -PtestMaxHeapSize=1g ``` ## Module-Specific Notes diff --git a/.evergreen/run-fle-on-demand-credential-test.sh b/.evergreen/run-fle-on-demand-credential-test.sh index 6445b53c666..2a45db4d503 100755 --- a/.evergreen/run-fle-on-demand-credential-test.sh +++ b/.evergreen/run-fle-on-demand-credential-test.sh @@ -20,6 +20,11 @@ if ! which java ; then sudo apt install openjdk-17-jdk -y fi +if ! which gpg ; then + echo "Installing gpg..." + sudo apt install gnupg -y +fi + export PROVIDER=${PROVIDER} echo "Running gradle version" diff --git a/.evergreen/run-mongodb-aws-ecs-test.sh b/.evergreen/run-mongodb-aws-ecs-test.sh index 63e4232839b..eee2999813b 100755 --- a/.evergreen/run-mongodb-aws-ecs-test.sh +++ b/.evergreen/run-mongodb-aws-ecs-test.sh @@ -36,6 +36,11 @@ if ! which git ; then apt install git -y fi +if ! which gpg ; then + echo "Installing gpg..." + sudo apt install gnupg -y +fi + cd src RELATIVE_DIR_PATH="$(dirname "${BASH_SOURCE:-$0}")" diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 778b8962c09..e96f05a6f4a 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -41,6 +41,11 @@ if ! which java ; then echo "Installed java." fi +if ! which gpg ; then + echo "Installing gpg..." + sudo apt install gnupg -y +fi + which java export OIDC_TESTS_ENABLED=true @@ -49,18 +54,21 @@ TO_REPLACE="mongodb://" REPLACEMENT="mongodb://$OIDC_ADMIN_USER:$OIDC_ADMIN_PWD@" ADMIN_URI=${MONGODB_URI/$TO_REPLACE/$REPLACEMENT} +# Limit memory for the containers +GRADLE_EXTRA_VARS="--no-daemon -Dorg.gradle.jvmargs=-Xmx512m -PtestMaxHeapSize=1g" + echo "Running gradle version" -./gradlew -version +./gradlew $GRADLE_EXTRA_VARS -version echo "Running gradle classes compile for driver-sync and driver-reactive-streams: ${FULL_DESCRIPTION}" -./gradlew --parallel --stacktrace --info \ +./gradlew $GRADLE_EXTRA_VARS --parallel --stacktrace --info \ driver-sync:classes driver-reactive-streams:classes echo "Running OIDC authentication tests against driver-sync: ${FULL_DESCRIPTION}" -./gradlew -Dorg.mongodb.test.uri="$ADMIN_URI" \ +./gradlew $GRADLE_EXTRA_VARS -Dorg.mongodb.test.uri="$ADMIN_URI" \ --stacktrace --debug --info \ driver-sync:test --tests OidcAuthenticationProseTests --tests UnifiedAuthTest echo "Running OIDC authentication tests against driver-reactive-streams: ${FULL_DESCRIPTION}" -./gradlew -Dorg.mongodb.test.uri="$ADMIN_URI" \ +./gradlew $GRADLE_EXTRA_VARS -Dorg.mongodb.test.uri="$ADMIN_URI" \ --stacktrace --debug --info driver-reactive-streams:test --tests OidcAuthenticationAsyncProseTests diff --git a/.evergreen/static-checks.sh b/.evergreen/static-checks.sh index 1accf5c1684..12522a41661 100755 --- a/.evergreen/static-checks.sh +++ b/.evergreen/static-checks.sh @@ -13,3 +13,6 @@ echo "Compiling JVM drivers" ./gradlew -version ./gradlew -PxmlReports.enabled=true --info -x test -x integrationTest -x spotlessApply clean check scalaCheck jar testClasses docs + +echo "Running OSGi bundle resolution tests" +./gradlew -PxmlReports.enabled=true --info :testing:osgi-test:check diff --git a/.gitignore b/.gitignore index 8cc6afec7da..02e9bdf0e21 100644 --- a/.gitignore +++ b/.gitignore @@ -59,8 +59,12 @@ local.properties # per-developer Agent overrides .AGENTS.md .claude/settings.local.json +.claude/docs CLAUDE.local.md +# Eclipse annotation processing +driver-benchmarks/.factorypath + # bin build directories **/bin diff --git a/bson/src/main/org/bson/BsonBinaryWriter.java b/bson/src/main/org/bson/BsonBinaryWriter.java index 20e73d97d44..54b62be8bd0 100644 --- a/bson/src/main/org/bson/BsonBinaryWriter.java +++ b/bson/src/main/org/bson/BsonBinaryWriter.java @@ -334,6 +334,26 @@ public void pipe(final BsonReader reader) { pipeDocument(reader, null); } + /** + * Pipes an encoded BSON document from the given byte array to this writer. + * + * @param bytes the byte array containing the encoded BSON document + * @param offset the offset into the byte array + * @param length the length of the encoded BSON document + * @since 5.8 + */ + public void pipe(final byte[] bytes, final int offset, final int length) { + notNull("bytes", bytes); + checkMinDocumentSize(length); + if (getState() == State.VALUE) { + bsonOutput.writeByte(BsonType.DOCUMENT.getValue()); + writeCurrentName(); + } + int pipedDocumentStartPosition = bsonOutput.getPosition(); + bsonOutput.writeBytes(bytes, offset, length); + completePipeDocument(pipedDocumentStartPosition); + } + @Override public void pipe(final BsonReader reader, final List extraElements) { notNull("reader", reader); @@ -350,14 +370,10 @@ private void pipeDocument(final BsonReader reader, final List extra } BsonInput bsonInput = binaryReader.getBsonInput(); int size = bsonInput.readInt32(); - if (size < 5) { - throw new BsonSerializationException("Document size must be at least 5"); - } + checkMinDocumentSize(size); int pipedDocumentStartPosition = bsonOutput.getPosition(); bsonOutput.writeInt32(size); - byte[] bytes = new byte[size - 4]; - bsonInput.readBytes(bytes); - bsonOutput.writeBytes(bytes); + bsonInput.pipe(bsonOutput, size - 4); binaryReader.setState(AbstractBsonReader.State.TYPE); @@ -371,17 +387,7 @@ private void pipeDocument(final BsonReader reader, final List extra setContext(getContext().getParentContext()); } - if (getContext() == null) { - setState(State.DONE); - } else { - if (getContext().getContextType() == BsonContextType.JAVASCRIPT_WITH_SCOPE) { - backpatchSize(); // size of the JavaScript with scope value - setContext(getContext().getParentContext()); - } - setState(getNextState()); - } - - validateSize(bsonOutput.getPosition() - pipedDocumentStartPosition); + completePipeDocument(pipedDocumentStartPosition); } else if (extraElements != null) { super.pipe(reader, extraElements); } else { @@ -389,6 +395,19 @@ private void pipeDocument(final BsonReader reader, final List extra } } + private void completePipeDocument(final int pipedDocumentStartPosition) { + if (getContext() == null) { + setState(State.DONE); + } else { + if (getContext().getContextType() == BsonContextType.JAVASCRIPT_WITH_SCOPE) { + backpatchSize(); // size of the JavaScript with scope value + setContext(getContext().getParentContext()); + } + setState(getNextState()); + } + validateSize(bsonOutput.getPosition() - pipedDocumentStartPosition); + } + /** * Sets a maximum size for documents from this point. * @@ -426,6 +445,12 @@ public void reset() { mark = null; } + private static void checkMinDocumentSize(final int size) { + if (size < 5) { + throw new BsonSerializationException("Document size must be at least 5"); + } + } + private void writeCurrentName() { if (getContext().getContextType() == BsonContextType.ARRAY) { int index = getContext().index++; diff --git a/bson/src/main/org/bson/RawBsonDocument.java b/bson/src/main/org/bson/RawBsonDocument.java index eb672bcef8d..7a9cbbd3b3c 100644 --- a/bson/src/main/org/bson/RawBsonDocument.java +++ b/bson/src/main/org/bson/RawBsonDocument.java @@ -44,7 +44,7 @@ import static org.bson.assertions.Assertions.notNull; /** - * An immutable BSON document that is represented using only the raw bytes. + * A BSON document that is represented using only the raw bytes. * * @since 3.0 */ @@ -144,6 +144,40 @@ public ByteBuf getByteBuffer() { return new ByteBufNIO(buffer); } + /** + * Returns the byte array backing this document. The returned array may be larger than the BSON document itself; + * only the range from {@link #getByteOffset()} to {@code getByteOffset() + }{@link #getByteLength()} contains + * valid document bytes. Changes to the returned array will be reflected in this document. + * + * @return the backing byte array + * @since 5.8 + * @see #getByteOffset() + * @see #getByteLength() + */ + public byte[] getBackingArray() { + return bytes; + } + + /** + * Returns the offset into the {@linkplain #getBackingArray() backing byte array} where this document starts. + * + * @return the offset + * @since 5.8 + */ + public int getByteOffset() { + return offset; + } + + /** + * Returns the length of this document within the {@linkplain #getBackingArray() backing byte array}. + * + * @return the length + * @since 5.8 + */ + public int getByteLength() { + return length; + } + /** * Decode this into a document. * diff --git a/bson/src/main/org/bson/codecs/RawBsonDocumentCodec.java b/bson/src/main/org/bson/codecs/RawBsonDocumentCodec.java index 4d81b7f97aa..a0d5947429f 100644 --- a/bson/src/main/org/bson/codecs/RawBsonDocumentCodec.java +++ b/bson/src/main/org/bson/codecs/RawBsonDocumentCodec.java @@ -40,8 +40,17 @@ public RawBsonDocumentCodec() { @Override public void encode(final BsonWriter writer, final RawBsonDocument value, final EncoderContext encoderContext) { - try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(value.getByteBuffer()))) { - writer.pipe(reader); + if (writer instanceof BsonBinaryWriter) { + // Fast path. The pipe method should ideally exist on BsonWriter, but adding it as + // abstract would be a breaking change, and adding it as a default method would force + // BsonWriter to depend on BsonBinaryReader/ByteBufferBsonInput, violating the + // interface's abstraction. + // TODO JAVA-6211 move pipe(byte[], int, int) to BsonWriter to remove this instanceof. + ((BsonBinaryWriter) writer).pipe(value.getBackingArray(), value.getByteOffset(), value.getByteLength()); + } else { + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(value.getByteBuffer()))) { + writer.pipe(reader); + } } } diff --git a/bson/src/main/org/bson/io/BsonInput.java b/bson/src/main/org/bson/io/BsonInput.java index 823355fe3ee..250cddab0e4 100644 --- a/bson/src/main/org/bson/io/BsonInput.java +++ b/bson/src/main/org/bson/io/BsonInput.java @@ -127,6 +127,19 @@ public interface BsonInput extends Closeable { */ boolean hasRemaining(); + /** + * Pipes the specified number of bytes from {@linkplain BsonInput this} input to the given {@linkplain BsonOutput output}. + * + * @param output the output to pipe to + * @param numBytes the number of bytes to pipe + * @since 5.8 + */ + default void pipe(BsonOutput output, int numBytes) { + byte[] bytes = new byte[numBytes]; + readBytes(bytes); + output.writeBytes(bytes); + } + @Override void close(); } diff --git a/bson/src/main/org/bson/io/ByteBufferBsonInput.java b/bson/src/main/org/bson/io/ByteBufferBsonInput.java index 2819bdcb091..1ab5ac9f5b3 100644 --- a/bson/src/main/org/bson/io/ByteBufferBsonInput.java +++ b/bson/src/main/org/bson/io/ByteBufferBsonInput.java @@ -275,6 +275,24 @@ public boolean hasRemaining() { return buffer.hasRemaining(); } + @Override + public void pipe(final BsonOutput output, final int numBytes) { + ensureOpen(); + ensureAvailable(numBytes); + + if (buffer.isBackedByArray()) { + int position = buffer.position(); + int arrayOffset = buffer.arrayOffset(); + output.writeBytes(buffer.array(), arrayOffset + position, numBytes); + buffer.position(position + numBytes); + } else { + // Fallback: use temporary buffer for non-array-backed buffers + byte[] temp = new byte[numBytes]; + buffer.get(temp); + output.writeBytes(temp); + } + } + @Override public void close() { buffer.release(); diff --git a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java index 0b067fc816f..4f589a42263 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java @@ -802,7 +802,34 @@ public void testPipeOfDocumentWithInvalidSize() { // expected } } + } + + @Test + public void testPipeOfRawBytes() { + BasicOutputBuffer sourceBuffer = new BasicOutputBuffer(); + try (BsonBinaryWriter sourceWriter = new BsonBinaryWriter(sourceBuffer)) { + sourceWriter.writeStartDocument(); + sourceWriter.writeBoolean("a", true); + sourceWriter.writeEndDocument(); + } + byte[] documentBytes = sourceBuffer.toByteArray(); + BasicOutputBuffer destBuffer = new BasicOutputBuffer(); + try (BsonBinaryWriter destWriter = new BsonBinaryWriter(destBuffer)) { + destWriter.pipe(documentBytes, 0, documentBytes.length); + } + + assertArrayEquals(documentBytes, destBuffer.toByteArray()); + } + + @Test + public void testPipeOfRawBytesWithInvalidSize() { + byte[] bytes = {4, 0, 0, 0}; // minimum document size is 5 + + BasicOutputBuffer newBuffer = new BasicOutputBuffer(); + try (BsonBinaryWriter newWriter = new BsonBinaryWriter(newBuffer)) { + assertThrows(BsonSerializationException.class, () -> newWriter.pipe(bytes, 0, bytes.length)); + } } // CHECKSTYLE:OFF diff --git a/bson/src/test/unit/org/bson/RawBsonDocumentTest.java b/bson/src/test/unit/org/bson/RawBsonDocumentTest.java new file mode 100644 index 00000000000..866dc101799 --- /dev/null +++ b/bson/src/test/unit/org/bson/RawBsonDocumentTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Arrays; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class RawBsonDocumentTest { + + private static final BsonDocument DOCUMENT = new BsonDocument() + .append("a", new BsonInt32(1)) + .append("b", new BsonInt32(2)) + .append("c", new BsonDocument("x", BsonBoolean.TRUE)) + .append("d", new BsonArray(Arrays.asList( + new BsonDocument("y", BsonBoolean.FALSE), + new BsonArray(Arrays.asList(new BsonInt32(1)))))); + + private static final byte[] DOCUMENT_BYTES = encodeDocument(); + + static Stream backingArrayAccessors() { + int documentLength = DOCUMENT_BYTES.length; + + Stream.Builder builder = Stream.builder(); + builder.add(Arguments.of(createFromDocument(), 0, documentLength)); + builder.add(Arguments.of(createFromByteArray(), 0, documentLength)); + + for (int padding = 1; padding <= 2; padding++) { + builder.add(Arguments.of(createPaddedBefore(padding), padding, documentLength)); + builder.add(Arguments.of(createPaddedAfter(padding), 0, documentLength)); + builder.add(Arguments.of(createPaddedBoth(padding), padding, documentLength)); + } + + return builder.build(); + } + + @ParameterizedTest(name = "{0}, expectedOffset={1}, expectedLength={2}") + @MethodSource("backingArrayAccessors") + void shouldExposeBackingArrayOffsetAndLength(final RawBsonDocument rawDocument, + final int expectedOffset, + final int expectedLength) { + assertEquals(expectedOffset, rawDocument.getByteOffset()); + assertEquals(expectedLength, rawDocument.getByteLength()); + assertArrayEquals(DOCUMENT_BYTES, + Arrays.copyOfRange( + rawDocument.getBackingArray(), + rawDocument.getByteOffset(), + rawDocument.getByteOffset() + rawDocument.getByteLength())); + } + + private static Named createFromDocument() { + return Named.of("from document", new RawBsonDocument(DOCUMENT, new BsonDocumentCodec())); + } + + private static Named createFromByteArray() { + return Named.of("from byte array", new RawBsonDocument(DOCUMENT_BYTES)); + } + + private static Named createPaddedBefore(final int padding) { + byte[] padded = new byte[DOCUMENT_BYTES.length + padding]; + System.arraycopy(DOCUMENT_BYTES, 0, padded, padding, DOCUMENT_BYTES.length); + return Named.of("padded before " + padding, new RawBsonDocument(padded, padding, DOCUMENT_BYTES.length)); + } + + private static Named createPaddedAfter(final int padding) { + byte[] padded = new byte[DOCUMENT_BYTES.length + padding]; + System.arraycopy(DOCUMENT_BYTES, 0, padded, 0, DOCUMENT_BYTES.length); + return Named.of("padded after " + padding, new RawBsonDocument(padded, 0, DOCUMENT_BYTES.length)); + } + + private static Named createPaddedBoth(final int padding) { + byte[] padded = new byte[DOCUMENT_BYTES.length + padding * 2]; + System.arraycopy(DOCUMENT_BYTES, 0, padded, padding, DOCUMENT_BYTES.length); + return Named.of("padded both " + padding, new RawBsonDocument(padded, padding, DOCUMENT_BYTES.length)); + } + + private static byte[] encodeDocument() { + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), DOCUMENT, EncoderContext.builder().build()); + return Arrays.copyOf(buffer.getInternalBuffer(), buffer.getPosition()); + } +} diff --git a/bson/src/test/unit/org/bson/io/BsonInputTest.java b/bson/src/test/unit/org/bson/io/BsonInputTest.java new file mode 100644 index 00000000000..f9676f6a8de --- /dev/null +++ b/bson/src/test/unit/org/bson/io/BsonInputTest.java @@ -0,0 +1,153 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.io; + +import org.bson.ByteBufNIO; +import org.bson.types.ObjectId; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class BsonInputTest { + + @Test + void defaultPipeShouldCopyBytesFromInputToOutput() { + // given + byte[] inputBytes = "Java!".getBytes(StandardCharsets.UTF_8); + + try (BsonInput bsonInput = new ForwardingBsonInput( + new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(inputBytes)))); + BasicOutputBuffer output = new BasicOutputBuffer()) { + // when + bsonInput.pipe(output, inputBytes.length); + + // then + assertEquals(inputBytes.length, bsonInput.getPosition()); + assertEquals(inputBytes.length, output.getPosition()); + assertArrayEquals(inputBytes, output.toByteArray()); + } + } + + @Test + void defaultPipeShouldCopyPartialBytesFromInputToOutput() { + // given + byte[] inputBytes = "Java!".getBytes(StandardCharsets.UTF_8); + + try (BsonInput bsonInput = new ForwardingBsonInput( + new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(inputBytes)))); + BasicOutputBuffer output = new BasicOutputBuffer()) { + // when + bsonInput.pipe(output, 3); + + // then + assertEquals(3, bsonInput.getPosition()); + assertEquals(3, output.getPosition()); + assertArrayEquals("Jav".getBytes(StandardCharsets.UTF_8), output.toByteArray()); + } + } + + /** + * Delegates all abstract methods but does NOT override pipe, + * so the default implementation is exercised. + */ + private static class ForwardingBsonInput implements BsonInput { + private final ByteBufferBsonInput delegate; + + ForwardingBsonInput(final ByteBufferBsonInput delegate) { + this.delegate = delegate; + } + + @Override + public int getPosition() { + return delegate.getPosition(); + } + + @Override + public byte readByte() { + return delegate.readByte(); + } + + @Override + public void readBytes(final byte[] bytes) { + delegate.readBytes(bytes); + } + + @Override + public void readBytes(final byte[] bytes, final int offset, final int length) { + delegate.readBytes(bytes, offset, length); + } + + @Override + public long readInt64() { + return delegate.readInt64(); + } + + @Override + public double readDouble() { + return delegate.readDouble(); + } + + @Override + public int readInt32() { + return delegate.readInt32(); + } + + @Override + public String readString() { + return delegate.readString(); + } + + @Override + public ObjectId readObjectId() { + return delegate.readObjectId(); + } + + @Override + public String readCString() { + return delegate.readCString(); + } + + @Override + public void skipCString() { + delegate.skipCString(); + } + + @Override + public void skip(final int numBytes) { + delegate.skip(numBytes); + } + + @Override + public BsonInputMark getMark(final int readLimit) { + return delegate.getMark(readLimit); + } + + @Override + public boolean hasRemaining() { + return delegate.hasRemaining(); + } + + @Override + public void close() { + delegate.close(); + } + } +} diff --git a/buildSrc/src/main/kotlin/conventions/testing-base.gradle.kts b/buildSrc/src/main/kotlin/conventions/testing-base.gradle.kts index 4708c742d40..e368bab0a67 100644 --- a/buildSrc/src/main/kotlin/conventions/testing-base.gradle.kts +++ b/buildSrc/src/main/kotlin/conventions/testing-base.gradle.kts @@ -29,7 +29,8 @@ plugins { } tasks.withType { - maxHeapSize = "4g" + // Override with -PtestMaxHeapSize= (e.g. "1g", "512m"). Defaults to 4g. + maxHeapSize = findProperty("testMaxHeapSize")?.toString() ?: "4g" maxParallelForks = 1 useJUnitPlatform() diff --git a/driver-core/build.gradle.kts b/driver-core/build.gradle.kts index 282c478858d..047b3a43a63 100644 --- a/driver-core/build.gradle.kts +++ b/driver-core/build.gradle.kts @@ -102,6 +102,7 @@ configureJarManifest { "org.bson.codecs.record.*;resolution:=optional", // Depends on JDK version "org.bson.codecs.kotlin.*;resolution:=optional", "org.bson.codecs.kotlinx.*;resolution:=optional", + "io.micrometer.*;resolution:=optional", "*" // import all that is not excluded or modified before ) .joinToString(",") diff --git a/driver-core/src/main/com/mongodb/client/model/Aggregates.java b/driver-core/src/main/com/mongodb/client/model/Aggregates.java index 6a5950ab560..29531e76e16 100644 --- a/driver-core/src/main/com/mongodb/client/model/Aggregates.java +++ b/driver-core/src/main/com/mongodb/client/model/Aggregates.java @@ -59,6 +59,7 @@ import static com.mongodb.internal.Iterables.concat; import static com.mongodb.internal.client.model.Util.sizeAtLeast; import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; /** * Builders for aggregation pipeline stages. @@ -1040,6 +1041,58 @@ public static Bson vectorSearch( return new VectorSearchBson(path, queryVector, index, limit, options); } + /** + * Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas. + * You may use the {@code $meta: "score"} expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)} + * or {@link RerankQuery#rerankQuery(Bson)}. + * @param path The document field to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (currently 1-1000). + * @param model The reranking model name. Currently accepted: + * {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}. + * @return The {@code $rerank} pipeline stage. + * @mongodb.server.release 8.3 + * @since 5.8 + */ + @Beta(Reason.SERVER) + public static Bson rerank( + final RerankQuery query, + final String path, + final int numDocsToRerank, + final String model) { + notNull("path", path); + return rerank(query, singletonList(path), numDocsToRerank, model); + } + + /** + * Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas. + * You may use the {@code $meta: "score"} expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}. + * @param paths The document field(s) to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. Accepted values: + * {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}. + * @return The {@code $rerank} pipeline stage. + * @mongodb.server.release 8.3 + * @since 5.8 + */ + @Beta(Reason.SERVER) + public static Bson rerank( + final RerankQuery query, + final List paths, + final int numDocsToRerank, + final String model) { + notNull("query", query); + notNull("paths", paths); + isTrueArgument("paths must not be empty", !paths.isEmpty()); + notNull("model", model); + return new RerankBson(query, paths, numDocsToRerank, model); + } + /** * Creates an $unset pipeline stage that removes/excludes fields from documents * @@ -2290,4 +2343,38 @@ public String toString() { + '}'; } } + + private static class RerankBson implements Bson { + private final RerankQuery query; + private final List paths; + private final int numDocsToRerank; + private final String model; + + RerankBson(final RerankQuery query, final List paths, final int numDocsToRerank, + final String model) { + this.query = query; + this.paths = paths; + this.numDocsToRerank = numDocsToRerank; + this.model = model; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + Document specificationDoc = new Document("query", query) + .append("path", paths.size() == 1 ? paths.get(0) : paths) + .append("numDocsToRerank", numDocsToRerank) + .append("model", model); + return new Document("$rerank", specificationDoc).toBsonDocument(documentClass, codecRegistry); + } + + @Override + public String toString() { + return "Stage{name=$rerank" + + ", query=" + query + + ", paths=" + paths + + ", numDocsToRerank=" + numDocsToRerank + + ", model=" + model + + '}'; + } + } } diff --git a/driver-core/src/main/com/mongodb/client/model/HnswSearchIndexOptions.java b/driver-core/src/main/com/mongodb/client/model/HnswSearchIndexOptions.java new file mode 100644 index 00000000000..73d36496838 --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/HnswSearchIndexOptions.java @@ -0,0 +1,106 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model; + +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.conversions.Bson; + +import static com.mongodb.assertions.Assertions.isTrueArgument; + +/** + * Options for the HNSW (Hierarchical Navigable Small World) indexing method in a vector search index. + * + *

This class provides a fluent builder for specifying HNSW-specific parameters when creating + * a vector search index with {@code indexingMethod("hnsw")}.

+ * + *

Since {@link VectorSearchIndexFields.VectorField#hnswOptions(Bson)} accepts any {@link Bson}, + * a raw {@link org.bson.Document} may also be passed directly for forward compatibility.

+ * + *
{@code
+ *    vectorField("embedding")
+ *        .indexingMethod("hnsw")
+ *        .hnswOptions(new HnswSearchIndexOptions().maxEdges(16).numEdgeCandidates(200))
+ * }
+ * + * @see VectorSearchIndexFields.VectorField#hnswOptions(Bson) + * @since 5.8 + */ +@NotThreadSafe +public final class HnswSearchIndexOptions implements Bson { + @Nullable + private Integer maxEdges; + @Nullable + private Integer numEdgeCandidates; + + /** + * Creates a new instance with default settings. + * + * @since 5.8 + */ + public HnswSearchIndexOptions() { + } + + /** + * Sets the maximum number of connected neighbors for each node in the HNSW graph. + * + * @param maxEdges the maximum number of edges (connected neighbors) + * @return this + * @since 5.8 + */ + public HnswSearchIndexOptions maxEdges(final int maxEdges) { + isTrueArgument("maxEdges > 0", maxEdges > 0); + this.maxEdges = maxEdges; + return this; + } + + /** + * Sets the number of nearest neighbor candidates to consider when building the HNSW graph. + * + * @param numEdgeCandidates the number of nearest neighbor candidates + * @return this + * @since 5.8 + */ + public HnswSearchIndexOptions numEdgeCandidates(final int numEdgeCandidates) { + isTrueArgument("numEdgeCandidates > 0", numEdgeCandidates > 0); + this.numEdgeCandidates = numEdgeCandidates; + return this; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + BsonDocument doc = new BsonDocument(); + if (maxEdges != null) { + doc.append("maxEdges", new BsonInt32(maxEdges)); + } + if (numEdgeCandidates != null) { + doc.append("numEdgeCandidates", new BsonInt32(numEdgeCandidates)); + } + return doc; + } + + @Override + public String toString() { + return "HnswSearchIndexOptions{" + + "maxEdges=" + maxEdges + + ", numEdgeCandidates=" + numEdgeCandidates + + '}'; + } +} diff --git a/driver-core/src/main/com/mongodb/client/model/RerankQuery.java b/driver-core/src/main/com/mongodb/client/model/RerankQuery.java new file mode 100644 index 00000000000..eeff6ecb84e --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/RerankQuery.java @@ -0,0 +1,84 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model; + +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.annotations.Beta; +import org.bson.annotations.Reason; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.conversions.Bson; + +import static com.mongodb.assertions.Assertions.notNull; + +/** + * Represents a query for the {@code $rerank} aggregation pipeline stage. + *

+ * The {@code $rerank} stage is available only in MongoDB Atlas. + *

+ * Use {@link #rerankQuery(String)} for a simple text query, or + * {@link #rerankQuery(Bson)} to specify the full query document directly + * (e.g., for future modalities like imageURL or videoURL). + * + * @mongodb.server.release 8.3 + * @since 5.8 + */ +@Beta(Reason.SERVER) +public final class RerankQuery implements Bson { + private final Bson query; + + private RerankQuery(final Bson query) { + this.query = query; + } + + /** + * Creates a rerank query with the specified text. + *

+ * This is a convenience for {@code rerankQuery(new Document("text", text))}. + * + * @param text the query text to rerank against. + * @return a new {@link RerankQuery} + */ + public static RerankQuery rerankQuery(final String text) { + notNull("text", text); + return new RerankQuery(new BsonDocument("text", new BsonString(text))); + } + + /** + * Creates a rerank query from a full query document. + *

+ * Use this overload for future query modalities (e.g., imageURL, videoURL) + * or to pass additional fields alongside text. + * + * @param query the query document. + * @return a new {@link RerankQuery} + */ + public static RerankQuery rerankQuery(final Bson query) { + notNull("query", query); + return new RerankQuery(query); + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + return query.toBsonDocument(documentClass, codecRegistry); + } + + @Override + public String toString() { + return "RerankQuery{" + query + '}'; + } +} diff --git a/driver-core/src/main/com/mongodb/client/model/SearchIndexDefinition.java b/driver-core/src/main/com/mongodb/client/model/SearchIndexDefinition.java new file mode 100644 index 00000000000..ac6bf9b3403 --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/SearchIndexDefinition.java @@ -0,0 +1,82 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model; + +import com.mongodb.annotations.Sealed; +import org.bson.conversions.Bson; + +import java.util.List; + +import static com.mongodb.assertions.Assertions.isTrueArgument; +import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.assertions.Assertions.notNullElements; +import static java.util.Arrays.asList; + +/** + * A definition for an Atlas Search index. + * + *

This interface provides factory methods for creating search index definitions + * that can be passed to {@link SearchIndexModel}.

+ * + * @see SearchIndexModel + * @see VectorSearchIndexDefinition + * @since 5.8 + */ +@Sealed +public interface SearchIndexDefinition extends Bson { + + /** + * Creates a vector search index definition with the specified fields. + * + *

The resulting definition produces a document of the form {@code {"fields": [...]}}, + * suitable for use with {@link SearchIndexType#vectorSearch()}.

+ * + * @param fields the fields for the vector search index. Each field should be created using + * {@link VectorSearchIndexFields} factory methods, or may be a raw {@link Bson} document. + * @return a new {@link VectorSearchIndexDefinition} + * @see VectorSearchIndexFields#vectorField(String) + * @see VectorSearchIndexFields#filterField(String) + * @see VectorSearchIndexFields#autoEmbedField(String) + * @since 5.8 + */ + static VectorSearchIndexDefinition vectorSearch(final Bson... fields) { + List fieldList = asList(notNull("fields", fields)); + isTrueArgument("fields must not be empty", !fieldList.isEmpty()); + notNullElements("fields", fieldList); + return new VectorSearchIndexDefinition(fieldList); + } + + /** + * Creates a vector search index definition with the specified fields. + * + *

The resulting definition produces a document of the form {@code {"fields": [...]}}, + * suitable for use with {@link SearchIndexType#vectorSearch()}.

+ * + * @param fields the fields for the vector search index. Each field should be created using + * {@link VectorSearchIndexFields} factory methods, or may be a raw {@link Bson} document. + * @return a new {@link VectorSearchIndexDefinition} + * @see VectorSearchIndexFields#vectorField(String) + * @see VectorSearchIndexFields#filterField(String) + * @see VectorSearchIndexFields#autoEmbedField(String) + * @since 5.8 + */ + static VectorSearchIndexDefinition vectorSearch(final List fields) { + notNullElements("fields", fields); + isTrueArgument("fields must not be empty", !fields.isEmpty()); + return new VectorSearchIndexDefinition(fields); + } +} diff --git a/driver-core/src/main/com/mongodb/client/model/SearchIndexModel.java b/driver-core/src/main/com/mongodb/client/model/SearchIndexModel.java index 2a229e1a579..28e9a3a56ca 100644 --- a/driver-core/src/main/com/mongodb/client/model/SearchIndexModel.java +++ b/driver-core/src/main/com/mongodb/client/model/SearchIndexModel.java @@ -24,6 +24,14 @@ /** * A model describing the creation of a single Atlas Search index. * + *

The {@code definition} parameter accepts any {@link org.bson.conversions.Bson} instance. + * For vector search indexes, use the builders provided by {@link SearchIndexDefinition#vectorSearch(Bson...)} + * and {@link VectorSearchIndexFields} to construct the definition, and pass it to the + * {@linkplain #SearchIndexModel(String, VectorSearchIndexDefinition) vector search constructor} + * which automatically sets the index type to {@link SearchIndexType#vectorSearch()}.

+ * + * @see SearchIndexDefinition + * @see VectorSearchIndexFields * @since 4.11 * @mongodb.server.release 6.0 */ @@ -42,6 +50,7 @@ public final class SearchIndexModel { * will be used to create the search index.

* * @param definition the search index mapping definition. + * @see SearchIndexDefinition#vectorSearch(Bson...) */ public SearchIndexModel(final Bson definition) { this(null, definition, null); @@ -52,17 +61,33 @@ public SearchIndexModel(final Bson definition) { * * @param name the search index name. * @param definition the search index mapping definition. + * @see SearchIndexDefinition#vectorSearch(Bson...) */ public SearchIndexModel(final String name, final Bson definition) { this(name, definition, null); } + /** + * Construct a vector search index instance with the given name and definition. + * + *

The index type is automatically set to {@link SearchIndexType#vectorSearch()}.

+ * + * @param name the search index name. + * @param definition the vector search index definition. + * @see SearchIndexDefinition#vectorSearch(Bson...) + * @since 5.8 + */ + public SearchIndexModel(final String name, final VectorSearchIndexDefinition definition) { + this(name, definition, SearchIndexType.vectorSearch()); + } + /** * Construct an instance with the given Atlas Search name, index definition, and type. * * @param name the search index name. * @param definition the search index mapping definition. * @param type the search index type. + * @see SearchIndexDefinition#vectorSearch(Bson...) * @since 5.2 */ public SearchIndexModel(@Nullable final String name, final Bson definition, @Nullable final SearchIndexType type) { diff --git a/driver-core/src/main/com/mongodb/client/model/VectorSearchIndexDefinition.java b/driver-core/src/main/com/mongodb/client/model/VectorSearchIndexDefinition.java new file mode 100644 index 00000000000..12b90751e0d --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/VectorSearchIndexDefinition.java @@ -0,0 +1,90 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model; + +import com.mongodb.lang.Nullable; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.conversions.Bson; + +import java.util.ArrayList; +import java.util.List; + +import static com.mongodb.assertions.Assertions.doesNotContainNull; +import static com.mongodb.assertions.Assertions.notNull; + +/** + * A vector search index definition, producing a document of the form {@code {"fields": [...]}}. + * + *

Instances are created via {@link SearchIndexDefinition#vectorSearch(Bson...)}.

+ * + * @see SearchIndexDefinition + * @see SearchIndexType#vectorSearch() + * @since 5.8 + */ +public final class VectorSearchIndexDefinition implements SearchIndexDefinition { + private final List fields; + @Nullable + private final Bson storedSource; + + VectorSearchIndexDefinition(final List fields) { + this(fields, null); + } + + VectorSearchIndexDefinition(final List fields, @Nullable final Bson storedSource) { + doesNotContainNull("fields", notNull("fields", fields)); + this.fields = new ArrayList<>(fields); + this.storedSource = storedSource; + } + + /** + * Creates a new {@link VectorSearchIndexDefinition} with the specified stored source configuration. + * + *

The stored source configuration controls which fields are stored in the index + * and can be returned without reading the full document from the collection.

+ * + * @param storedSource a document specifying the stored source configuration, + * e.g., {@code {"include": ["field1", "field2"]}} or {@code {"exclude": ["field3"]}} + * @return a new {@link VectorSearchIndexDefinition} with the stored source configuration + * @since 5.8 + */ + public VectorSearchIndexDefinition storedSource(final Bson storedSource) { + return new VectorSearchIndexDefinition(this.fields, notNull("storedSource", storedSource)); + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + BsonArray fieldArray = new BsonArray(); + for (Bson field : fields) { + fieldArray.add(field.toBsonDocument(documentClass, codecRegistry)); + } + BsonDocument document = new BsonDocument("fields", fieldArray); + if (storedSource != null) { + document.append("storedSource", storedSource.toBsonDocument(documentClass, codecRegistry)); + } + return document; + } + + @Override + public String toString() { + return "VectorSearchIndexDefinition{" + + "fields=" + fields + + ", storedSource=" + storedSource + + '}'; + } +} diff --git a/driver-core/src/main/com/mongodb/client/model/VectorSearchIndexFields.java b/driver-core/src/main/com/mongodb/client/model/VectorSearchIndexFields.java new file mode 100644 index 00000000000..30d35a0a0a4 --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/VectorSearchIndexFields.java @@ -0,0 +1,424 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model; + +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.conversions.Bson; + +import static com.mongodb.assertions.Assertions.isTrueArgument; +import static com.mongodb.assertions.Assertions.notNull; + +/** + * A factory for defining fields within a vector search index definition. + * + *

A convenient way to use this class is to statically import all of its methods, which allows usage like:

+ *
{@code
+ *    SearchIndexDefinition.vectorSearch(
+ *        vectorField("plot_embedding")
+ *            .numDimensions(1536)
+ *            .similarity("euclidean")
+ *            .indexingMethod("flat"),
+ *        filterField("genre")
+ *    );
+ * }
+ * + * @see SearchIndexDefinition#vectorSearch(Bson...) + * @since 5.8 + */ +public final class VectorSearchIndexFields { + + private VectorSearchIndexFields() { + } + + /** + * Creates a vector field definition for a vector search index. + * + * @param path the field path in the document + * @return a new {@link VectorField} + * @since 5.8 + */ + public static VectorField vectorField(final String path) { + return new VectorField(notNull("path", path)); + } + + /** + * Creates a filter field definition for a vector search index. + * + * @param path the field path in the document + * @return a new {@link FilterField} + * @since 5.8 + */ + public static FilterField filterField(final String path) { + return new FilterField(notNull("path", path)); + } + + /** + * Creates an auto-embed field definition for a vector search index. + * + * @param path the field path in the document containing the content to embed + * @return a new {@link AutoEmbedField} + * @since 5.8 + */ + public static AutoEmbedField autoEmbedField(final String path) { + return new AutoEmbedField(notNull("path", path)); + } + + /** + * A vector field definition for a vector search index. + * + *

Instances are created via {@link #vectorField(String)}.

+ * + * @since 5.8 + */ + @NotThreadSafe + public static final class VectorField implements Bson { + private final String path; + @Nullable + private Integer numDimensions; + @Nullable + private String similarity; + @Nullable + private String indexingMethod; + @Nullable + private Bson hnswOptions; + + private VectorField(final String path) { + this.path = path; + } + + /** + * Sets the number of dimensions for the vector field. + * + * @param numDimensions the number of vector dimensions + * @return this + * @since 5.8 + */ + public VectorField numDimensions(final int numDimensions) { + isTrueArgument("numDimensions > 0", numDimensions > 0); + this.numDimensions = numDimensions; + return this; + } + + /** + * Sets the similarity function used to compare vectors. + * + *

Supported values:

+ *
    + *
  • {@code "euclidean"} — measures the distance between ends of vectors
  • + *
  • {@code "cosine"} — measures the angle between vectors
  • + *
  • {@code "dotProduct"} — measures both the magnitude and direction of vectors
  • + *
+ * + * @param similarity the similarity function name + * @return this + * @since 5.8 + */ + public VectorField similarity(final String similarity) { + this.similarity = notNull("similarity", similarity); + return this; + } + + /** + * Sets the indexing method for this vector field. + * + *

Supported values:

+ *
    + *
  • {@code "flat"} — optimized for multi-tenant use cases with singular, static filters
  • + *
  • {@code "hnsw"} — Hierarchical Navigable Small World graph
  • + *
+ * + * @param indexingMethod the indexing method name + * @return this + * @since 5.8 + */ + public VectorField indexingMethod(final String indexingMethod) { + this.indexingMethod = notNull("indexingMethod", indexingMethod); + return this; + } + + /** + * Sets the HNSW options for this vector field. + * + *

This is only applicable when the indexing method is {@code "hnsw"}. + * A convenience builder is available via {@link HnswSearchIndexOptions}, or a raw + * {@link org.bson.Document} may be passed directly.

+ * + * @param hnswOptions the HNSW options + * @return this + * @see HnswSearchIndexOptions + * @since 5.8 + */ + public VectorField hnswOptions(final Bson hnswOptions) { + this.hnswOptions = notNull("hnswOptions", hnswOptions); + return this; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + BsonDocument doc = new BsonDocument(); + doc.append("type", new BsonString("vector")); + doc.append("path", new BsonString(path)); + if (numDimensions != null) { + doc.append("numDimensions", new BsonInt32(numDimensions)); + } + if (similarity != null) { + doc.append("similarity", new BsonString(similarity)); + } + if (indexingMethod != null) { + doc.append("indexingMethod", new BsonString(indexingMethod)); + } + if (hnswOptions != null) { + doc.append("hnswOptions", hnswOptions.toBsonDocument(documentClass, codecRegistry)); + } + return doc; + } + + @Override + public String toString() { + return "VectorField{" + + "path='" + path + '\'' + + ", numDimensions=" + numDimensions + + ", similarity='" + similarity + '\'' + + ", indexingMethod='" + indexingMethod + '\'' + + ", hnswOptions=" + hnswOptions + + '}'; + } + } + + /** + * A filter field definition for a vector search index. + * + *

Instances are created via {@link #filterField(String)}.

+ * + * @since 5.8 + */ + public static final class FilterField implements Bson { + private final String path; + + private FilterField(final String path) { + this.path = path; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + BsonDocument doc = new BsonDocument(); + doc.append("type", new BsonString("filter")); + doc.append("path", new BsonString(path)); + return doc; + } + + @Override + public String toString() { + return "FilterField{" + + "path='" + path + '\'' + + '}'; + } + } + + /** + * An auto-embed field definition for a vector search index. + * + *

Instances are created via {@link #autoEmbedField(String)}.

+ * + * @since 5.8 + */ + @NotThreadSafe + public static final class AutoEmbedField implements Bson { + private final String path; + @Nullable + private String modality; + @Nullable + private String model; + @Nullable + private Integer numDimensions; + @Nullable + private String quantization; + @Nullable + private String similarity; + @Nullable + private String indexingMethod; + @Nullable + private Bson hnswOptions; + + private AutoEmbedField(final String path) { + this.path = path; + } + + /** + * Sets the modality for auto-embedding. This is a required field. + * + *

The initially supported type is {@code "text"}.

+ * + * @param modality the modality (e.g., {@code "text"}) + * @return this + * @since 5.8 + */ + public AutoEmbedField modality(final String modality) { + this.modality = notNull("modality", modality); + return this; + } + + /** + * Sets the embedding model to use. This is a required field. + * + *

Only one model can be used across all fields in a single vector index definition.

+ * + * @param model the model name (e.g., {@code "voyage-4"}, {@code "voyage-4-large"}, {@code "voyage-4-lite"}, {@code "voyage-code-3"}) + * @return this + * @since 5.8 + */ + public AutoEmbedField model(final String model) { + this.model = notNull("model", model); + return this; + } + + /** + * Sets the number of dimensions for the auto-embedded vector. This is an optional field. + * + *

These map to the number of dimensions supported by the API endpoint (currently 256, 512, 1024, 2048).

+ * + * @param numDimensions the number of vector dimensions + * @return this + * @since 5.8 + */ + public AutoEmbedField numDimensions(final int numDimensions) { + isTrueArgument("numDimensions > 0", numDimensions > 0); + this.numDimensions = numDimensions; + return this; + } + + /** + * Sets the quantization type for the auto-embedded vector. This is an optional field. + * + *

Supported values:

+ *
    + *
  • {@code "float"}
  • + *
  • {@code "scalar"}
  • + *
  • {@code "binary"}
  • + *
  • {@code "binaryNoRescore"}
  • + *
+ * + * @param quantization the quantization type + * @return this + * @since 5.8 + */ + public AutoEmbedField quantization(final String quantization) { + this.quantization = notNull("quantization", quantization); + return this; + } + + /** + * Sets the similarity function used to compare vectors. This is an optional field. + * + *

Supported values:

+ *
    + *
  • {@code "dotProduct"}
  • + *
  • {@code "cosine"}
  • + *
  • {@code "euclidean"}
  • + *
+ * + * @param similarity the similarity function name + * @return this + * @since 5.8 + */ + public AutoEmbedField similarity(final String similarity) { + this.similarity = notNull("similarity", similarity); + return this; + } + + /** + * Sets the indexing method for this auto-embed field. This is an optional field. + * + *

Supported values:

+ *
    + *
  • {@code "flat"} — optimized for multi-tenant use cases with singular, static filters
  • + *
  • {@code "hnsw"} — Hierarchical Navigable Small World graph
  • + *
+ * + * @param indexingMethod the indexing method name + * @return this + * @since 5.8 + */ + public AutoEmbedField indexingMethod(final String indexingMethod) { + this.indexingMethod = notNull("indexingMethod", indexingMethod); + return this; + } + + /** + * Sets the HNSW options for this auto-embed field. This is an optional field. + * + *

This is only applicable when the indexing method is {@code "hnsw"}. + * A convenience builder is available via {@link HnswSearchIndexOptions}, or a raw + * {@link org.bson.Document} may be passed directly.

+ * + * @param hnswOptions the HNSW options + * @return this + * @see HnswSearchIndexOptions + * @since 5.8 + */ + public AutoEmbedField hnswOptions(final Bson hnswOptions) { + this.hnswOptions = notNull("hnswOptions", hnswOptions); + return this; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + isTrueArgument("modality is required for autoEmbed fields", modality != null); + isTrueArgument("model is required for autoEmbed fields", model != null); + BsonDocument doc = new BsonDocument(); + doc.append("type", new BsonString("autoEmbed")); + doc.append("path", new BsonString(path)); + doc.append("modality", new BsonString(modality)); + doc.append("model", new BsonString(model)); + if (numDimensions != null) { + doc.append("numDimensions", new BsonInt32(numDimensions)); + } + if (quantization != null) { + doc.append("quantization", new BsonString(quantization)); + } + if (similarity != null) { + doc.append("similarity", new BsonString(similarity)); + } + if (indexingMethod != null) { + doc.append("indexingMethod", new BsonString(indexingMethod)); + } + if (hnswOptions != null) { + doc.append("hnswOptions", hnswOptions.toBsonDocument(documentClass, codecRegistry)); + } + return doc; + } + + @Override + public String toString() { + return "AutoEmbedField{" + + "path='" + path + '\'' + + ", modality='" + modality + '\'' + + ", model='" + model + '\'' + + ", numDimensions=" + numDimensions + + ", quantization='" + quantization + '\'' + + ", similarity='" + similarity + '\'' + + ", indexingMethod='" + indexingMethod + '\'' + + ", hnswOptions=" + hnswOptions + + '}'; + } + } +} diff --git a/driver-core/src/main/com/mongodb/client/model/search/SearchOperator.java b/driver-core/src/main/com/mongodb/client/model/search/SearchOperator.java index aa8b01b29d4..6e2e2df6b32 100644 --- a/driver-core/src/main/com/mongodb/client/model/search/SearchOperator.java +++ b/driver-core/src/main/com/mongodb/client/model/search/SearchOperator.java @@ -633,6 +633,57 @@ static RegexSearchOperator regex(final Iterable paths, fin .append("query", queryIterator.hasNext() ? queries : firstQuery)); } + /** + * Returns a {@link SearchOperator} that performs vector search within the {@code $search} pipeline stage. + * This is the approximate (ANN) variant with {@code numCandidates}. + * + * @param path The indexed vector field to search. + * @param queryVector The query vector. The number of dimensions must match the index field. + * @param limit The number of results to return. + * @param numCandidates The number of nearest neighbors to consider during ANN search. + * Must be greater than or equal to {@code limit}. The server may impose an upper bound. + * @return The requested {@link VectorSearchOperator}. + * @mongodb.atlas.manual atlas-search/vector-search/ vectorSearch operator + * @since 5.8 + */ + static VectorSearchOperator vectorSearch( + final FieldSearchPath path, + final Iterable queryVector, + final int limit, + final int numCandidates) { + notNull("path", path); + notNull("queryVector", queryVector); + isTrueArgument("numCandidates must be >= limit", numCandidates >= limit); + return new VectorSearchOperatorConstructibleBsonElement("vectorSearch", + new Document("path", path.toValue()) + .append("queryVector", queryVector) + .append("limit", limit) + .append("numCandidates", numCandidates)); + } + + /** + * Returns a {@link SearchOperator} that performs exact (ENN) vector search within the {@code $search} pipeline stage. + * + * @param path The indexed vector field to search. + * @param queryVector The query vector. The number of dimensions must match the index field. + * @param limit The number of results to return. + * @return The requested {@link VectorSearchOperator}. + * @mongodb.atlas.manual atlas-search/vector-search/ vectorSearch operator + * @since 5.8 + */ + static VectorSearchOperator vectorSearchExact( + final FieldSearchPath path, + final Iterable queryVector, + final int limit) { + notNull("path", path); + notNull("queryVector", queryVector); + return new VectorSearchOperatorConstructibleBsonElement("vectorSearch", + new Document("path", path.toValue()) + .append("queryVector", queryVector) + .append("limit", limit) + .append("exact", true)); + } + /** * Creates a {@link SearchOperator} from a {@link Bson} in situations when there is no builder method that better satisfies your needs. * This method cannot be used to validate the syntax. diff --git a/driver-core/src/main/com/mongodb/client/model/search/VectorSearchConstructibleBson.java b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchConstructibleBson.java index 3e281890822..39a043ca82f 100644 --- a/driver-core/src/main/com/mongodb/client/model/search/VectorSearchConstructibleBson.java +++ b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchConstructibleBson.java @@ -17,6 +17,7 @@ import com.mongodb.annotations.Immutable; import com.mongodb.internal.client.model.AbstractConstructibleBson; +import org.bson.BsonBoolean; import org.bson.BsonDocument; import org.bson.Document; import org.bson.conversions.Bson; @@ -45,7 +46,12 @@ protected VectorSearchConstructibleBson newSelf(final Bson base, final Document @Override public VectorSearchOptions filter(final Bson filter) { - return newAppended("filter", notNull("name", filter)); + return newAppended("filter", notNull("filter", filter)); + } + + @Override + public VectorSearchOptions returnStoredSource(final boolean returnStoredSource) { + return newAppended("returnStoredSource", new BsonBoolean(returnStoredSource)); } @Override diff --git a/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOperator.java b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOperator.java new file mode 100644 index 00000000000..a50f626a10c --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOperator.java @@ -0,0 +1,48 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.client.model.search; + +import com.mongodb.annotations.Beta; +import com.mongodb.annotations.Reason; +import com.mongodb.annotations.Sealed; + +/** + * A {@link SearchOperator} that performs vector search within the {@code $search} pipeline stage. + * + * @mongodb.atlas.manual atlas-search/operators-and-collectors/#operators Search operators + * @since 5.8 + */ +@Sealed +@Beta(Reason.CLIENT) +public interface VectorSearchOperator extends SearchOperator { + + /** + * Creates a new {@link VectorSearchOperator} with the filter specified. + * + * @param filter A search operator to filter documents. + * @return A new {@link VectorSearchOperator}. + */ + VectorSearchOperator filter(SearchOperator filter); + + /** + * Creates a new {@link VectorSearchOperator} with the scoring modifier specified. + * + * @param modifier The scoring modifier. + * @return A new {@link VectorSearchOperator}. + */ + @Override + VectorSearchOperator score(SearchScore modifier); +} diff --git a/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOperatorConstructibleBsonElement.java b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOperatorConstructibleBsonElement.java new file mode 100644 index 00000000000..ae45cebfcb4 --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOperatorConstructibleBsonElement.java @@ -0,0 +1,49 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.client.model.search; + +import com.mongodb.internal.client.model.AbstractConstructibleBsonElement; +import org.bson.conversions.Bson; + +import static com.mongodb.assertions.Assertions.notNull; + +final class VectorSearchOperatorConstructibleBsonElement + extends AbstractConstructibleBsonElement + implements VectorSearchOperator { + + VectorSearchOperatorConstructibleBsonElement(final String name, final Bson value) { + super(name, value); + } + + private VectorSearchOperatorConstructibleBsonElement(final Bson baseElement, final Bson appendedElementValue) { + super(baseElement, appendedElementValue); + } + + @Override + protected VectorSearchOperatorConstructibleBsonElement newSelf(final Bson baseElement, final Bson appendedElementValue) { + return new VectorSearchOperatorConstructibleBsonElement(baseElement, appendedElementValue); + } + + @Override + public VectorSearchOperator filter(final SearchOperator filter) { + return newWithAppendedValue("filter", notNull("filter", filter)); + } + + @Override + public VectorSearchOperatorConstructibleBsonElement score(final SearchScore modifier) { + return newWithAppendedValue("score", notNull("modifier", modifier)); + } +} diff --git a/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOptions.java b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOptions.java index 073c05b2371..d3bcf3aea46 100644 --- a/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOptions.java +++ b/driver-core/src/main/com/mongodb/client/model/search/VectorSearchOptions.java @@ -41,6 +41,16 @@ public interface VectorSearchOptions extends Bson { */ VectorSearchOptions filter(Bson filter); + /** + * Creates a new {@link VectorSearchOptions} that instructs to return only stored source fields. + * + * @param returnStoredSource The option to return only stored source fields. + * @return A new {@link VectorSearchOptions}. + * @mongodb.atlas.manual atlas-vector-search/vector-search-stage/ $vectorSearch + * @since 5.8 + */ + VectorSearchOptions returnStoredSource(boolean returnStoredSource); + /** * Creates a new {@link VectorSearchOptions} with the specified option in situations when there is no builder method * that better satisfies your needs. diff --git a/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java b/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java index 7fd01712ea3..5cb70d4e2ef 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java @@ -33,7 +33,6 @@ import org.junit.jupiter.params.provider.MethodSource; import java.math.RoundingMode; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Stream; @@ -43,8 +42,10 @@ import static com.mongodb.client.model.Accumulators.percentile; import static com.mongodb.client.model.Aggregates.geoNear; import static com.mongodb.client.model.Aggregates.group; +import static com.mongodb.client.model.Aggregates.rerank; import static com.mongodb.client.model.Aggregates.unset; import static com.mongodb.client.model.Aggregates.vectorSearch; +import static com.mongodb.client.model.RerankQuery.rerankQuery; import static com.mongodb.client.model.GeoNearOptions.geoNearOptions; import static com.mongodb.client.model.Sorts.ascending; import static com.mongodb.client.model.Windows.Bound.UNBOUNDED; @@ -260,7 +261,7 @@ public void testDocuments() { "{$documents: [{a: 1, b: {$add: [1, 1]}}, {a: 3, b: 4}]}", stage); - List pipeline = Arrays.asList(stage); + List pipeline = asList(stage); getCollectionHelper().aggregateDb(pipeline); assertEquals( @@ -268,9 +269,9 @@ public void testDocuments() { getCollectionHelper().aggregateDb(pipeline)); // accepts lists of Documents and BsonDocuments - List documents = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}")); + List documents = asList(BsonDocument.parse("{a: 1, b: 2}")); assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(documents)); - List bsonDocuments = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}")); + List bsonDocuments = asList(BsonDocument.parse("{a: 1, b: 2}")); assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(bsonDocuments)); } @@ -281,13 +282,13 @@ public void testDocumentsLookup() { getCollectionHelper().insertDocuments("[{_id: 1, a: 8}, {_id: 2, a: 9}]"); Bson documentsStage = Aggregates.documents(asList(Document.parse("{a: 5}"))); - Bson lookupStage = Aggregates.lookup(null, Arrays.asList(documentsStage), "added"); + Bson lookupStage = Aggregates.lookup(null, asList(documentsStage), "added"); assertPipeline( "{'$lookup': {'pipeline': [{'$documents': [{'a': 5}]}], 'as': 'added'}}", lookupStage); assertEquals( parseToList("[{_id:1, a:8, added: [{a: 5}]}, {_id:2, a:9, added: [{a: 5}]}]"), - getCollectionHelper().aggregate(Arrays.asList(lookupStage))); + getCollectionHelper().aggregate(asList(lookupStage))); } @Test @@ -374,4 +375,82 @@ public void testExactVectorSearchWithQueryObject() { exactVectorSearchOptions() )); } + + @Test + public void testRerankWithSinglePath() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials'}," + + " 'path': 'content'," + + " 'numDocsToRerank': 25," + + " 'model': 'rerank-2.5'" + + " }" + + "}", + rerank( + rerankQuery("machine learning tutorials"), + "content", + 25, + "rerank-2.5" + )); + } + + @Test + public void testRerankWithMultiplePaths() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials'}," + + " 'path': ['content', 'title']," + + " 'numDocsToRerank': 50," + + " 'model': 'rerank-2.5-lite'" + + " }" + + "}", + rerank( + rerankQuery("machine learning tutorials"), + asList("content", "title"), + 50, + "rerank-2.5-lite" + )); + } + + @Test + public void testRerankWithBsonQuery() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'}," + + " 'path': 'content'," + + " 'numDocsToRerank': 25," + + " 'model': 'rerank-2.5'" + + " }" + + "}", + rerank( + rerankQuery(new Document("text", "machine learning tutorials") + .append("imageURL", "https://example.com/img.png")), + "content", + 25, + "rerank-2.5" + )); + } + + @Test + public void testRerankWithMultiplePathsAndBsonQuery() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'}," + + " 'path': ['content', 'title']," + + " 'numDocsToRerank': 100," + + " 'model': 'rerank-2'" + + " }" + + "}", + rerank( + rerankQuery(new Document("text", "machine learning tutorials") + .append("imageURL", "https://example.com/img.png")), + asList("content", "title"), + 100, + "rerank-2" + )); + } } diff --git a/driver-core/src/test/unit/com/mongodb/client/model/HnswSearchIndexOptionsTest.java b/driver-core/src/test/unit/com/mongodb/client/model/HnswSearchIndexOptionsTest.java new file mode 100644 index 00000000000..dc55e1c2e80 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/client/model/HnswSearchIndexOptionsTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.client.model; + +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +final class HnswSearchIndexOptionsTest { + + @Test + void emptyOptions() { + assertEquals( + new BsonDocument(), + new HnswSearchIndexOptions().toBsonDocument() + ); + } + + @Test + void maxEdgesOnly() { + assertEquals( + new BsonDocument("maxEdges", new BsonInt32(24)), + new HnswSearchIndexOptions().maxEdges(24).toBsonDocument() + ); + } + + @Test + void numEdgeCandidatesOnly() { + assertEquals( + new BsonDocument("numEdgeCandidates", new BsonInt32(150)), + new HnswSearchIndexOptions().numEdgeCandidates(150).toBsonDocument() + ); + } + + @Test + void allOptions() { + assertEquals( + new BsonDocument("maxEdges", new BsonInt32(16)) + .append("numEdgeCandidates", new BsonInt32(200)), + new HnswSearchIndexOptions().maxEdges(16).numEdgeCandidates(200).toBsonDocument() + ); + } + + @Test + void maxEdgesRejectsZero() { + assertThrows(IllegalArgumentException.class, () -> new HnswSearchIndexOptions().maxEdges(0)); + } + + @Test + void maxEdgesRejectsNegative() { + assertThrows(IllegalArgumentException.class, () -> new HnswSearchIndexOptions().maxEdges(-1)); + } + + @Test + void numEdgeCandidatesRejectsZero() { + assertThrows(IllegalArgumentException.class, () -> new HnswSearchIndexOptions().numEdgeCandidates(0)); + } + + @Test + void numEdgeCandidatesRejectsNegative() { + assertThrows(IllegalArgumentException.class, () -> new HnswSearchIndexOptions().numEdgeCandidates(-1)); + } +} diff --git a/driver-core/src/test/unit/com/mongodb/client/model/SearchIndexDefinitionTest.java b/driver-core/src/test/unit/com/mongodb/client/model/SearchIndexDefinitionTest.java new file mode 100644 index 00000000000..31c34d82aa8 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/client/model/SearchIndexDefinitionTest.java @@ -0,0 +1,176 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.client.model; + +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.bson.Document; +import org.junit.jupiter.api.Test; + +import static com.mongodb.client.model.VectorSearchIndexFields.filterField; +import static com.mongodb.client.model.VectorSearchIndexFields.vectorField; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +final class SearchIndexDefinitionTest { + + @Test + void vectorSearchWithSingleVectorField() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + vectorField("plot_embedding") + .numDimensions(1536) + .similarity("euclidean") + ); + + assertEquals( + new BsonDocument("fields", new BsonArray(asList( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("plot_embedding")) + .append("numDimensions", new BsonInt32(1536)) + .append("similarity", new BsonString("euclidean")) + ))), + definition.toBsonDocument() + ); + } + + @Test + void vectorSearchWithMultipleFields() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + vectorField("embedding") + .numDimensions(1536) + .similarity("euclidean") + .indexingMethod("flat"), + filterField("tenantId") + ); + + assertEquals( + new BsonDocument("fields", new BsonArray(asList( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("embedding")) + .append("numDimensions", new BsonInt32(1536)) + .append("similarity", new BsonString("euclidean")) + .append("indexingMethod", new BsonString("flat")), + new BsonDocument("type", new BsonString("filter")) + .append("path", new BsonString("tenantId")) + ))), + definition.toBsonDocument() + ); + } + + @Test + void vectorSearchWithListOfFields() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch(asList( + vectorField("embedding") + .numDimensions(1536) + .similarity("euclidean"), + filterField("category") + )); + + assertEquals( + new BsonDocument("fields", new BsonArray(asList( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("embedding")) + .append("numDimensions", new BsonInt32(1536)) + .append("similarity", new BsonString("euclidean")), + new BsonDocument("type", new BsonString("filter")) + .append("path", new BsonString("category")) + ))), + definition.toBsonDocument() + ); + } + + @Test + void vectorSearchWithRawBsonField() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + new Document("type", "vector") + .append("path", "raw_field") + .append("numDimensions", 512) + .append("similarity", "cosine") + ); + + assertEquals( + new BsonDocument("fields", new BsonArray(asList( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("raw_field")) + .append("numDimensions", new BsonInt32(512)) + .append("similarity", new BsonString("cosine")) + ))), + definition.toBsonDocument() + ); + } + + @Test + void vectorSearchRejectsNullVarargs() { + assertThrows(IllegalArgumentException.class, () -> SearchIndexDefinition.vectorSearch((org.bson.conversions.Bson[]) null)); + } + + @Test + void vectorSearchRejectsNullList() { + assertThrows(IllegalArgumentException.class, () -> SearchIndexDefinition.vectorSearch((java.util.List) null)); + } + + @Test + void vectorSearchRejectsNullElement() { + assertThrows(IllegalArgumentException.class, () -> SearchIndexDefinition.vectorSearch( + vectorField("embedding"), null)); + } + + @Test + void vectorSearchRejectsEmptyVarargs() { + assertThrows(IllegalArgumentException.class, SearchIndexDefinition::vectorSearch); + } + + @Test + void vectorSearchRejectsEmptyList() { + assertThrows(IllegalArgumentException.class, () -> SearchIndexDefinition.vectorSearch(emptyList())); + } + + @Test + void vectorSearchWithStoredSource() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + vectorField("embedding") + .numDimensions(1536) + .similarity("cosine") + ).storedSource(new Document("include", asList("plot", "title"))); + + assertEquals( + new BsonDocument("fields", new BsonArray(asList( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("embedding")) + .append("numDimensions", new BsonInt32(1536)) + .append("similarity", new BsonString("cosine")) + ))).append("storedSource", new BsonDocument("include", new BsonArray(asList( + new BsonString("plot"), + new BsonString("title") + )))), + definition.toBsonDocument() + ); + } + + @Test + void vectorSearchStoredSourceRejectsNull() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + vectorField("embedding") + .numDimensions(1536) + .similarity("cosine") + ); + assertThrows(IllegalArgumentException.class, () -> definition.storedSource(null)); + } +} diff --git a/driver-core/src/test/unit/com/mongodb/client/model/SearchIndexModelTest.java b/driver-core/src/test/unit/com/mongodb/client/model/SearchIndexModelTest.java new file mode 100644 index 00000000000..efd82bd2209 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/client/model/SearchIndexModelTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.client.model; + +import org.bson.BsonString; +import org.junit.jupiter.api.Test; + +import static com.mongodb.client.model.VectorSearchIndexFields.vectorField; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +final class SearchIndexModelTest { + + @Test + void vectorSearchConstructorSetsType() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + vectorField("embedding").numDimensions(1536).similarity("cosine") + ); + + SearchIndexModel model = new SearchIndexModel("my_index", definition); + + assertEquals("my_index", model.getName()); + assertEquals(definition, model.getDefinition()); + assertNotNull(model.getType()); + assertEquals(new BsonString("vectorSearch"), model.getType().toBsonValue()); + } + + @Test + void vectorSearchConstructorWithMultipleFields() { + VectorSearchIndexDefinition definition = SearchIndexDefinition.vectorSearch( + vectorField("embedding").numDimensions(768).similarity("euclidean"), + VectorSearchIndexFields.filterField("category") + ); + + SearchIndexModel model = new SearchIndexModel("vector_idx", definition); + + assertEquals("vector_idx", model.getName()); + assertEquals(definition, model.getDefinition()); + assertEquals(new BsonString("vectorSearch"), model.getType().toBsonValue()); + } +} diff --git a/driver-core/src/test/unit/com/mongodb/client/model/VectorSearchIndexFieldsTest.java b/driver-core/src/test/unit/com/mongodb/client/model/VectorSearchIndexFieldsTest.java new file mode 100644 index 00000000000..d24b656b6cc --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/client/model/VectorSearchIndexFieldsTest.java @@ -0,0 +1,206 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.client.model; + +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.bson.Document; +import org.junit.jupiter.api.Test; + +import static com.mongodb.client.model.VectorSearchIndexFields.autoEmbedField; +import static com.mongodb.client.model.VectorSearchIndexFields.filterField; +import static com.mongodb.client.model.VectorSearchIndexFields.vectorField; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +final class VectorSearchIndexFieldsTest { + + @Test + void vectorFieldMinimal() { + assertEquals( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("vec")), + vectorField("vec").toBsonDocument() + ); + } + + @Test + void vectorFieldAllOptions() { + assertEquals( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("embedding")) + .append("numDimensions", new BsonInt32(1536)) + .append("similarity", new BsonString("cosine")) + .append("indexingMethod", new BsonString("hnsw")) + .append("hnswOptions", new BsonDocument("maxEdges", new BsonInt32(16))), + vectorField("embedding") + .numDimensions(1536) + .similarity("cosine") + .indexingMethod("hnsw") + .hnswOptions(new HnswSearchIndexOptions().maxEdges(16)) + .toBsonDocument() + ); + } + + @Test + void vectorFieldWithRawBsonHnswOptions() { + assertEquals( + new BsonDocument("type", new BsonString("vector")) + .append("path", new BsonString("vec")) + .append("indexingMethod", new BsonString("hnsw")) + .append("hnswOptions", new BsonDocument("maxEdges", new BsonInt32(32))), + vectorField("vec") + .indexingMethod("hnsw") + .hnswOptions(new Document("maxEdges", 32)) + .toBsonDocument() + ); + } + + @Test + void vectorFieldNumDimensionsRejectsZero() { + assertThrows(IllegalArgumentException.class, () -> vectorField("vec").numDimensions(0)); + } + + @Test + void vectorFieldNumDimensionsRejectsNegative() { + assertThrows(IllegalArgumentException.class, () -> vectorField("vec").numDimensions(-1)); + } + + @Test + void vectorFieldRejectsNullPath() { + assertThrows(IllegalArgumentException.class, () -> vectorField(null)); + } + + @Test + void vectorFieldRejectsNullSimilarity() { + assertThrows(IllegalArgumentException.class, () -> vectorField("vec").similarity(null)); + } + + @Test + void vectorFieldRejectsNullIndexingMethod() { + assertThrows(IllegalArgumentException.class, () -> vectorField("vec").indexingMethod(null)); + } + + @Test + void vectorFieldRejectsNullHnswOptions() { + assertThrows(IllegalArgumentException.class, () -> vectorField("vec").hnswOptions(null)); + } + + @Test + void filterFieldProducesCorrectBson() { + assertEquals( + new BsonDocument("type", new BsonString("filter")) + .append("path", new BsonString("status")), + filterField("status").toBsonDocument() + ); + } + + @Test + void filterFieldRejectsNullPath() { + assertThrows(IllegalArgumentException.class, () -> filterField(null)); + } + + @Test + void autoEmbedFieldMinimal() { + assertEquals( + new BsonDocument("type", new BsonString("autoEmbed")) + .append("path", new BsonString("content")) + .append("modality", new BsonString("text")) + .append("model", new BsonString("voyage-4")), + autoEmbedField("content").modality("text").model("voyage-4").toBsonDocument() + ); + } + + @Test + void autoEmbedFieldRejectsMissingModality() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("content").model("voyage-4").toBsonDocument()); + } + + @Test + void autoEmbedFieldRejectsMissingModel() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("content").modality("text").toBsonDocument()); + } + + @Test + void autoEmbedFieldAllOptions() { + assertEquals( + new BsonDocument("type", new BsonString("autoEmbed")) + .append("path", new BsonString("product.description")) + .append("modality", new BsonString("text")) + .append("model", new BsonString("voyage-4-large")) + .append("numDimensions", new BsonInt32(256)) + .append("quantization", new BsonString("binary")) + .append("similarity", new BsonString("euclidean")) + .append("indexingMethod", new BsonString("hnsw")) + .append("hnswOptions", new BsonDocument("maxEdges", new BsonInt32(16))), + autoEmbedField("product.description") + .modality("text") + .model("voyage-4-large") + .numDimensions(256) + .quantization("binary") + .similarity("euclidean") + .indexingMethod("hnsw") + .hnswOptions(new HnswSearchIndexOptions().maxEdges(16)) + .toBsonDocument() + ); + } + + @Test + void autoEmbedFieldNumDimensionsRejectsZero() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").numDimensions(0)); + } + + @Test + void autoEmbedFieldNumDimensionsRejectsNegative() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").numDimensions(-1)); + } + + @Test + void autoEmbedFieldRejectsNullPath() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField(null)); + } + + @Test + void autoEmbedFieldRejectsNullModality() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").modality(null)); + } + + @Test + void autoEmbedFieldRejectsNullModel() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").model(null)); + } + + @Test + void autoEmbedFieldRejectsNullQuantization() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").quantization(null)); + } + + @Test + void autoEmbedFieldRejectsNullSimilarity() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").similarity(null)); + } + + @Test + void autoEmbedFieldRejectsNullIndexingMethod() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").indexingMethod(null)); + } + + @Test + void autoEmbedFieldRejectsNullHnswOptions() { + assertThrows(IllegalArgumentException.class, () -> autoEmbedField("text").hnswOptions(null)); + } +} diff --git a/driver-core/src/test/unit/com/mongodb/client/model/search/BinaryVectorSearchOptionsTest.java b/driver-core/src/test/unit/com/mongodb/client/model/search/BinaryVectorSearchOptionsTest.java index 1fde037dbef..952974b8edd 100644 --- a/driver-core/src/test/unit/com/mongodb/client/model/search/BinaryVectorSearchOptionsTest.java +++ b/driver-core/src/test/unit/com/mongodb/client/model/search/BinaryVectorSearchOptionsTest.java @@ -110,6 +110,30 @@ void optionsExact() { ); } + @Test + void returnStoredSourceApproximate() { + assertEquals( + new BsonDocument() + .append("returnStoredSource", new BsonBoolean(true)) + .append("numCandidates", new BsonInt64(1)), + VectorSearchOptions.approximateVectorSearchOptions(1) + .returnStoredSource(true) + .toBsonDocument() + ); + } + + @Test + void returnStoredSourceExact() { + assertEquals( + new BsonDocument() + .append("returnStoredSource", new BsonBoolean(true)) + .append("exact", new BsonBoolean(true)), + VectorSearchOptions.exactVectorSearchOptions() + .returnStoredSource(true) + .toBsonDocument() + ); + } + @Test void approximateVectorSearchOptionsIsUnmodifiable() { String expected = VectorSearchOptions.approximateVectorSearchOptions(1).toBsonDocument().toJson(); diff --git a/driver-core/src/test/unit/com/mongodb/client/model/search/SearchOperatorTest.java b/driver-core/src/test/unit/com/mongodb/client/model/search/SearchOperatorTest.java index ccf5a44cd1f..88cbad0fc42 100644 --- a/driver-core/src/test/unit/com/mongodb/client/model/search/SearchOperatorTest.java +++ b/driver-core/src/test/unit/com/mongodb/client/model/search/SearchOperatorTest.java @@ -16,6 +16,7 @@ package com.mongodb.client.model.search; import com.mongodb.MongoClientSettings; +import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.geojson.Point; import com.mongodb.client.model.geojson.Position; import org.bson.BsonArray; @@ -1002,6 +1003,129 @@ void regex() { ); } + @Test + void vectorSearch() { + assertAll( + () -> assertThrows(IllegalArgumentException.class, () -> + // path must not be null + SearchOperator.vectorSearch(null, asList(1.0), 10, 50) + ), + () -> assertThrows(IllegalArgumentException.class, () -> + // queryVector must not be null + SearchOperator.vectorSearch(fieldPath("embedding"), null, 10, 50) + ), + () -> assertThrows(IllegalArgumentException.class, () -> + // numCandidates must be >= limit + SearchOperator.vectorSearch(fieldPath("embedding"), asList(1.0), 100, 50) + ), + () -> assertEquals( + new BsonDocument("vectorSearch", + new BsonDocument("path", new BsonString("embedding")) + .append("queryVector", new BsonArray(asList( + new BsonDouble(1.0), new BsonDouble(2.0), new BsonDouble(3.0)))) + .append("limit", new BsonInt32(10)) + .append("numCandidates", new BsonInt32(100))), + SearchOperator.vectorSearch( + fieldPath("embedding"), + asList(1.0, 2.0, 3.0), + 10, + 100 + ).toBsonDocument() + ), + () -> assertEquals( + new BsonDocument("vectorSearch", + new BsonDocument("path", new BsonString("embedding")) + .append("queryVector", new BsonArray(asList( + new BsonDouble(1.0), new BsonDouble(2.0)))) + .append("limit", new BsonInt32(10)) + .append("numCandidates", new BsonInt32(50)) + .append("filter", new BsonDocument("text", + new BsonDocument("query", new BsonString("hello")) + .append("path", new BsonString("title")))) + .append("score", new BsonDocument("boost", + new BsonDocument("value", new BsonDouble(2.0))))), + SearchOperator.vectorSearch( + fieldPath("embedding"), + asList(1.0, 2.0), + 10, + 50 + ).filter(SearchOperator.text(fieldPath("title"), "hello")) + .score(boost(2f)) + .toBsonDocument() + ) + ); + } + + @Test + void vectorSearchExact() { + assertAll( + () -> assertThrows(IllegalArgumentException.class, () -> + // path must not be null + SearchOperator.vectorSearchExact(null, asList(1.0), 10) + ), + () -> assertThrows(IllegalArgumentException.class, () -> + // queryVector must not be null + SearchOperator.vectorSearchExact(fieldPath("embedding"), null, 10) + ), + () -> assertEquals( + new BsonDocument("vectorSearch", + new BsonDocument("path", new BsonString("embedding")) + .append("queryVector", new BsonArray(asList( + new BsonDouble(1.0), new BsonDouble(2.0), new BsonDouble(3.0)))) + .append("limit", new BsonInt32(5)) + .append("exact", BsonBoolean.TRUE)), + SearchOperator.vectorSearchExact( + fieldPath("embedding"), + asList(1.0, 2.0, 3.0), + 5 + ).toBsonDocument() + ), + () -> assertEquals( + new BsonDocument("vectorSearch", + new BsonDocument("path", new BsonString("embedding")) + .append("queryVector", new BsonArray(asList( + new BsonDouble(1.0), new BsonDouble(2.0)))) + .append("limit", new BsonInt32(10)) + .append("exact", BsonBoolean.TRUE) + .append("filter", new BsonDocument("text", + new BsonDocument("query", new BsonString("hello")) + .append("path", new BsonString("title")))) + .append("score", new BsonDocument("boost", + new BsonDocument("value", new BsonDouble(2.0))))), + SearchOperator.vectorSearchExact( + fieldPath("embedding"), + asList(1.0, 2.0), + 10 + ).filter(SearchOperator.text(fieldPath("title"), "hello")) + .score(boost(2f)) + .toBsonDocument() + ) + ); + } + + @Test + void vectorSearchInsideSearchStage() { + assertEquals( + new BsonDocument("$search", + new BsonDocument("index", new BsonString("myIndex")) + .append("vectorSearch", + new BsonDocument("path", new BsonString("embedding")) + .append("queryVector", new BsonArray(asList( + new BsonDouble(1.0), new BsonDouble(2.0), new BsonDouble(3.0)))) + .append("limit", new BsonInt32(10)) + .append("numCandidates", new BsonInt32(100)))), + Aggregates.search( + SearchOperator.vectorSearch( + fieldPath("embedding"), + asList(1.0, 2.0, 3.0), + 10, + 100 + ), + SearchOptions.searchOptions().index("myIndex") + ).toBsonDocument() + ); + } + private static SearchOperator docExamplePredefined() { return SearchOperator.exists( fieldPath("fieldName")); diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java index b988f1cde1a..d5128be8dc9 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java @@ -22,6 +22,7 @@ import org.bson.BsonSerializationException; import org.bson.ByteBuf; import org.bson.ByteBufNIO; +import org.bson.io.BasicOutputBuffer; import org.bson.io.ByteBufferBsonInput; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.params.ParameterizedTest; @@ -45,6 +46,7 @@ import static java.util.stream.Collectors.toList; import static java.util.stream.IntStream.range; import static java.util.stream.IntStream.rangeClosed; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -710,6 +712,58 @@ void shouldReadSkipCStringWhenMultipleNullTerminatorPresentWithinBuffer(final Bu } + @ParameterizedTest(name = "should pipe bytes to output. BufferProvider={0}") + @MethodSource("bufferProviders") + void shouldPipeBytesToOutput(final BufferProvider bufferProvider) { + // given + byte[] input = "Java!".getBytes(StandardCharsets.UTF_8); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, input); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer); + BasicOutputBuffer bufferOutput = new BasicOutputBuffer()) { + // when + bufferInput.pipe(bufferOutput, input.length); + + // then + assertEquals(input.length, bufferInput.getPosition()); + assertEquals(input.length, bufferOutput.getPosition()); + assertArrayEquals(input, bufferOutput.toByteArray()); + } + } + + @ParameterizedTest(name = "should pipe partial bytes to output. BufferProvider={0}") + @MethodSource("bufferProviders") + void shouldPipePartialBytesToOutput(final BufferProvider bufferProvider) { + // given + byte[] input = "Java!".getBytes(StandardCharsets.UTF_8); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, input); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer); + BasicOutputBuffer output = new BasicOutputBuffer()) { + // when + bufferInput.pipe(output, 3); + + // then + assertEquals(3, bufferInput.getPosition()); + assertEquals(3, output.getPosition()); + assertArrayEquals("Jav".getBytes(StandardCharsets.UTF_8), output.toByteArray()); + } + } + + @ParameterizedTest(name = "should throw when piping more bytes than available. BufferProvider={0}") + @MethodSource("bufferProviders") + void shouldThrowWhenPipingMoreBytesThanAvailable(final BufferProvider bufferProvider) { + // given + byte[] input = "Jav".getBytes(StandardCharsets.UTF_8); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, input); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer); + BasicOutputBuffer output = new BasicOutputBuffer()) { + // when & then + assertThrows(BsonSerializationException.class, () -> bufferInput.pipe(output, 10)); + } + } + private static ByteBuf allocateAndWriteToBuffer(final BufferProvider bufferProvider, final byte[] input) { ByteBuf buffer = bufferProvider.getBuffer(input.length); buffer.put(input, 0, input.length); diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java index bf2e6205ad6..1ce8c844459 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/unified/MicrometerTracingTest.java @@ -16,12 +16,21 @@ package com.mongodb.reactivestreams.client.unified; +import com.mongodb.client.unified.UnifiedTestModifications.TestDef; import org.junit.jupiter.params.provider.Arguments; import java.util.Collection; +import static com.mongodb.client.Fixture.getMongoClient; + final class MicrometerTracingTest extends UnifiedReactiveStreamsTest { private static Collection data() { return getTestData("open-telemetry/tests"); } + + @Override + protected void postCleanUp(final TestDef testDef) { + super.postCleanUp(testDef); + getEntities().getDatabaseNames().forEach(name -> getMongoClient().getDatabase(name).drop()); + } } diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala index c7b8d120cf7..31c8c65ec79 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala @@ -18,10 +18,12 @@ package org.mongodb.scala.model import com.mongodb.annotations.{ Beta, Reason } import com.mongodb.client.model.fill.FillOutputField -import com.mongodb.client.model.search.FieldSearchPath +import com.mongodb.client.model.search.{ FieldSearchPath, VectorSearchQuery } +import org.bson.BinaryVector import scala.collection.JavaConverters._ import com.mongodb.client.model.{ Aggregates => JAggregates } +import com.mongodb.client.model.RerankQuery import org.mongodb.scala.MongoNamespace import org.mongodb.scala.bson.conversions.Bson import org.mongodb.scala.model.densify.{ DensifyOptions, DensifyRange } @@ -739,12 +741,115 @@ object Aggregates { */ def vectorSearch( path: FieldSearchPath, - queryVector: Iterable[java.lang.Double], + queryVector: Iterable[Double], index: String, limit: Long, options: VectorSearchOptions ): Bson = - JAggregates.vectorSearch(path, queryVector.asJava, index, limit, options) + JAggregates.vectorSearch( + path, + queryVector.asInstanceOf[Iterable[java.lang.Double]].asJava, + index, + limit, + options + ) + + /** + * Creates a `\$vectorSearch` pipeline stage supported by MongoDB Atlas. + * You may use the `\$meta: "vectorSearchScore"` expression, e.g., via [[Projections.metaVectorSearchScore]], + * to extract the relevance score assigned to each found document. + * + * @param path The field to be searched. + * @param queryVector The `BinaryVector` query vector. The number of dimensions must match that of the `index`. + * @param index The name of the index to use. + * @param limit The limit on the number of documents produced by the pipeline stage. + * @param options Optional `\$vectorSearch` pipeline stage fields. + * @return The `\$vectorSearch` pipeline stage. + * @see [[https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ \$vectorSearch]] + * @note Requires MongoDB 6.0.10 or greater + * @since 5.8 + */ + def vectorSearch( + path: FieldSearchPath, + queryVector: BinaryVector, + index: String, + limit: Long, + options: VectorSearchOptions + ): Bson = + JAggregates.vectorSearch(path, queryVector, index, limit, options) + + /** + * Creates a `\$vectorSearch` pipeline stage supported by MongoDB Atlas with automated embedding. + * You may use the `\$meta: "vectorSearchScore"` expression, e.g., via [[Projections.metaVectorSearchScore]], + * to extract the relevance score assigned to each found document. + * + * This overload is used for auto-embedding in Atlas. The server will automatically generate embeddings + * for the query using the model specified in the index definition or via + * `TextVectorSearchQuery.model`. + * + * @param path The field to be searched. + * @param query The query specification, typically created via `VectorSearchQuery.textQuery`. + * @param index The name of the index to use. + * @param limit The limit on the number of documents produced by the pipeline stage. + * @param options Optional `\$vectorSearch` pipeline stage fields. + * @return The `\$vectorSearch` pipeline stage. + * @see [[https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ \$vectorSearch]] + * @note Requires MongoDB 6.0.10 or greater + * @since 5.8 + */ + @Beta(Array(Reason.SERVER)) + def vectorSearch( + path: FieldSearchPath, + query: VectorSearchQuery, + index: String, + limit: Long, + options: VectorSearchOptions + ): Bson = + JAggregates.vectorSearch(path, query, index, limit, options) + + /** + * Creates a `\$rerank` pipeline stage supported by MongoDB Atlas. + * You may use the `\$meta: "score"` expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against. + * @param path The document field to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. + * @return The `\$rerank` pipeline stage. + * @note Requires MongoDB on Atlas 8.3 or greater + * @since 5.8 + */ + @Beta(Array(Reason.SERVER)) + def rerank( + query: RerankQuery, + path: String, + numDocsToRerank: Int, + model: String + ): Bson = + JAggregates.rerank(query, path, numDocsToRerank, model) + + /** + * Creates a `\$rerank` pipeline stage supported by MongoDB Atlas. + * You may use the `\$meta: "score"` expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against. + * @param paths The document field(s) to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. + * @return The `\$rerank` pipeline stage. + * @note Requires MongoDB on Atlas 8.3 or greater + * @since 5.8 + */ + @Beta(Array(Reason.SERVER)) + def rerank( + query: RerankQuery, + paths: Seq[String], + numDocsToRerank: Int, + model: String + ): Bson = + JAggregates.rerank(query, paths.toList.asJava, numDocsToRerank, model) /** * Creates an `\$unset` pipeline stage that removes/excludes fields from documents diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala index 0d23a38c2e8..7a920092581 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala @@ -481,6 +481,99 @@ package object model { */ type SearchIndexModel = com.mongodb.client.model.SearchIndexModel + /** + * A definition for an Atlas Search index. + * @since 5.8 + */ + type SearchIndexDefinition = com.mongodb.client.model.SearchIndexDefinition + + /** + * Companion object providing Scala-friendly factories for [[SearchIndexDefinition]]. + * @since 5.8 + */ + object SearchIndexDefinition { + + /** + * Creates a vector search index definition with the specified fields. + * + * @param field the first field for the vector search index. + * @param fields additional fields for the vector search index. + * @return a new [[VectorSearchIndexDefinition]] + */ + def vectorSearch(field: Bson, fields: Bson*): VectorSearchIndexDefinition = + com.mongodb.client.model.SearchIndexDefinition.vectorSearch(field +: fields: _*) + + /** + * Creates a vector search index definition with the specified fields. + * + * @param fields the fields for the vector search index. + * @return a new [[VectorSearchIndexDefinition]] + */ + def vectorSearch(fields: Seq[_ <: Bson]): VectorSearchIndexDefinition = { + com.mongodb.client.model.SearchIndexDefinition.vectorSearch(fields.asJava) + } + } + + /** + * A vector search index definition. + * @since 5.8 + */ + type VectorSearchIndexDefinition = com.mongodb.client.model.VectorSearchIndexDefinition + + /** + * A factory for defining fields within a vector search index definition. + * @since 5.8 + */ + type VectorSearchIndexFields = com.mongodb.client.model.VectorSearchIndexFields + + /** + * Companion object providing Scala-friendly factories for [[VectorSearchIndexFields]]. + * @since 5.8 + */ + object VectorSearchIndexFields { + + /** + * Creates a vector field definition for a vector search index. + * + * @param path the field path in the document + * @return a new `VectorSearchIndexFields.VectorField` + */ + def vectorField(path: String): com.mongodb.client.model.VectorSearchIndexFields.VectorField = + com.mongodb.client.model.VectorSearchIndexFields.vectorField(path) + + /** + * Creates a filter field definition for a vector search index. + * + * @param path the field path in the document + * @return a new `VectorSearchIndexFields.FilterField` + */ + def filterField(path: String): com.mongodb.client.model.VectorSearchIndexFields.FilterField = + com.mongodb.client.model.VectorSearchIndexFields.filterField(path) + + /** + * Creates an auto-embed field definition for a vector search index. + * + * @param path the field path in the document containing the content to embed + * @return a new `VectorSearchIndexFields.AutoEmbedField` + */ + def autoEmbedField(path: String): com.mongodb.client.model.VectorSearchIndexFields.AutoEmbedField = + com.mongodb.client.model.VectorSearchIndexFields.autoEmbedField(path) + } + + /** + * Options for the HNSW indexing method in a vector search index. + * @since 5.8 + */ + type HnswSearchIndexOptions = com.mongodb.client.model.HnswSearchIndexOptions + + /** + * Companion object providing a Scala-friendly factory for [[HnswSearchIndexOptions]]. + * @since 5.8 + */ + object HnswSearchIndexOptions { + def apply(): HnswSearchIndexOptions = new com.mongodb.client.model.HnswSearchIndexOptions() + } + /** * Represents an Atlas Search Index type, which is utilized for creating specific types of indexes. */ @@ -513,6 +606,19 @@ package object model { def apply(indexName: String, definition: Bson): SearchIndexModel = new com.mongodb.client.model.SearchIndexModel(indexName, definition) + /** + * Construct a vector search index instance with the given name and definition. + * + * The index type is automatically set to `vectorSearch`. + * + * @param indexName the name of the search index to create. + * @param definition the vector search index definition. + * @return the SearchIndexModel + * @since 5.8 + */ + def apply(indexName: String, definition: VectorSearchIndexDefinition): SearchIndexModel = + new com.mongodb.client.model.SearchIndexModel(indexName, definition) + /** * Construct an instance with the given search index name and definition. * @@ -987,6 +1093,8 @@ package object model { type GeoNearOptions = com.mongodb.client.model.GeoNearOptions + type RerankQuery = com.mongodb.client.model.RerankQuery + /** * @see `QuantileMethod.approximate()` */ diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/search/SearchOperator.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/search/SearchOperator.scala index 1fa47a54e1b..82b5ea4a05f 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/search/SearchOperator.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/search/SearchOperator.scala @@ -495,6 +495,44 @@ object SearchOperator { def regex(paths: Iterable[_ <: SearchPath], queries: Iterable[String]): RegexSearchOperator = JSearchOperator.regex(paths.asJava, queries.asJava) + /** + * Returns a `SearchOperator` that performs vector search within the `\$search` pipeline stage. + * This is the approximate (ANN) variant with `numCandidates`. + * + * @param path The indexed vector field to search. + * @param queryVector The query vector. The number of dimensions must match the index field. + * @param limit The number of results to return. + * @param numCandidates The number of nearest neighbors to consider during ANN search. + * Must be greater than or equal to `limit`. The server may impose an upper bound. + * @return The requested `VectorSearchOperator`. + * @see [[https://www.mongodb.com/docs/atlas/atlas-search/vector-search/ vectorSearch operator]] + * @since 5.8 + */ + def vectorSearch( + path: FieldSearchPath, + queryVector: Iterable[Double], + limit: Int, + numCandidates: Int + ): VectorSearchOperator = + JSearchOperator.vectorSearch(path, queryVector.map(Double.box).asJava, limit, numCandidates) + + /** + * Returns a `SearchOperator` that performs exact (ENN) vector search within the `\$search` pipeline stage. + * + * @param path The indexed vector field to search. + * @param queryVector The query vector. The number of dimensions must match the index field. + * @param limit The number of results to return. + * @return The requested `VectorSearchOperator`. + * @see [[https://www.mongodb.com/docs/atlas/atlas-search/vector-search/ vectorSearch operator]] + * @since 5.8 + */ + def vectorSearchExact( + path: FieldSearchPath, + queryVector: Iterable[Double], + limit: Int + ): VectorSearchOperator = + JSearchOperator.vectorSearchExact(path, queryVector.map(Double.box).asJava, limit) + /** * Creates a `SearchOperator` from a `Bson` in situations when there is no builder method that better satisfies your needs. * This method cannot be used to validate the syntax. diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/search/package.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/search/package.scala index baa454b1ee7..01ddaffb29f 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/search/package.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/search/package.scala @@ -234,6 +234,16 @@ package object search { @Beta(Array(Reason.CLIENT)) type QueryStringSearchOperator = com.mongodb.client.model.search.QueryStringSearchOperator + /** + * A `SearchOperator` that performs vector search within the `\$search` pipeline stage. + * + * @see `SearchOperator.vectorSearch` + * @since 5.8 + */ + @Sealed + @Beta(Array(Reason.CLIENT)) + type VectorSearchOperator = com.mongodb.client.model.search.VectorSearchOperator + /** * Fuzzy search options that may be used with some [[SearchOperator]]s. * diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala index d5a38ad7bca..4969b149699 100644 --- a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala @@ -37,6 +37,10 @@ import org.mongodb.scala.model.geojson.{ Point, Position } import org.mongodb.scala.model.search.SearchCount.total import org.mongodb.scala.model.search.SearchFacet.stringFacet import org.mongodb.scala.model.search.SearchHighlight.paths +import com.mongodb.client.model.{ Aggregates => JAggregates } +import com.mongodb.client.model.RerankQuery +import com.mongodb.client.model.search.VectorSearchQuery +import org.bson.BinaryVector import org.mongodb.scala.model.search.SearchCollector import org.mongodb.scala.model.search.SearchOperator.exists import org.mongodb.scala.model.search.SearchOptions.searchOptions @@ -816,6 +820,94 @@ class AggregatesSpec extends BaseSpec { ) } + it should "render $vectorSearch with BinaryVector" in { + toBson( + Aggregates.vectorSearch( + fieldPath("fieldName"), + BinaryVector.int8Vector(Array[Byte](0, 1, 2, 3, 4)), + "indexName", + 1, + approximateVectorSearchOptions(2) + ) + ) should equal( + toBson( + JAggregates.vectorSearch( + fieldPath("fieldName"), + BinaryVector.int8Vector(Array[Byte](0, 1, 2, 3, 4)), + "indexName", + 1, + approximateVectorSearchOptions(2) + ) + ) + ) + } + + it should "render $vectorSearch with VectorSearchQuery" in { + toBson( + Aggregates.vectorSearch( + fieldPath("fieldName"), + VectorSearchQuery.textQuery("sample text"), + "indexName", + 1, + approximateVectorSearchOptions(2) + ) + ) should equal( + toBson( + JAggregates.vectorSearch( + fieldPath("fieldName"), + VectorSearchQuery.textQuery("sample text"), + "indexName", + 1, + approximateVectorSearchOptions(2) + ) + ) + ) + } + + it should "render $rerank with single path" in { + toBson( + Aggregates.rerank( + RerankQuery.rerankQuery("machine learning"), + "content", + 25, + "rerank-2.5" + ) + ) should equal( + Document( + """{ + "$rerank": { + "query": {"text": "machine learning"}, + "path": "content", + "numDocsToRerank": 25, + "model": "rerank-2.5" + } + }""" + ) + ) + } + + it should "render $rerank with multiple paths" in { + toBson( + Aggregates.rerank( + RerankQuery.rerankQuery("machine learning"), + List("content", "title"), + 50, + "rerank-2.5-lite" + ) + ) should equal( + Document( + """{ + "$rerank": { + "query": {"text": "machine learning"}, + "path": ["content", "title"], + "numDocsToRerank": 50, + "model": "rerank-2.5-lite" + } + }""" + ) + ) + } + it should "render $unset" in { toBson( Aggregates.unset("title", "author.first") diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/SearchIndexDefinitionSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/SearchIndexDefinitionSpec.scala new file mode 100644 index 00000000000..efbe29917f7 --- /dev/null +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/SearchIndexDefinitionSpec.scala @@ -0,0 +1,104 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.mongodb.scala.model + +import java.lang.reflect.Modifier._ + +import org.bson.{ BsonDocument, BsonString } +import org.mongodb.scala.bson.collection.immutable.Document +import org.mongodb.scala.bson.conversions.Bson +import org.mongodb.scala.model.SearchIndexDefinition._ +import org.mongodb.scala.model.VectorSearchIndexFields._ +import org.mongodb.scala.{ model, BaseSpec, MongoClient } + +class SearchIndexDefinitionSpec extends BaseSpec { + + def toBson(bson: Bson): Document = + Document(bson.toBsonDocument(classOf[BsonDocument], MongoClient.DEFAULT_CODEC_REGISTRY)) + + "SearchIndexDefinition" should "have the same methods as the wrapped SearchIndexDefinition" in { + val wrapped = classOf[com.mongodb.client.model.SearchIndexDefinition].getDeclaredMethods + .filter(f => isStatic(f.getModifiers) && isPublic(f.getModifiers)) + .map(_.getName) + .toSet + val local = model.SearchIndexDefinition.getClass.getDeclaredMethods + .filter(f => isPublic(f.getModifiers)) + .map(_.getName) + .toSet -- DEFAULT_EXCLUSIONS + + local should equal(wrapped) + } + + it should "create a vectorSearch definition with varargs" in { + toBson( + vectorSearch( + vectorField("plot_embedding").numDimensions(1536).similarity("euclidean"), + filterField("genre") + ) + ) should equal( + Document( + """{"fields": [ + |{"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "euclidean"}, + |{"type": "filter", "path": "genre"} + |]}""".stripMargin.replaceAll("\n", " ") + ) + ) + } + + it should "create a vectorSearch definition with a Seq" in { + toBson( + vectorSearch( + Seq( + vectorField("embedding").numDimensions(768).similarity("cosine"), + filterField("category") + ) + ) + ) should equal( + Document( + """{"fields": [ + |{"type": "vector", "path": "embedding", "numDimensions": 768, "similarity": "cosine"}, + |{"type": "filter", "path": "category"} + |]}""".stripMargin.replaceAll("\n", " ") + ) + ) + } + + it should "create a vectorSearch definition with storedSource" in { + toBson( + vectorSearch( + vectorField("embedding").numDimensions(1536).similarity("cosine") + ).storedSource(Document("include" -> List("plot", "title"))) + ) should equal( + Document( + """{"fields": [ + |{"type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine"} + |], "storedSource": {"include": ["plot", "title"]}}""".stripMargin.replaceAll("\n", " ") + ) + ) + } + + it should "create a SearchIndexModel with VectorSearchIndexDefinition" in { + val definition = vectorSearch( + vectorField("embedding").numDimensions(1536).similarity("cosine") + ) + val model = SearchIndexModel("my_index", definition) + + model.getName should equal("my_index") + model.getDefinition should equal(definition) + model.getType.toBsonValue should equal(new BsonString("vectorSearch")) + } +} diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/VectorSearchIndexFieldsSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/VectorSearchIndexFieldsSpec.scala new file mode 100644 index 00000000000..72da98a2639 --- /dev/null +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/VectorSearchIndexFieldsSpec.scala @@ -0,0 +1,106 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.mongodb.scala.model + +import java.lang.reflect.Modifier._ + +import org.bson.BsonDocument +import org.mongodb.scala.bson.collection.immutable.Document +import org.mongodb.scala.bson.conversions.Bson +import org.mongodb.scala.model.VectorSearchIndexFields._ +import org.mongodb.scala.{ model, BaseSpec, MongoClient } + +class VectorSearchIndexFieldsSpec extends BaseSpec { + + def toBson(bson: Bson): Document = + Document(bson.toBsonDocument(classOf[BsonDocument], MongoClient.DEFAULT_CODEC_REGISTRY)) + + "VectorSearchIndexFields" should "have the same methods as the wrapped VectorSearchIndexFields" in { + val wrapped = classOf[com.mongodb.client.model.VectorSearchIndexFields].getDeclaredMethods + .filter(f => isStatic(f.getModifiers) && isPublic(f.getModifiers)) + .map(_.getName) + .toSet + val local = model.VectorSearchIndexFields.getClass.getDeclaredMethods + .filter(f => isPublic(f.getModifiers)) + .map(_.getName) + .toSet -- DEFAULT_EXCLUSIONS + + local should equal(wrapped) + } + + it should "create a vectorField with minimal options" in { + toBson(vectorField("vec")) should equal(Document("""{"type": "vector", "path": "vec"}""")) + } + + it should "create a vectorField with all options" in { + toBson( + vectorField("embedding") + .numDimensions(1536) + .similarity("cosine") + .indexingMethod("hnsw") + .hnswOptions(HnswSearchIndexOptions().maxEdges(16)) + ) should equal( + Document( + """{"type": "vector", "path": "embedding", "numDimensions": 1536, + |"similarity": "cosine", "indexingMethod": "hnsw", "hnswOptions": {"maxEdges": 16}}""".stripMargin + .replaceAll("\n", " ") + ) + ) + } + + it should "create a filterField" in { + toBson(filterField("status")) should equal(Document("""{"type": "filter", "path": "status"}""")) + } + + it should "create an autoEmbedField with minimal options" in { + toBson(autoEmbedField("content").modality("text").model("voyage-4")) should equal( + Document("""{"type": "autoEmbed", "path": "content", "modality": "text", "model": "voyage-4"}""") + ) + } + + it should "reject an autoEmbedField missing modality" in { + an[IllegalArgumentException] should be thrownBy { + toBson(autoEmbedField("content").model("voyage-4")) + } + } + + it should "reject an autoEmbedField missing model" in { + an[IllegalArgumentException] should be thrownBy { + toBson(autoEmbedField("content").modality("text")) + } + } + + it should "create an autoEmbedField with all options" in { + toBson( + autoEmbedField("product.description") + .modality("text") + .model("voyage-4-large") + .numDimensions(256) + .quantization("binary") + .similarity("euclidean") + .indexingMethod("hnsw") + .hnswOptions(HnswSearchIndexOptions().maxEdges(16)) + ) should equal( + Document( + """{"type": "autoEmbed", "path": "product.description", "modality": "text", + |"model": "voyage-4-large", "numDimensions": 256, "quantization": "binary", + |"similarity": "euclidean", "indexingMethod": "hnsw", "hnswOptions": {"maxEdges": 16}}""".stripMargin + .replaceAll("\n", " ") + ) + ) + } +} diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/search/SearchOperatorSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/search/SearchOperatorSpec.scala index 3d5481d8368..52795a63fe5 100644 --- a/driver-scala/src/test/scala/org/mongodb/scala/model/search/SearchOperatorSpec.scala +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/search/SearchOperatorSpec.scala @@ -28,7 +28,9 @@ import org.mongodb.scala.model.search.SearchOperator.{ exists, near, numberRange, - text + text, + vectorSearch, + vectorSearchExact } import org.mongodb.scala.model.search.SearchPath.{ fieldPath, wildcardPath } import org.mongodb.scala.model.search.SearchScore.function @@ -98,6 +100,38 @@ class SearchOperatorSpec extends BaseSpec { ) } + it should "render vectorSearch operator" in { + toDocument( + vectorSearch(fieldPath("embedding"), Seq(1.0, 2.0, 3.0), 10, 100) + ) should equal( + Document( + """{ "vectorSearch": { "path": "embedding", "queryVector": [1.0, 2.0, 3.0], "limit": 10, "numCandidates": 100 } }""" + ) + ) + } + + it should "render vectorSearchExact operator" in { + toDocument( + vectorSearchExact(fieldPath("embedding"), Seq(1.0, 2.0, 3.0), 5) + ) should equal( + Document( + """{ "vectorSearch": { "path": "embedding", "queryVector": [1.0, 2.0, 3.0], "limit": 5, "exact": true } }""" + ) + ) + } + + it should "render vectorSearch with filter and score" in { + toDocument( + vectorSearch(fieldPath("embedding"), Seq(1.0, 2.0), 10, 50) + .filter(text(fieldPath("title"), "hello")) + .score(SearchScore.boost(2f)) + ) should equal( + Document( + """{ "vectorSearch": { "path": "embedding", "queryVector": [1.0, 2.0], "limit": 10, "numCandidates": 50, "filter": { "text": { "query": "hello", "path": "title" } }, "score": { "boost": { "value": 2.0 } } } }""" + ) + ) + } + def toDocument(bson: Bson): Document = Document(bson.toBsonDocument(classOf[BsonDocument], MongoClient.DEFAULT_CODEC_REGISTRY)) } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 954ea29142f..cd95d0ef003 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -189,6 +189,12 @@ public MongoDatabase getDatabase(final String id) { return getEntity(id, databases, "database"); } + public Set getDatabaseNames() { + return databases.values().stream() + .map(MongoDatabase::getName) + .collect(Collectors.toSet()); + } + public boolean hasCollection(final String id) { return collections.containsKey(id); } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/MicrometerTracingTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/MicrometerTracingTest.java index 8c65317d257..3f37bcedbd9 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/MicrometerTracingTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/MicrometerTracingTest.java @@ -16,12 +16,21 @@ package com.mongodb.client.unified; +import com.mongodb.client.unified.UnifiedTestModifications.TestDef; import org.junit.jupiter.params.provider.Arguments; import java.util.Collection; +import static com.mongodb.client.Fixture.getMongoClient; + final class MicrometerTracingTest extends UnifiedSyncTest { private static Collection data() { return getTestData("open-telemetry/tests"); } + + @Override + protected void postCleanUp(final TestDef testDef) { + super.postCleanUp(testDef); + getEntities().getDatabaseNames().forEach(name -> getMongoClient().getDatabase(name).drop()); + } } diff --git a/gradle.properties b/gradle.properties index 086a36427b7..adafd03f486 100644 --- a/gradle.properties +++ b/gradle.properties @@ -14,7 +14,7 @@ # limitations under the License. # -version=5.8.0-SNAPSHOT +version=5.9.0-SNAPSHOT org.gradle.daemon=true org.gradle.jvmargs=-Dfile.encoding=UTF-8 -Duser.country=US -Duser.language=en diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 79d64ac5200..7686cc15c41 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -39,6 +39,7 @@ scala-v2-v11 = "2.11.12" # Test assertj = "3.24.2" +felix-framework = "7.0.5" aws-lambda-core = "1.2.2" aws-lambda-events = "3.11.1" cglib = "2.2.2" @@ -102,7 +103,7 @@ micrometer-observation = { module = "io.micrometer:micrometer-observation" } graal-sdk = { module = "org.graalvm.sdk:graal-sdk", version.ref = "graal-sdk" } graal-sdk-nativeimage = { module = "org.graalvm.sdk:nativeimage", version.ref = "graal-sdk" } -kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom" } +kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom", version.ref = "kotlin" } kotlin-stdlib-jdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8" } kotlinx-coroutines-bom = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-bom", version.ref = "kotlinx-coroutines-bom" } kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core" } @@ -174,6 +175,7 @@ aws-lambda-core = { module = " com.amazonaws:aws-lambda-java-core", version.ref aws-lambda-events = { module = " com.amazonaws:aws-lambda-java-events", version.ref = "aws-lambda-events" } cglib = { module = "cglib:cglib-nodep", version.ref = "cglib" } classgraph = { module = "io.github.classgraph:classgraph", version.ref = "classgraph" } +felix-framework = { module = "org.apache.felix:org.apache.felix.framework", version.ref = "felix-framework" } findbugs-jsr = { module = "com.google.code.findbugs:jsr305", version.ref = "findbugs-jsr" } groovy = { module = "org.codehaus.groovy:groovy-all", version.ref = "groovy" } hamcrest-all = { module = "org.hamcrest:hamcrest-all", version.ref = "hamcrest" } diff --git a/mongodb-crypt/build.gradle.kts b/mongodb-crypt/build.gradle.kts index a59ccefc02f..6a46f2dfc9a 100644 --- a/mongodb-crypt/build.gradle.kts +++ b/mongodb-crypt/build.gradle.kts @@ -16,6 +16,10 @@ import ProjectExtensions.configureJarManifest import ProjectExtensions.configureMavenPublication import de.undercouch.gradle.tasks.download.Download +import java.io.ByteArrayOutputStream +import javax.inject.Inject +import org.gradle.api.GradleException +import org.gradle.process.ExecOperations plugins { id("project.java") @@ -48,72 +52,256 @@ configureJarManifest { */ val jnaDownloadsDir = rootProject.file("build/jnaLibs/downloads/").path val jnaResourcesDir = rootProject.file("build/jnaLibs/resources/").path + +tasks.clean { delete(rootProject.file("build/jnaLibs")) } + val jnaLibPlatform: String = if (com.sun.jna.Platform.RESOURCE_PREFIX.startsWith("darwin")) "darwin" else com.sun.jna.Platform.RESOURCE_PREFIX -val jnaLibsPath: String = System.getProperty("jnaLibsPath", "${jnaResourcesDir}${jnaLibPlatform}") +// When -DjnaLibsPath is set, the user wants to use a pre-existing local copy of the libmongocrypt +// binaries instead of fetching them from the libmongocrypt GitHub release, so we skip the whole +// download / verify / extract chain. +val userSuppliedJnaLibsPath: String? = System.getProperty("jnaLibsPath") +val jnaLibsPath: String = userSuppliedJnaLibsPath ?: "${jnaResourcesDir}/${jnaLibPlatform}" val jnaResources: String = System.getProperty("jna.library.path", jnaLibsPath) -// Download jnaLibs that match the git tag or revision to jnaResourcesBuildDir -val downloadRevision = "1.17.3" -val binariesArchiveName = "libmongocrypt-java.tar.gz" +// Download the libmongocrypt per-platform tarballs (and their signatures) to jnaDownloadsDir. +// To upgrade: change downloadRevision, run `./gradlew clean downloadJnaLibs`, and verify the build. +val downloadRevision = "1.18.1" +val downloadUrlBase = "https://github.com/mongodb/libmongocrypt/releases/download/$downloadRevision" /** - * The name of the archive includes downloadRevision to ensure that: - * - the archive is downloaded if the revision changes. - * - the archive is not downloaded if the revision is the same and archive had already been saved in build output. + * Maps a JNA platform key (the directory consumed by `jna.library.path`) to the libmongocrypt GitHub release tarball + * that ships its native library, plus the path of that library inside the tarball. The tarball name and its internal + * layout differ per platform, so both must be tracked explicitly. + * + * libmongocrypt's signature assets replace the `.tar.gz` suffix with `.asc` (e.g. + * `libmongocrypt-linux-x86_64-glibc_2_7-nocrypto-1.18.1.asc`). */ -val localBinariesArchiveName = "libmongocrypt-java-$downloadRevision.tar.gz" - -val downloadUrl: String = - "https://mciuploads.s3.amazonaws.com/libmongocrypt/java/$downloadRevision/$binariesArchiveName" +data class CryptBinary(val jnaPlatform: String, val tarball: String, val libPathInTarball: String) { + val signature: String = tarball.removeSuffix(".tar.gz") + ".asc" +} -val jnaMapping: Map = - mapOf( - "rhel-62-64-bit" to "linux-x86-64", - "rhel72-zseries-test" to "linux-s390x", - "rhel-71-ppc64el" to "linux-ppc64le", - "ubuntu1604-arm64" to "linux-aarch64", - "windows-test" to "win32-x86-64", - "macos" to "darwin") +val cryptBinaries: List = + listOf( + CryptBinary( + "linux-x86-64", + "libmongocrypt-linux-x86_64-glibc_2_7-nocrypto-$downloadRevision.tar.gz", + "lib64/libmongocrypt.so"), + CryptBinary( + "linux-s390x", + "libmongocrypt-linux-s390x-glibc_2_7-nocrypto-$downloadRevision.tar.gz", + "lib64/libmongocrypt.so"), + CryptBinary( + "linux-ppc64le", + "libmongocrypt-linux-ppc64le-glibc_2_17-nocrypto-$downloadRevision.tar.gz", + "lib64/libmongocrypt.so"), + CryptBinary( + "linux-aarch64", + "libmongocrypt-linux-arm64-glibc_2_17-nocrypto-$downloadRevision.tar.gz", + "lib64/libmongocrypt.so"), + CryptBinary("win32-x86-64", "libmongocrypt-windows-x86_64-$downloadRevision.tar.gz", "bin/mongocrypt.dll"), + CryptBinary("darwin", "libmongocrypt-macos-universal-$downloadRevision.tar.gz", "lib/libmongocrypt.dylib")) sourceSets { main { java { resources { srcDirs(jnaResourcesDir) } } } } -tasks.register("downloadJava") { - src(downloadUrl) - dest("${jnaDownloadsDir}/$localBinariesArchiveName") - overwrite(true) - /* To make sure we don't download archive with binaries if it hasn't been changed in S3 bucket since last download.*/ - onlyIfModified(true) -} +/** + * Public key used to sign libmongocrypt release tarballs. See: + * https://www.mongodb.com/docs/manual/tutorial/verify-mongodb-packages/#std-label-verify-pkgs + */ +val libmongocryptPublicKeyUrl = "https://pgp.mongodb.com/libmongocrypt.pub" +val libmongocryptPublicKeyFile = "libmongocrypt.pub" + +tasks.register("downloadCryptLibs") { + src( + cryptBinaries.flatMap { listOf("$downloadUrlBase/${it.tarball}", "$downloadUrlBase/${it.signature}") } + + libmongocryptPublicKeyUrl) + dest(jnaDownloadsDir) + /* Reuse already-downloaded files. Useful for offline builds and reduces network churn. */ + overwrite(false) + quiet(true) + + /* Bypass entirely when the caller has supplied a local libmongocrypt directory. */ + onlyIf { userSuppliedJnaLibsPath == null } -tasks.register("unzipJava") { - /* - Clean up the directory first if the task is not UP-TO-DATE. - This can happen if the download revision has been changed and the archive is downloaded again. - */ doFirst { - println("Cleaning up $jnaResourcesDir") - delete(jnaResourcesDir) + val missing = cryptBinaries.filter { !file("$jnaDownloadsDir/${it.tarball}").exists() } + if (missing.isNotEmpty()) { + logger.lifecycle("Downloading libmongocrypt $downloadRevision binaries:") + missing.forEach { logger.lifecycle(" ${it.tarball}") } + } + } +} + +/* + * Verify the signature of every downloaded libmongocrypt tarball before extracting it. + * Per DRIVERS-3441, drivers that bundle libmongocrypt must verify GPG signatures of + * release tarballs against the official MongoDB libmongocrypt signing key. + * + * The keyring is kept under `build/` so this task does not touch the developer's + * system GPG keyring and so `./gradlew clean` resets the trust state. + */ +val skipCryptVerify = providers.gradleProperty("skipCryptVerify").map { it.toBoolean() }.orElse(false) + +abstract class VerifyLibmongocryptTask : DefaultTask() { + @get:Inject abstract val execOps: ExecOperations + + @get:InputFiles abstract val tarballs: ConfigurableFileCollection + @get:InputFiles abstract val signatures: ConfigurableFileCollection + @get:InputFile abstract val publicKey: RegularFileProperty + @get:Input abstract val skipVerify: Property + @get:Input abstract val expectedFingerprint: Property + @get:OutputFile abstract val verificationStamp: RegularFileProperty + + /* Scratch keyring directory. Marked @Internal (not @OutputDirectory) because the directory is + * genuinely ephemeral - nothing downstream consumes it. */ + @get:Internal abstract val gnupgHome: DirectoryProperty + + @TaskAction + fun verify() { + if (skipVerify.get()) { + logger.warn( + "SKIPPING libmongocrypt signature verification because -PskipCryptVerify=true was set. " + + "Do not use this for release builds.") + verificationStamp.get().asFile.writeText("Skipped verification at ${System.currentTimeMillis()}") + return + } + + try { + execOps.exec { + commandLine("gpg", "--version") + standardOutput = ByteArrayOutputStream() + } + } catch (e: Exception) { + throw GradleException( + "gpg is required to verify libmongocrypt tarballs since 1.18.0 but was not found on PATH. " + + "Install gpg (e.g. `apt-get install gnupg`, `brew install gnupg`, Gpg4win on Windows), " + + "or pass -PskipCryptVerify=true for offline development builds.", + e) + } + + val home = + gnupgHome.get().asFile.apply { + deleteRecursively() + mkdirs() + // GPG refuses to use a homedir with permissions broader than the owner. + setReadable(false, false) + setReadable(true, true) + setWritable(false, false) + setWritable(true, true) + setExecutable(false, false) + setExecutable(true, true) + } + + execOps.exec { + commandLine( + "gpg", + "--homedir", + home.path, + "--batch", + "--quiet", + "--no-autostart", + "--import", + publicKey.get().asFile.path) + standardOutput = ByteArrayOutputStream() + errorOutput = ByteArrayOutputStream() + } + + try { + execOps.exec { + commandLine( + "gpg", + "--homedir", + home.path, + "--batch", + "--no-autostart", + "--with-colons", + "--fingerprint", + expectedFingerprint.get()) + standardOutput = ByteArrayOutputStream() + errorOutput = ByteArrayOutputStream() + } + } catch (e: Exception) { + throw GradleException( + "Imported libmongocrypt signing key fingerprint does not match expected value " + + "${expectedFingerprint.get()}. The downloaded public key may have been rotated.", + e) + } + + // Pair tarballs with signatures by basename; ConfigurableFileCollection.files is an + // unordered Set, so zipping the two collections could mismatch pairs. + val signaturesByName = signatures.files.associateBy { it.name } + tarballs.files.forEach { tarball -> + val signatureName = tarball.name.removeSuffix(".tar.gz") + ".asc" + val signature = + signaturesByName[signatureName] + ?: throw GradleException( + "Missing signature $signatureName for ${tarball.name}; expected it next to the tarball.") + val verifyErr = ByteArrayOutputStream() + try { + execOps.exec { + commandLine( + "gpg", + "--homedir", + home.path, + "--batch", + "--quiet", + "--no-autostart", + "--trust-model", + "always", + "--verify", + signature.path, + tarball.path) + standardOutput = ByteArrayOutputStream() + errorOutput = verifyErr + } + } catch (e: Exception) { + throw GradleException( + "GPG signature verification failed for ${tarball.name}:\n${verifyErr.toString().trim()}", e) + } + } + + verificationStamp + .get() + .asFile + .writeText( + "verified=${System.currentTimeMillis()}\n" + "tarballs=${tarballs.files.joinToString { it.name }}\n") + } +} + +tasks.register("verifyCryptLibs") { + dependsOn("downloadCryptLibs") + tarballs.from(cryptBinaries.map { "$jnaDownloadsDir/${it.tarball}" }) + signatures.from(cryptBinaries.map { "$jnaDownloadsDir/${it.signature}" }) + publicKey.set(file("$jnaDownloadsDir/$libmongocryptPublicKeyFile")) + skipVerify.set(skipCryptVerify) + expectedFingerprint.set("F2F5BF4ABF517E039AFCADAA81F1404DEBACA586") + gnupgHome.set(layout.buildDirectory.dir("jnaLibs/gnupg")) + verificationStamp.set(layout.buildDirectory.file("jnaLibs/verified.stamp")) + + /* Bypass entirely when the caller has supplied a local libmongocrypt directory. */ + onlyIf { userSuppliedJnaLibsPath == null } +} + +tasks.register("extractCryptLibs") { + cryptBinaries.forEach { spec -> + from(tarTree(resources.gzip("$jnaDownloadsDir/${spec.tarball}"))) { + include(spec.libPathInTarball) + eachFile { path = "${spec.jnaPlatform}/${name}" } + includeEmptyDirs = false + } } - from(tarTree(resources.gzip("${jnaDownloadsDir}/$localBinariesArchiveName"))) - include( - jnaMapping.keys.flatMap { - listOf( - "${it}/nocrypto/**/libmongocrypt.so", "${it}/lib/**/libmongocrypt.dylib", "${it}/bin/**/mongocrypt.dll") - }) - eachFile { path = "${jnaMapping[path.substringBefore("/")]}/${name}" } into(jnaResourcesDir) - dependsOn("downloadJava") + dependsOn("downloadCryptLibs", "verifyCryptLibs") - doLast { println("jna.library.path contents: \n ${fileTree(jnaResourcesDir).files.joinToString(",\n ")}") } + /* Bypass entirely when the caller has supplied a local libmongocrypt directory. */ + onlyIf { userSuppliedJnaLibsPath == null } } // The `processResources` task (defined by the `java-library` plug-in) consumes files in the main -// source set. -// Add a dependency on `unzipJava`. `unzipJava` adds libmongocrypt libraries to the main source set. -tasks.processResources { mustRunAfter(tasks.named("unzipJava")) } +// source set. Extraction must complete first so the native libraries are present. +tasks.processResources { dependsOn("extractCryptLibs") } -tasks.register("downloadJnaLibs") { dependsOn("downloadJava", "unzipJava") } +tasks.register("downloadJnaLibs") { dependsOn("downloadCryptLibs", "verifyCryptLibs", "extractCryptLibs") } tasks.test { systemProperty("jna.debug_load", "true") @@ -122,10 +310,11 @@ tasks.test { testLogging { events("passed", "skipped", "failed") } doFirst { - println("jna.library.path contents:") - println(fileTree(jnaResources) { this.setIncludes(listOf("*.*")) }.files.joinToString(",\n ", " ")) + logger.lifecycle("jna.library.path contents:") + logger.lifecycle( + fileTree(jnaResources) { this.setIncludes(listOf("**/*.*")) }.files.joinToString(",\n ", " ")) } - dependsOn("downloadJnaLibs", "downloadJava", "unzipJava") + dependsOn("downloadJnaLibs") } tasks.withType { @@ -134,8 +323,13 @@ tasks.withType { | System properties: | ================= | - | jnaLibsPath : Custom local JNA library path for inclusion into the build (rather than downloading from s3) - | gitRevision : Optional Git Revision to download the built resources for from s3. + | jnaLibsPath : Custom local JNA library path to use at runtime (bypasses downloading/verifying/extracting libmongocrypt release artifacts). + | + | Project properties: + | =================== + | + | skipCryptVerify : Pass -PskipCryptVerify=true to skip GPG verification of downloaded libmongocrypt tarballs. + | Intended for offline development; do not use for release builds. """.trimMargin() } diff --git a/settings.gradle.kts b/settings.gradle.kts index 29d17792ad4..896e770724d 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -42,6 +42,7 @@ include(":driver-kotlin-sync") include(":driver-scala") include(":driver-benchmarks") +include(":testing:osgi-test") include(":driver-lambda") if (providers.gradleProperty("includeGraalvm").isPresent) { include(":graalvm-native-image-app") diff --git a/testing/osgi-test/build.gradle.kts b/testing/osgi-test/build.gradle.kts new file mode 100644 index 00000000000..253dd9feed0 --- /dev/null +++ b/testing/osgi-test/build.gradle.kts @@ -0,0 +1,65 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id("project.base") + id("checkstyle") + id("conventions.testing-base") +} + +java { + toolchain { languageVersion = JavaLanguageVersion.of(17) } +} + +dependencies { + testImplementation(platform(libs.junit.bom)) + testImplementation(libs.junit.jupiter) + testImplementation(libs.junit.jupiter.platform.launcher) + // AssertJ used here for infrastructure assertions (isDirectory, hasSize, containsExactly) + // which are significantly more readable than JUnit 5 equivalents for this test. + testImplementation(libs.assertj) + testImplementation(libs.felix.framework) + + // These JARs are scanned by buildSystemPackagesFromClasspath() to export packages + // from the Felix system bundle, satisfying non-optional imports from bundles under test. + testImplementation(libs.reactive.streams) + testImplementation(platform(libs.project.reactor.bom)) + testImplementation(libs.project.reactor.core) + testImplementation(platform(libs.kotlin.bom)) + testImplementation(libs.kotlin.stdlib.jdk8) + testImplementation(libs.kotlin.reflect) + testImplementation(platform(libs.kotlinx.coroutines.bom)) + testImplementation(libs.kotlinx.coroutines.core) + testImplementation(libs.kotlinx.coroutines.reactive) + testImplementation(libs.findbugs.jsr) + testImplementation(libs.jna) +} + +tasks.test { + dependsOn( + ":bson:jar", + ":bson-record-codec:jar", + ":mongodb-crypt:jar", + ":driver-core:jar", + ":bson-scala:jar", + ":driver-sync:jar", + ":driver-reactive-streams:jar", + ":driver-scala:jar", + ":driver-kotlin-sync:jar", + ":driver-kotlin-coroutine:jar", + ":driver-kotlin-extensions:jar" + ) + systemProperty("projectRoot", rootProject.projectDir.absolutePath) +} diff --git a/testing/osgi-test/src/test/java/com/mongodb/osgi/OsgiBundleResolutionTest.java b/testing/osgi-test/src/test/java/com/mongodb/osgi/OsgiBundleResolutionTest.java new file mode 100644 index 00000000000..a3697ce8597 --- /dev/null +++ b/testing/osgi-test/src/test/java/com/mongodb/osgi/OsgiBundleResolutionTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.osgi; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; +import java.util.jar.Manifest; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.felix.framework.FrameworkFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.osgi.framework.Bundle; +import org.osgi.framework.BundleContext; +import org.osgi.framework.BundleException; +import org.osgi.framework.FrameworkEvent; +import org.osgi.framework.launch.Framework; + +class OsgiBundleResolutionTest { + + private static final Path PROJECT_ROOT = Paths.get(System.getProperty("projectRoot", "../..")); + + // Listed in dependency order (leaves last) so that the first bundle.start() failure + // identifies the root cause rather than a cascading downstream resolution error. + private static final String[] BUNDLE_MODULES = { + "bson", + "bson-record-codec", + "mongodb-crypt", + "driver-core", + "bson-scala", + "driver-sync", + "driver-reactive-streams", + "driver-scala", + "driver-kotlin-sync", + "driver-kotlin-coroutine", + "driver-kotlin-extensions" + }; + + // JARs on the test classpath whose packages are exported from the Felix system bundle, + // satisfying non-optional imports from the bundles under test. + private static final String[] SYSTEM_PACKAGE_JAR_PREFIXES = { + "reactive-streams", + "reactor-core", + "kotlin-stdlib", + "kotlin-reflect", + "kotlinx-coroutines-core", + "kotlinx-coroutines-reactive", + "jsr305", + "jna" + }; + + // Eagerly computed — the classpath is fixed for the lifetime of the test JVM. + private static final String SYSTEM_PACKAGES = buildSystemPackagesFromClasspath(); + + @TempDir + private Path cacheDir; + + private Framework framework; + + @BeforeEach + void startFramework() throws BundleException { + Map config = new HashMap<>(); + config.put("org.osgi.framework.storage", cacheDir.toString()); + config.put("org.osgi.framework.storage.clean", "onFirstInit"); + config.put("felix.log.level", "1"); + if (!SYSTEM_PACKAGES.isEmpty()) { + config.put("org.osgi.framework.system.packages.extra", SYSTEM_PACKAGES); + } + + framework = new FrameworkFactory().newFramework(config); + framework.start(); + } + + @AfterEach + void stopFramework() throws BundleException, InterruptedException { + if (framework != null) { + framework.stop(); + FrameworkEvent event = framework.waitForStop(10_000); + if (event.getType() == FrameworkEvent.WAIT_TIMEDOUT) { + throw new IllegalStateException("OSGi framework did not stop within 10 seconds"); + } + } + } + + @Test + void bundlesResolveWithoutOptionalDependencies() throws Exception { + List installed = installAllBundles(framework.getBundleContext()); + + for (Bundle bundle : installed) { + try { + bundle.start(); + } catch (BundleException e) { + // Fail immediately on the first resolution error. Bundles are wired by + // Import-Package, so an unresolved bundle (e.g. driver-core missing a + // required import) leaves its exported packages unsatisfied for all + // downstream bundles. Collecting further failures would only add + // cascading noise — the first message identifies the root cause. + fail(formatBundleFailure(bundle, e)); + } + } + } + + @Test + void bundlesReportCorrectSymbolicNames() throws Exception { + List installed = installAllBundles(framework.getBundleContext()); + + List symbolicNames = installed.stream() + .map(Bundle::getSymbolicName) + .collect(Collectors.toList()); + + assertThat(symbolicNames).containsExactly( + "org.mongodb.bson", + "org.mongodb.bson-record-codec", + "com.mongodb.crypt.capi", + "org.mongodb.driver-core", + "org.mongodb.scala.mongo-scala-bson", + "org.mongodb.driver-sync", + "org.mongodb.driver-reactivestreams", + "org.mongodb.scala.mongo-scala-driver", + "org.mongodb.mongodb-driver-kotlin-sync", + "org.mongodb.mongodb-driver-kotlin-coroutine", + "org.mongodb.mongodb-driver-kotlin-extensions"); + } + + private List installAllBundles(final BundleContext ctx) throws Exception { + List installed = new ArrayList<>(); + for (String module : BUNDLE_MODULES) { + File jar = findBundleJar(module); + try (InputStream is = Files.newInputStream(jar.toPath())) { + Bundle bundle = ctx.installBundle("file:" + jar.getAbsolutePath(), is); + installed.add(bundle); + } + } + return installed; + } + + // Parses Felix's error message format to extract the missing package name. + private static String formatBundleFailure(final Bundle bundle, final BundleException e) { + String msg = e.getMessage(); + StringBuilder sb = new StringBuilder(); + sb.append("\n\n====================================================================\n"); + sb.append("BUNDLE RESOLUTION FAILURE: ").append(bundle.getSymbolicName()).append("\n"); + sb.append("====================================================================\n"); + + if (msg != null && msg.contains("missing requirement")) { + int pkgStart = msg.indexOf("osgi.wiring.package="); + if (pkgStart >= 0) { + String remainder = msg.substring(pkgStart + "osgi.wiring.package=".length()); + int pkgEnd = remainder.indexOf(')'); + String missingPackage = pkgEnd >= 0 ? remainder.substring(0, pkgEnd) : remainder; + sb.append("Missing required package: ").append(missingPackage).append("\n\n"); + sb.append("FIX: Add '").append(missingPackage).append(".*;resolution:=optional' to the\n"); + sb.append(" Import-Package list in the module's build.gradle.kts\n"); + } + } + + sb.append("\nFull error: ").append(msg); + sb.append("\n====================================================================\n"); + return sb.toString(); + } + + private static String buildSystemPackagesFromClasspath() { + Set packages = new LinkedHashSet<>(); + String classpath = System.getProperty("java.class.path", ""); + + for (String entry : classpath.split(File.pathSeparator)) { + File file = new File(entry); + String name = file.getName(); + if (!matchesAnyPrefix(name)) { + continue; + } + if (!file.isFile() || !name.endsWith(".jar")) { + continue; + } + try (JarFile jar = new JarFile(file)) { + Manifest manifest = jar.getManifest(); + if (manifest == null) { + continue; + } + String version = manifest.getMainAttributes().getValue("Bundle-Version"); + if (version == null) { + version = "0.0.0"; + } + Enumeration entries = jar.entries(); + while (entries.hasMoreElements()) { + JarEntry jarEntry = entries.nextElement(); + String entryName = jarEntry.getName(); + if (entryName.endsWith(".class") && entryName.contains("/")) { + String pkg = entryName.substring(0, entryName.lastIndexOf('/')).replace('/', '.'); + packages.add(pkg + ";version=\"" + version + "\""); + } + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to read classpath JAR: " + file, e); + } + } + + return String.join(",", packages); + } + + private static boolean matchesAnyPrefix(final String fileName) { + for (String prefix : SYSTEM_PACKAGE_JAR_PREFIXES) { + if (fileName.startsWith(prefix)) { + return true; + } + } + return false; + } + + private static File findBundleJar(final String module) { + Path libsDir = PROJECT_ROOT.resolve(module).resolve("build").resolve("libs"); + assertThat(libsDir) + .as("Build output directory for module '%s' must exist. Run ./gradlew jar first.", module) + .isDirectory(); + + try (Stream files = Files.list(libsDir)) { + List candidates = files + .filter(p -> p.getFileName().toString().endsWith(".jar")) + .filter(p -> !p.getFileName().toString().contains("-test")) + .filter(p -> !p.getFileName().toString().contains("-sources")) + .filter(p -> !p.getFileName().toString().contains("-javadoc")) + .map(Path::toFile) + .collect(Collectors.toList()); + + assertThat(candidates) + .as("Expected exactly one main JAR in %s", libsDir) + .hasSize(1); + + return candidates.get(0); + } catch (IOException e) { + return fail("Failed to list JARs in " + libsDir + ": " + e.getMessage()); + } + } +} diff --git a/testing/osgi-test/src/test/java/com/mongodb/osgi/package-info.java b/testing/osgi-test/src/test/java/com/mongodb/osgi/package-info.java new file mode 100644 index 00000000000..0835da294d3 --- /dev/null +++ b/testing/osgi-test/src/test/java/com/mongodb/osgi/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * OSGi bundle resolution regression tests using Apache Felix. + */ +package com.mongodb.osgi; diff --git a/testing/resources/specifications b/testing/resources/specifications index b519824da64..44840386103 160000 --- a/testing/resources/specifications +++ b/testing/resources/specifications @@ -1 +1 @@ -Subproject commit b519824da64005cdf99ca680fc49c4e278af0ef3 +Subproject commit 44840386103b93e5ebb7d7d595ca605c44eb6e08