Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add S3VectorsQueryVectorsOperator
  • Loading branch information
john-jac committed May 11, 2026
commit 3babb47e4e78cb3d31d17c083bd68e0ff59bb61c
15 changes: 15 additions & 0 deletions providers/amazon/docs/operators/s3_vectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ To insert vectors into an Amazon S3 Vectors index, use
:start-after: [START howto_operator_s3vectors_put_vectors]
:end-before: [END howto_operator_s3vectors_put_vectors]


.. _howto/operator:S3VectorsQueryVectorsOperator:

Query Vectors
-------------

To query vectors by similarity in an Amazon S3 Vectors index, use
:class:`~airflow.providers.amazon.aws.operators.s3_vectors.S3VectorsQueryVectorsOperator`.

.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_s3_vectors.py
:language: python
:dedent: 4
:start-after: [START howto_operator_s3vectors_query_vectors]
:end-before: [END howto_operator_s3vectors_query_vectors]

Reference
---------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,72 @@ def execute(self, context: Context) -> None:
vectors=self.vectors,
)
self.log.info("Put %d vectors successfully", len(self.vectors))


class S3VectorsQueryVectorsOperator(AwsBaseOperator[AwsBaseHook]):
"""
Query vectors by similarity in an Amazon S3 Vectors index.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:S3VectorsQueryVectorsOperator`

:param vector_bucket_name: The name of the vector bucket. (templated)
:param index_name: The name of the index. (templated)
:param top_k: The number of results to return.
:param query_vector: The query vector dict (e.g. ``{"float32": [0.1, 0.2, ...]}``)
:param filter: Optional filter expression dict.
:param return_metadata: Whether to return metadata with results.
:param return_distance: Whether to return distance scores.
"""

aws_hook_class = AwsBaseHook
template_fields: tuple[str, ...] = (
*AwsBaseOperator.template_fields,
"vector_bucket_name",
"index_name",
"top_k",
)

def __init__(
self,
*,
vector_bucket_name: str,
index_name: str,
top_k: int,
query_vector: dict[str, Any],
filter: dict[str, Any] | None = None,
return_metadata: bool = True,
return_distance: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.vector_bucket_name = vector_bucket_name
self.index_name = index_name
self.top_k = top_k
self.query_vector = query_vector
self.filter = filter
self.return_metadata = return_metadata
self.return_distance = return_distance

@property
def _hook_parameters(self) -> dict[str, Any]:
return {**super()._hook_parameters, "client_type": "s3vectors"}

def execute(self, context: Context) -> list[dict[str, Any]]:
self.log.info("Querying top %d vectors from index %s", self.top_k, self.index_name)
kwargs: dict[str, Any] = prune_dict(
{
"vectorBucketName": self.vector_bucket_name,
"indexName": self.index_name,
"topK": self.top_k,
"queryVector": self.query_vector,
"filter": self.filter,
"returnMetadata": self.return_metadata,
"returnDistance": self.return_distance,
}
)
response = self.hook.conn.query_vectors(**kwargs)
vectors = response.get("vectors", [])
self.log.info("Query returned %d results", len(vectors))
return vectors
12 changes: 12 additions & 0 deletions providers/amazon/tests/system/amazon/aws/example_s3_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
S3VectorsDeleteIndexOperator,
S3VectorsDeleteVectorBucketOperator,
S3VectorsPutVectorsOperator,
S3VectorsQueryVectorsOperator,
)
from airflow.providers.common.compat.sdk import DAG, chain

Expand Down Expand Up @@ -86,6 +87,16 @@
)
# [END howto_operator_s3vectors_delete_vector_bucket]

# [START howto_operator_s3vectors_query_vectors]
query_vectors = S3VectorsQueryVectorsOperator(
task_id="query_vectors",
vector_bucket_name=bucket_name,
index_name=index_name,
top_k=3,
query_vector={"float32": [0.1, 0.2, 0.3, 0.4]},
)
# [END howto_operator_s3vectors_query_vectors]

# [START howto_operator_s3vectors_delete_index]
delete_index = S3VectorsDeleteIndexOperator(
task_id="delete_index",
Expand All @@ -100,6 +111,7 @@
create_vector_bucket,
create_index,
put_vectors,
query_vectors,
delete_index,
delete_vector_bucket,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
S3VectorsDeleteIndexOperator,
S3VectorsDeleteVectorBucketOperator,
S3VectorsPutVectorsOperator,
S3VectorsQueryVectorsOperator,
)

from unit.amazon.aws.utils.test_template_fields import validate_template_fields
Expand Down Expand Up @@ -250,3 +251,49 @@ def test_execute(self):

def test_template_fields(self):
validate_template_fields(self.operator)


QUERY_VECTOR = {"float32": [0.1, 0.2, 0.3, 0.4]}
QUERY_RESULTS = [{"key": "vec1", "distance": 0.95, "metadata": {"label": "test"}}]


class TestS3VectorsQueryVectorsOperator:
def setup_method(self):
self.operator = S3VectorsQueryVectorsOperator(
task_id="query_vectors",
vector_bucket_name=BUCKET_NAME,
index_name=INDEX_NAME,
top_k=5,
query_vector=QUERY_VECTOR,
)

def test_execute(self):
mock_conn = MagicMock()
mock_conn.query_vectors.return_value = {"vectors": QUERY_RESULTS, "distanceMetric": "cosine"}
self.operator.hook.conn = mock_conn

result = self.operator.execute({})

mock_conn.query_vectors.assert_called_once_with(
vectorBucketName=BUCKET_NAME,
indexName=INDEX_NAME,
topK=5,
queryVector=QUERY_VECTOR,
returnMetadata=True,
returnDistance=True,
)
assert result == QUERY_RESULTS

def test_execute_with_filter(self):
mock_conn = MagicMock()
mock_conn.query_vectors.return_value = {"vectors": []}
self.operator.hook.conn = mock_conn
self.operator.filter = {"equals": {"key": "label", "value": "test"}}

self.operator.execute({})

call_kwargs = mock_conn.query_vectors.call_args[1]
assert call_kwargs["filter"] == {"equals": {"key": "label", "value": "test"}}

def test_template_fields(self):
validate_template_fields(self.operator)
Loading