Skip to content
This repository was archived by the owner on Mar 2, 2026. It is now read-only.

Commit b77002d

Browse files
feat: pipelines create_from() (#1124)
1 parent 8e83a40 commit b77002d

15 files changed

+98
-65
lines changed

google/cloud/firestore_v1/base_aggregation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from google.cloud.firestore_v1.stream_generator import (
4949
StreamGenerator,
5050
)
51+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
5152

5253
import datetime
5354

@@ -356,19 +357,20 @@ def stream(
356357
A generator of the query results.
357358
"""
358359

359-
def pipeline(self):
360+
def _build_pipeline(self, source: "PipelineSource"):
360361
"""
361362
Convert this query into a Pipeline
362363
363364
Queries containing a `cursor` or `limit_to_last` are not currently supported
364365
366+
Args:
367+
source: the PipelineSource to build the pipeline off of
365368
Raises:
366-
- ValueError: raised if Query wasn't created with an associated client
367369
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
368370
Returns:
369371
a Pipeline representing the query
370372
"""
371373
# use autoindexer to keep track of which field number to use for un-aliased fields
372374
autoindexer = itertools.count(start=1)
373375
exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations]
374-
return self._nested_query.pipeline().aggregate(*exprs)
376+
return self._nested_query._build_pipeline(source).aggregate(*exprs)

google/cloud/firestore_v1/base_collection.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
4949
from google.cloud.firestore_v1.document import DocumentReference
5050
from google.cloud.firestore_v1.field_path import FieldPath
51+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
5152
from google.cloud.firestore_v1.query_profile import ExplainOptions
5253
from google.cloud.firestore_v1.query_results import QueryResultsList
5354
from google.cloud.firestore_v1.stream_generator import StreamGenerator
@@ -602,18 +603,20 @@ def find_nearest(
602603
distance_threshold=distance_threshold,
603604
)
604605

605-
def pipeline(self):
606+
def _build_pipeline(self, source: "PipelineSource"):
606607
"""
607608
Convert this query into a Pipeline
608609
609610
Queries containing a `cursor` or `limit_to_last` are not currently supported
610611
612+
Args:
613+
source: the PipelineSource to build the pipeline off o
611614
Raises:
612615
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
613616
Returns:
614617
a Pipeline representing the query
615618
"""
616-
return self._query().pipeline()
619+
return self._query()._build_pipeline(source)
617620

618621

619622
def _auto_id() -> str:

google/cloud/firestore_v1/base_query.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from google.cloud.firestore_v1.query_profile import ExplainOptions
6868
from google.cloud.firestore_v1.query_results import QueryResultsList
6969
from google.cloud.firestore_v1.stream_generator import StreamGenerator
70+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
7071

7172
import datetime
7273

@@ -1129,24 +1130,23 @@ def recursive(self: QueryType) -> QueryType:
11291130

11301131
return copied
11311132

1132-
def pipeline(self):
1133+
def _build_pipeline(self, source: "PipelineSource"):
11331134
"""
11341135
Convert this query into a Pipeline
11351136
11361137
Queries containing a `cursor` or `limit_to_last` are not currently supported
11371138
1139+
Args:
1140+
source: the PipelineSource to build the pipeline off of
11381141
Raises:
1139-
- ValueError: raised if Query wasn't created with an associated client
11401142
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
11411143
Returns:
11421144
a Pipeline representing the query
11431145
"""
1144-
if not self._client:
1145-
raise ValueError("Query does not have an associated client")
11461146
if self._all_descendants:
1147-
ppl = self._client.pipeline().collection_group(self._parent.id)
1147+
ppl = source.collection_group(self._parent.id)
11481148
else:
1149-
ppl = self._client.pipeline().collection(self._parent._path)
1149+
ppl = source.collection(self._parent._path)
11501150

11511151
# Filters
11521152
for filter_ in self._field_filters:

google/cloud/firestore_v1/pipeline_source.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from google.cloud.firestore_v1.client import Client
2323
from google.cloud.firestore_v1.async_client import AsyncClient
2424
from google.cloud.firestore_v1.base_document import BaseDocumentReference
25+
from google.cloud.firestore_v1.base_query import BaseQuery
26+
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
27+
from google.cloud.firestore_v1.base_collection import BaseCollectionReference
2528

2629

2730
PipelineType = TypeVar("PipelineType", bound=_BasePipeline)
@@ -43,6 +46,23 @@ def __init__(self, client: Client | AsyncClient):
4346
def _create_pipeline(self, source_stage):
4447
return self.client._pipeline_cls._create_with_stages(self.client, source_stage)
4548

49+
def create_from(
50+
self, query: "BaseQuery" | "BaseAggregationQuery" | "BaseCollectionReference"
51+
) -> PipelineType:
52+
"""
53+
Create a pipeline from an existing query
54+
55+
Queries containing a `cursor` or `limit_to_last` are not currently supported
56+
57+
Args:
58+
query: the query to build the pipeline off of
59+
Raises:
60+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
61+
Returns:
62+
a new pipeline instance representing the query
63+
"""
64+
return query._build_pipeline(self)
65+
4666
def collection(self, path: str | tuple[str]) -> PipelineType:
4767
"""
4868
Creates a new Pipeline that operates on a specified Firestore collection.

tests/system/test_system.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def _clean_results(results):
133133
except Exception as e:
134134
# if we expect the query to fail, capture the exception
135135
query_exception = e
136-
pipeline = query.pipeline()
136+
client = query._client
137+
pipeline = client.pipeline().create_from(query)
137138
if query_exception:
138139
# ensure that the pipeline uses same error as query
139140
with pytest.raises(query_exception.__class__):

tests/system/test_system_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def _clean_results(results):
213213
except Exception as e:
214214
# if we expect the query to fail, capture the exception
215215
query_exception = e
216-
pipeline = query.pipeline()
216+
client = query._client
217+
pipeline = client.pipeline().create_from(query)
217218
if query_exception:
218219
# ensure that the pipeline uses same error as query
219220
with pytest.raises(query_exception.__class__):

tests/unit/v1/test_aggregation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias):
10401040
query = make_query(parent)
10411041
aggregation_query = make_aggregation_query(query)
10421042
aggregation_query.sum(field, alias=in_alias)
1043-
pipeline = aggregation_query.pipeline()
1043+
pipeline = aggregation_query._build_pipeline(client.pipeline())
10441044
assert isinstance(pipeline, Pipeline)
10451045
assert len(pipeline.stages) == 2
10461046
assert isinstance(pipeline.stages[0], Collection)
@@ -1071,7 +1071,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias):
10711071
query = make_query(parent)
10721072
aggregation_query = make_aggregation_query(query)
10731073
aggregation_query.avg(field, alias=in_alias)
1074-
pipeline = aggregation_query.pipeline()
1074+
pipeline = aggregation_query._build_pipeline(client.pipeline())
10751075
assert isinstance(pipeline, Pipeline)
10761076
assert len(pipeline.stages) == 2
10771077
assert isinstance(pipeline.stages[0], Collection)
@@ -1102,7 +1102,7 @@ def test_aggreation_to_pipeline_count(in_alias, out_alias):
11021102
query = make_query(parent)
11031103
aggregation_query = make_aggregation_query(query)
11041104
aggregation_query.count(alias=in_alias)
1105-
pipeline = aggregation_query.pipeline()
1105+
pipeline = aggregation_query._build_pipeline(client.pipeline())
11061106
assert isinstance(pipeline, Pipeline)
11071107
assert len(pipeline.stages) == 2
11081108
assert isinstance(pipeline.stages[0], Collection)
@@ -1127,7 +1127,7 @@ def test_aggreation_to_pipeline_count_increment():
11271127
aggregation_query = make_aggregation_query(query)
11281128
for _ in range(n):
11291129
aggregation_query.count()
1130-
pipeline = aggregation_query.pipeline()
1130+
pipeline = aggregation_query._build_pipeline(client.pipeline())
11311131
aggregate_stage = pipeline.stages[1]
11321132
assert len(aggregate_stage.accumulators) == n
11331133
for i in range(n):
@@ -1146,7 +1146,7 @@ def test_aggreation_to_pipeline_complex():
11461146
aggregation_query.count()
11471147
aggregation_query.avg("other")
11481148
aggregation_query.sum("final")
1149-
pipeline = aggregation_query.pipeline()
1149+
pipeline = aggregation_query._build_pipeline(client.pipeline())
11501150
assert isinstance(pipeline, Pipeline)
11511151
assert len(pipeline.stages) == 3
11521152
assert isinstance(pipeline.stages[0], Collection)

tests/unit/v1/test_async_aggregation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias):
716716
query = make_async_query(parent)
717717
aggregation_query = make_async_aggregation_query(query)
718718
aggregation_query.sum(field, alias=in_alias)
719-
pipeline = aggregation_query.pipeline()
719+
pipeline = aggregation_query._build_pipeline(client.pipeline())
720720
assert isinstance(pipeline, AsyncPipeline)
721721
assert len(pipeline.stages) == 2
722722
assert isinstance(pipeline.stages[0], Collection)
@@ -747,7 +747,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias):
747747
query = make_async_query(parent)
748748
aggregation_query = make_async_aggregation_query(query)
749749
aggregation_query.avg(field, alias=in_alias)
750-
pipeline = aggregation_query.pipeline()
750+
pipeline = aggregation_query._build_pipeline(client.pipeline())
751751
assert isinstance(pipeline, AsyncPipeline)
752752
assert len(pipeline.stages) == 2
753753
assert isinstance(pipeline.stages[0], Collection)
@@ -778,7 +778,7 @@ def test_async_aggreation_to_pipeline_count(in_alias, out_alias):
778778
query = make_async_query(parent)
779779
aggregation_query = make_async_aggregation_query(query)
780780
aggregation_query.count(alias=in_alias)
781-
pipeline = aggregation_query.pipeline()
781+
pipeline = aggregation_query._build_pipeline(client.pipeline())
782782
assert isinstance(pipeline, AsyncPipeline)
783783
assert len(pipeline.stages) == 2
784784
assert isinstance(pipeline.stages[0], Collection)
@@ -803,7 +803,7 @@ def test_aggreation_to_pipeline_count_increment():
803803
aggregation_query = make_async_aggregation_query(query)
804804
for _ in range(n):
805805
aggregation_query.count()
806-
pipeline = aggregation_query.pipeline()
806+
pipeline = aggregation_query._build_pipeline(client.pipeline())
807807
aggregate_stage = pipeline.stages[1]
808808
assert len(aggregate_stage.accumulators) == n
809809
for i in range(n):
@@ -822,7 +822,7 @@ def test_async_aggreation_to_pipeline_complex():
822822
aggregation_query.count()
823823
aggregation_query.avg("other")
824824
aggregation_query.sum("final")
825-
pipeline = aggregation_query.pipeline()
825+
pipeline = aggregation_query._build_pipeline(client.pipeline())
826826
assert isinstance(pipeline, AsyncPipeline)
827827
assert len(pipeline.stages) == 3
828828
assert isinstance(pipeline.stages[0], Collection)

tests/unit/v1/test_async_collection.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -609,15 +609,9 @@ def test_asynccollectionreference_pipeline():
609609

610610
client = make_async_client()
611611
collection = _make_async_collection_reference("collection", client=client)
612-
pipeline = collection.pipeline()
612+
pipeline = collection._build_pipeline(client.pipeline())
613613
assert isinstance(pipeline, AsyncPipeline)
614614
# should have single "Collection" stage
615615
assert len(pipeline.stages) == 1
616616
assert isinstance(pipeline.stages[0], Collection)
617617
assert pipeline.stages[0].path == "/collection"
618-
619-
620-
def test_asynccollectionreference_pipeline_no_client():
621-
collection = _make_async_collection_reference("collection")
622-
with pytest.raises(ValueError, match="client"):
623-
collection.pipeline()

tests/unit/v1/test_async_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ def test_asyncquery_collection_pipeline_type():
917917
client = make_async_client()
918918
parent = client.collection("test")
919919
query = parent._query()
920-
ppl = query.pipeline()
920+
ppl = query._build_pipeline(client.pipeline())
921921
assert isinstance(ppl, AsyncPipeline)
922922

923923

@@ -926,5 +926,5 @@ def test_asyncquery_collectiongroup_pipeline_type():
926926

927927
client = make_async_client()
928928
query = client.collection_group("test")
929-
ppl = query.pipeline()
929+
ppl = query._build_pipeline(client.pipeline())
930930
assert isinstance(ppl, AsyncPipeline)

0 commit comments

Comments
 (0)