From 3babb47e4e78cb3d31d17c083bd68e0ff59bb61c Mon Sep 17 00:00:00 2001 From: john-jac <75442233+john-jac@users.noreply.github.com> Date: Mon, 11 May 2026 09:56:11 -0700 Subject: [PATCH] Add `S3VectorsQueryVectorsOperator` --- .../amazon/docs/operators/s3_vectors.rst | 15 ++++ .../amazon/aws/operators/s3_vectors.py | 69 +++++++++++++++++++ .../system/amazon/aws/example_s3_vectors.py | 12 ++++ .../amazon/aws/operators/test_s3_vectors.py | 47 +++++++++++++ 4 files changed, 143 insertions(+) diff --git a/providers/amazon/docs/operators/s3_vectors.rst b/providers/amazon/docs/operators/s3_vectors.rst index 6e1d364377a42..ac8bf6dedacf5 100644 --- a/providers/amazon/docs/operators/s3_vectors.rst +++ b/providers/amazon/docs/operators/s3_vectors.rst @@ -92,6 +92,21 @@ To insert vectors into an Amazon S3 Vectors index, use :start-after: [START howto_operator_s3vectors_put_vectors] :end-before: [END howto_operator_s3vectors_put_vectors] + +.. _howto/operator:S3VectorsQueryVectorsOperator: + +Query Vectors +------------- + +To query vectors by similarity in an Amazon S3 Vectors index, use +:class:`~airflow.providers.amazon.aws.operators.s3_vectors.S3VectorsQueryVectorsOperator`. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_s3_vectors.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_s3vectors_query_vectors] + :end-before: [END howto_operator_s3vectors_query_vectors] + Reference --------- diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py index 1ca23e05bf2f8..cac1d2409f9f5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py @@ -298,3 +298,72 @@ def execute(self, context: Context) -> None: vectors=self.vectors, ) self.log.info("Put %d vectors successfully", len(self.vectors)) + + +class S3VectorsQueryVectorsOperator(AwsBaseOperator[AwsBaseHook]): + """ + Query vectors by similarity in an Amazon S3 Vectors index. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3VectorsQueryVectorsOperator` + + :param vector_bucket_name: The name of the vector bucket. (templated) + :param index_name: The name of the index. (templated) + :param top_k: The number of results to return. + :param query_vector: The query vector dict (e.g. ``{"float32": [0.1, 0.2, ...]}``) + :param filter: Optional filter expression dict. + :param return_metadata: Whether to return metadata with results. + :param return_distance: Whether to return distance scores. + """ + + aws_hook_class = AwsBaseHook + template_fields: tuple[str, ...] = ( + *AwsBaseOperator.template_fields, + "vector_bucket_name", + "index_name", + "top_k", + ) + + def __init__( + self, + *, + vector_bucket_name: str, + index_name: str, + top_k: int, + query_vector: dict[str, Any], + filter: dict[str, Any] | None = None, + return_metadata: bool = True, + return_distance: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.vector_bucket_name = vector_bucket_name + self.index_name = index_name + self.top_k = top_k + self.query_vector = query_vector + self.filter = filter + self.return_metadata = return_metadata + self.return_distance = return_distance + + @property + def _hook_parameters(self) -> dict[str, Any]: + return {**super()._hook_parameters, "client_type": "s3vectors"} + + def execute(self, context: Context) -> list[dict[str, Any]]: + self.log.info("Querying top %d vectors from index %s", self.top_k, self.index_name) + kwargs: dict[str, Any] = prune_dict( + { + "vectorBucketName": self.vector_bucket_name, + "indexName": self.index_name, + "topK": self.top_k, + "queryVector": self.query_vector, + "filter": self.filter, + "returnMetadata": self.return_metadata, + "returnDistance": self.return_distance, + } + ) + response = self.hook.conn.query_vectors(**kwargs) + vectors = response.get("vectors", []) + self.log.info("Query returned %d results", len(vectors)) + return vectors diff --git a/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py b/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py index 9faedd9549e6d..dda071944c8a9 100644 --- a/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py +++ b/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py @@ -24,6 +24,7 @@ S3VectorsDeleteIndexOperator, S3VectorsDeleteVectorBucketOperator, S3VectorsPutVectorsOperator, + S3VectorsQueryVectorsOperator, ) from airflow.providers.common.compat.sdk import DAG, chain @@ -86,6 +87,16 @@ ) # [END howto_operator_s3vectors_delete_vector_bucket] + # [START howto_operator_s3vectors_query_vectors] + query_vectors = S3VectorsQueryVectorsOperator( + task_id="query_vectors", + vector_bucket_name=bucket_name, + index_name=index_name, + top_k=3, + query_vector={"float32": [0.1, 0.2, 0.3, 0.4]}, + ) + # [END howto_operator_s3vectors_query_vectors] + # [START howto_operator_s3vectors_delete_index] delete_index = S3VectorsDeleteIndexOperator( task_id="delete_index", @@ -100,6 +111,7 @@ create_vector_bucket, create_index, put_vectors, + query_vectors, delete_index, delete_vector_bucket, ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py b/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py index e483d0396db07..df564d498b059 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py @@ -27,6 +27,7 @@ S3VectorsDeleteIndexOperator, S3VectorsDeleteVectorBucketOperator, S3VectorsPutVectorsOperator, + S3VectorsQueryVectorsOperator, ) from unit.amazon.aws.utils.test_template_fields import validate_template_fields @@ -250,3 +251,49 @@ def test_execute(self): def test_template_fields(self): validate_template_fields(self.operator) + + +QUERY_VECTOR = {"float32": [0.1, 0.2, 0.3, 0.4]} +QUERY_RESULTS = [{"key": "vec1", "distance": 0.95, "metadata": {"label": "test"}}] + + +class TestS3VectorsQueryVectorsOperator: + def setup_method(self): + self.operator = S3VectorsQueryVectorsOperator( + task_id="query_vectors", + vector_bucket_name=BUCKET_NAME, + index_name=INDEX_NAME, + top_k=5, + query_vector=QUERY_VECTOR, + ) + + def test_execute(self): + mock_conn = MagicMock() + mock_conn.query_vectors.return_value = {"vectors": QUERY_RESULTS, "distanceMetric": "cosine"} + self.operator.hook.conn = mock_conn + + result = self.operator.execute({}) + + mock_conn.query_vectors.assert_called_once_with( + vectorBucketName=BUCKET_NAME, + indexName=INDEX_NAME, + topK=5, + queryVector=QUERY_VECTOR, + returnMetadata=True, + returnDistance=True, + ) + assert result == QUERY_RESULTS + + def test_execute_with_filter(self): + mock_conn = MagicMock() + mock_conn.query_vectors.return_value = {"vectors": []} + self.operator.hook.conn = mock_conn + self.operator.filter = {"equals": {"key": "label", "value": "test"}} + + self.operator.execute({}) + + call_kwargs = mock_conn.query_vectors.call_args[1] + assert call_kwargs["filter"] == {"equals": {"key": "label", "value": "test"}} + + def test_template_fields(self): + validate_template_fields(self.operator)