Skip to content
This repository was archived by the owner on Mar 2, 2026. It is now read-only.

Commit fed7af2

Browse files
feat: query to pipeline conversion (#1071)
1 parent cd578c1 commit fed7af2

18 files changed

+1019
-53
lines changed

google/cloud/firestore_v1/_pipeline_stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def __init__(self, collection_id: str):
216216
self.collection_id = collection_id
217217

218218
def _pb_args(self):
219-
return [Value(string_value=self.collection_id)]
219+
return [Value(reference_value=""), Value(string_value=self.collection_id)]
220220

221221

222222
class Database(Stage):

google/cloud/firestore_v1/base_aggregation.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from __future__ import annotations
2222

2323
import abc
24+
import itertools
2425

2526
from abc import ABC
26-
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union
27+
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable
2728

2829
from google.api_core import gapic_v1
2930
from google.api_core import retry as retries
@@ -33,6 +34,10 @@
3334
from google.cloud.firestore_v1.types import (
3435
StructuredAggregationQuery,
3536
)
37+
from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction
38+
from google.cloud.firestore_v1.pipeline_expressions import Count
39+
from google.cloud.firestore_v1.pipeline_expressions import AliasedExpr
40+
from google.cloud.firestore_v1.pipeline_expressions import Field
3641

3742
# Types needed only for Type Hints
3843
if TYPE_CHECKING: # pragma: NO COVER
@@ -66,6 +71,9 @@ def __init__(self, alias: str, value: float, read_time=None):
6671
def __repr__(self):
6772
return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>"
6873

74+
def _to_dict(self):
75+
return {self.alias: self.value}
76+
6977

7078
class BaseAggregation(ABC):
7179
def __init__(self, alias: str | None = None):
@@ -75,6 +83,27 @@ def __init__(self, alias: str | None = None):
7583
def _to_protobuf(self):
7684
"""Convert this instance to the protobuf representation"""
7785

86+
@abc.abstractmethod
87+
def _to_pipeline_expr(
88+
self, autoindexer: Iterable[int]
89+
) -> AliasedExpr[AggregateFunction]:
90+
"""
91+
Convert this instance to a pipeline expression for use with pipeline.aggregate()
92+
93+
Args:
94+
autoindexer: If an alias isn't supplied, one should be created with the format "field_n"
95+
The autoindexer is an iterable that provides the `n` value to use for each expression
96+
"""
97+
98+
def _pipeline_alias(self, autoindexer):
99+
"""
100+
Helper to build the alias for the pipeline expression
101+
"""
102+
if self.alias is not None:
103+
return self.alias
104+
else:
105+
return f"field_{next(autoindexer)}"
106+
78107

79108
class CountAggregation(BaseAggregation):
80109
def __init__(self, alias: str | None = None):
@@ -88,6 +117,9 @@ def _to_protobuf(self):
88117
aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
89118
return aggregation_pb
90119

120+
def _to_pipeline_expr(self, autoindexer: Iterable[int]):
121+
return Count().as_(self._pipeline_alias(autoindexer))
122+
91123

92124
class SumAggregation(BaseAggregation):
93125
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
@@ -107,6 +139,9 @@ def _to_protobuf(self):
107139
aggregation_pb.sum.field.field_path = self.field_ref
108140
return aggregation_pb
109141

142+
def _to_pipeline_expr(self, autoindexer: Iterable[int]):
143+
return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer))
144+
110145

111146
class AvgAggregation(BaseAggregation):
112147
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
@@ -126,6 +161,9 @@ def _to_protobuf(self):
126161
aggregation_pb.avg.field.field_path = self.field_ref
127162
return aggregation_pb
128163

164+
def _to_pipeline_expr(self, autoindexer: Iterable[int]):
165+
return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer))
166+
129167

130168
def _query_response_to_result(
131169
response_pb,
@@ -317,3 +355,20 @@ def stream(
317355
StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]:
318356
A generator of the query results.
319357
"""
358+
359+
def pipeline(self):
360+
"""
361+
Convert this query into a Pipeline
362+
363+
Queries containing a `cursor` or `limit_to_last` are not currently supported
364+
365+
Raises:
366+
- ValueError: raised if Query wasn't created with an associated client
367+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
368+
Returns:
369+
a Pipeline representing the query
370+
"""
371+
# use autoindexer to keep track of which field number to use for un-aliased fields
372+
autoindexer = itertools.count(start=1)
373+
exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations]
374+
return self._nested_query.pipeline().aggregate(*exprs)

google/cloud/firestore_v1/base_collection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,19 @@ def find_nearest(
602602
distance_threshold=distance_threshold,
603603
)
604604

605+
def pipeline(self):
606+
"""
607+
Convert this query into a Pipeline
608+
609+
Queries containing a `cursor` or `limit_to_last` are not currently supported
610+
611+
Raises:
612+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
613+
Returns:
614+
a Pipeline representing the query
615+
"""
616+
return self._query().pipeline()
617+
605618

606619
def _auto_id() -> str:
607620
"""Generate a "random" automatically generated ID.

google/cloud/firestore_v1/base_query.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
query,
6060
)
6161
from google.cloud.firestore_v1.vector import Vector
62+
from google.cloud.firestore_v1 import pipeline_expressions
6263

6364
if TYPE_CHECKING: # pragma: NO COVER
6465
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
@@ -1128,6 +1129,74 @@ def recursive(self: QueryType) -> QueryType:
11281129

11291130
return copied
11301131

1132+
def pipeline(self):
1133+
"""
1134+
Convert this query into a Pipeline
1135+
1136+
Queries containing a `cursor` or `limit_to_last` are not currently supported
1137+
1138+
Raises:
1139+
- ValueError: raised if Query wasn't created with an associated client
1140+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
1141+
Returns:
1142+
a Pipeline representing the query
1143+
"""
1144+
if not self._client:
1145+
raise ValueError("Query does not have an associated client")
1146+
if self._all_descendants:
1147+
ppl = self._client.pipeline().collection_group(self._parent.id)
1148+
else:
1149+
ppl = self._client.pipeline().collection(self._parent._path)
1150+
1151+
# Filters
1152+
for filter_ in self._field_filters:
1153+
ppl = ppl.where(
1154+
pipeline_expressions.BooleanExpr._from_query_filter_pb(
1155+
filter_, self._client
1156+
)
1157+
)
1158+
1159+
# Projections
1160+
if self._projection and self._projection.fields:
1161+
ppl = ppl.select(*[field.field_path for field in self._projection.fields])
1162+
1163+
# Orders
1164+
orders = self._normalize_orders()
1165+
if orders:
1166+
exists = []
1167+
orderings = []
1168+
for order in orders:
1169+
field = pipeline_expressions.Field.of(order.field.field_path)
1170+
exists.append(field.exists())
1171+
direction = (
1172+
"ascending"
1173+
if order.direction == StructuredQuery.Direction.ASCENDING
1174+
else "descending"
1175+
)
1176+
orderings.append(pipeline_expressions.Ordering(field, direction))
1177+
1178+
# Add exists filters to match Query's implicit orderby semantics.
1179+
if len(exists) == 1:
1180+
ppl = ppl.where(exists[0])
1181+
else:
1182+
ppl = ppl.where(pipeline_expressions.And(*exists))
1183+
1184+
# Add sort orderings
1185+
ppl = ppl.sort(*orderings)
1186+
1187+
# Cursors, Limit and Offset
1188+
if self._start_at or self._end_at or self._limit_to_last:
1189+
raise NotImplementedError(
1190+
"Query to Pipeline conversion: cursors and limit_to_last is not supported yet."
1191+
)
1192+
else: # Limit & Offset without cursors
1193+
if self._offset:
1194+
ppl = ppl.offset(self._offset)
1195+
if self._limit:
1196+
ppl = ppl.limit(self._limit)
1197+
1198+
return ppl
1199+
11311200
def _comparator(self, doc1, doc2) -> int:
11321201
_orders = self._orders
11331202

google/cloud/firestore_v1/pipeline_expressions.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,18 @@ def is_nan(self) -> "BooleanExpr":
587587
"""
588588
return BooleanExpr("is_nan", [self])
589589

590+
@expose_as_static
591+
def is_null(self) -> "BooleanExpr":
592+
"""Creates an expression that checks if this expression evaluates to 'Null'.
593+
594+
Example:
595+
>>> Field.of("value").is_null()
596+
597+
Returns:
598+
A new `Expr` representing the 'isNull' check.
599+
"""
600+
return BooleanExpr("is_null", [self])
601+
590602
@expose_as_static
591603
def exists(self) -> "BooleanExpr":
592604
"""Creates an expression that checks if a field exists in the document.
@@ -627,6 +639,7 @@ def average(self) -> "Expr":
627639
"""
628640
return AggregateFunction("average", [self])
629641

642+
@expose_as_static
630643
def count(self) -> "Expr":
631644
"""Creates an aggregation that counts the number of stage inputs with valid evaluations of the
632645
expression or field.
@@ -1312,9 +1325,9 @@ def _from_query_filter_pb(filter_pb, client):
13121325
elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN:
13131326
return And(field.exists(), Not(field.is_nan()))
13141327
elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL:
1315-
return And(field.exists(), field.equal(None))
1328+
return And(field.exists(), field.is_null())
13161329
elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL:
1317-
return And(field.exists(), Not(field.equal(None)))
1330+
return And(field.exists(), Not(field.is_null()))
13181331
else:
13191332
raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}")
13201333
elif isinstance(filter_pb, Query_pb.FieldFilter):
@@ -1361,7 +1374,7 @@ class And(BooleanExpr):
13611374
Example:
13621375
>>> # Check if the 'age' field is greater than 18 AND the 'city' field is "London" AND
13631376
>>> # the 'status' field is "active"
1364-
>>> Expr.And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active"))
1377+
>>> And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active"))
13651378
13661379
Args:
13671380
*conditions: The filter conditions to 'AND' together.
@@ -1377,7 +1390,7 @@ class Not(BooleanExpr):
13771390
13781391
Example:
13791392
>>> # Find documents where the 'completed' field is NOT true
1380-
>>> Expr.Not(Field.of("completed").equal(True))
1393+
>>> Not(Field.of("completed").equal(True))
13811394
13821395
Args:
13831396
condition: The filter condition to negate.
@@ -1394,7 +1407,7 @@ class Or(BooleanExpr):
13941407
Example:
13951408
>>> # Check if the 'age' field is greater than 18 OR the 'city' field is "London" OR
13961409
>>> # the 'status' field is "active"
1397-
>>> Expr.Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active"))
1410+
>>> Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active"))
13981411
13991412
Args:
14001413
*conditions: The filter conditions to 'OR' together.
@@ -1411,7 +1424,7 @@ class Xor(BooleanExpr):
14111424
Example:
14121425
>>> # Check if only one of the conditions is true: 'age' greater than 18, 'city' is "London",
14131426
>>> # or 'status' is "active".
1414-
>>> Expr.Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active"))
1427+
>>> Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active"))
14151428
14161429
Args:
14171430
*conditions: The filter conditions to 'XOR' together.
@@ -1428,7 +1441,7 @@ class Conditional(BooleanExpr):
14281441
14291442
Example:
14301443
>>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor".
1431-
>>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor"));
1444+
>>> Conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor"));
14321445
14331446
Args:
14341447
condition: The condition to evaluate.
@@ -1440,3 +1453,24 @@ def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr):
14401453
super().__init__(
14411454
"conditional", [condition, then_expr, else_expr], use_infix_repr=False
14421455
)
1456+
1457+
class Count(AggregateFunction):
1458+
"""
1459+
Represents an aggregation that counts the number of stage inputs with valid evaluations of the
1460+
expression or field.
1461+
1462+
Example:
1463+
>>> # Count the total number of products
1464+
>>> Field.of("productId").count().as_("totalProducts")
1465+
>>> Count(Field.of("productId"))
1466+
>>> Count().as_("count")
1467+
1468+
Args:
1469+
expression: The expression or field to count. If None, counts all stage inputs.
1470+
"""
1471+
1472+
def __init__(self, expression: Expr | None = None):
1473+
expression_list = [expression] if expression else []
1474+
super().__init__(
1475+
"count", expression_list, use_infix_repr=bool(expression_list)
1476+
)

tests/system/test__helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
# run all tests against default database, and a named database
2121
# TODO: add enterprise mode when GA (RunQuery not currently supported)
2222
TEST_DATABASES = [None, FIRESTORE_OTHER_DB]
23+
TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB]

0 commit comments

Comments
 (0)