Skip to content

Commit 281eaae

Browse files
feat(firestore): add DML support to pipelines (#16473)
Add support for Data Manipulation Language, making pipelines read/write instead of read-only New Stages: - `Delete` - `Update` New e2e test assertion: - `assert_end_state`, to assert that the database is in a certain state after running a pipeline --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ae55a1b commit 281eaae

File tree

5 files changed

+210
-6
lines changed

5 files changed

+210
-6
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,62 @@ def distinct(self, *fields: str | Selectable) -> "_BasePipeline":
669669
"""
670670
return self._append(stages.Distinct(*fields))
671671

672+
def delete(self) -> "_BasePipeline":
673+
"""
674+
Deletes the documents from the current pipeline stage.
675+
676+
Example:
677+
>>> from google.cloud.firestore_v1.pipeline_expressions import Field
678+
>>> pipeline = client.pipeline().collection("logs")
679+
>>> # Delete all documents in the "logs" collection where "status" is "archived"
680+
>>> pipeline = pipeline.where(Field.of("status").equal("archived")).delete()
681+
>>> pipeline.execute()
682+
683+
Returns:
684+
A new Pipeline object with this stage appended to the stage list
685+
"""
686+
return self._append(stages.Delete())
687+
688+
def update(self, *transformed_fields: "Selectable") -> "_BasePipeline":
689+
"""
690+
Performs an update operation using documents from previous stages.
691+
692+
If called without `transformed_fields`, this method updates the documents in
693+
place based on the data flowing through the pipeline.
694+
695+
To update specific fields with new values, provide `Selectable` expressions that define the
696+
transformations to apply.
697+
698+
Example 1: Update a collection's schema by adding a new field and removing an old one.
699+
>>> from google.cloud.firestore_v1.pipeline_expressions import Constant
700+
>>> pipeline = client.pipeline().collection("books")
701+
>>> pipeline = pipeline.add_fields(Constant.of("Fiction").as_("genre"))
702+
>>> pipeline = pipeline.remove_fields("old_genre").update()
703+
>>> pipeline.execute()
704+
705+
Example 2: Update documents in place with data from literals.
706+
>>> pipeline = client.pipeline().literals(
707+
... {"__name__": client.collection("books").document("book1"), "status": "Updated"}
708+
... ).update()
709+
>>> pipeline.execute()
710+
711+
Example 3: Update documents from previous stages with specified transformations.
712+
>>> from google.cloud.firestore_v1.pipeline_expressions import Field, Constant
713+
>>> pipeline = client.pipeline().collection("books")
714+
>>> # Update the "status" field to "Discounted" for all books where price > 50
715+
>>> pipeline = pipeline.where(Field.of("price").greater_than(50))
716+
>>> pipeline = pipeline.update(Constant.of("Discounted").as_("status"))
717+
>>> pipeline.execute()
718+
719+
Args:
720+
*transformed_fields: Optional. The transformations to apply. If not provided,
721+
the update is performed in place based on the data flowing through the pipeline.
722+
723+
Returns:
724+
A new Pipeline object with this stage appended to the stage list
725+
"""
726+
return self._append(stages.Update(*transformed_fields))
727+
672728
def define(self, *aliased_expressions: AliasedExpression) -> "_BasePipeline":
673729
"""
674730
Binds one or more expressions to Variables that can be accessed in subsequent stages

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,27 @@ def _pb_args(self):
496496
return [self.condition._to_pb()]
497497

498498

499+
class Delete(Stage):
500+
"""Deletes documents matching the pipeline criteria."""
501+
502+
def __init__(self):
503+
super().__init__("delete")
504+
505+
def _pb_args(self) -> list[Value]:
506+
return []
507+
508+
509+
class Update(Stage):
510+
"""Updates documents with transformed fields."""
511+
512+
def __init__(self, *transformed_fields: Selectable):
513+
super().__init__("update")
514+
self.transformed_fields = list(transformed_fields)
515+
516+
def _pb_args(self) -> list[Value]:
517+
return [Selectable._to_value(self.transformed_fields)]
518+
519+
499520
class Define(Stage):
500521
"""Binds one or more expressions to variables."""
501522

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
data:
2+
dml_delete_coll:
3+
doc1: { score: 10 }
4+
doc2: { score: 60 }
5+
dml_update_coll:
6+
doc1: { status: "pending", score: 50 }
7+
8+
tests:
9+
- description: "Basic DML delete"
10+
pipeline:
11+
- Collection: dml_delete_coll
12+
- Where:
13+
FunctionExpression.less_than:
14+
- Field: score
15+
- Constant: 50
16+
- Delete:
17+
assert_end_state:
18+
dml_delete_coll/doc1: null
19+
dml_delete_coll/doc2: { score: 60 }
20+
assert_proto:
21+
pipeline:
22+
stages:
23+
- args:
24+
- referenceValue: /dml_delete_coll
25+
name: collection
26+
- args:
27+
- functionValue:
28+
args:
29+
- fieldReferenceValue: score
30+
- integerValue: '50'
31+
name: less_than
32+
name: where
33+
- name: delete
34+
35+
- description: "Basic DML update"
36+
pipeline:
37+
- Collection: dml_update_coll
38+
- Update:
39+
- AliasedExpression:
40+
- Constant: "active"
41+
- "status"
42+
assert_end_state:
43+
dml_update_coll/doc1: { status: "active", score: 50 }
44+
assert_proto:
45+
pipeline:
46+
stages:
47+
- args:
48+
- referenceValue: /dml_update_coll
49+
name: collection
50+
- args:
51+
- mapValue:
52+
fields:
53+
status:
54+
stringValue: active
55+
name: update

packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def test_pipeline_expected_errors(test_dict, client):
119119
if "assert_results" in t
120120
or "assert_count" in t
121121
or "assert_results_approximate" in t
122+
or "assert_end_state" in t
122123
],
123124
ids=id_format,
124125
)
@@ -131,6 +132,7 @@ def test_pipeline_results(test_dict, client):
131132
test_dict.get("assert_results_approximate", None)
132133
)
133134
expected_count = test_dict.get("assert_count", None)
135+
expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {}))
134136
pipeline = parse_pipeline(client, test_dict["pipeline"])
135137
# check if server responds as expected
136138
got_results = [snapshot.data() for snapshot in pipeline.stream()]
@@ -146,6 +148,19 @@ def test_pipeline_results(test_dict, client):
146148
)
147149
if expected_count is not None:
148150
assert len(got_results) == expected_count
151+
if expected_end_state:
152+
for doc_path, expected_content in expected_end_state.items():
153+
doc_ref = client.document(doc_path)
154+
snapshot = doc_ref.get()
155+
if expected_content is None:
156+
assert not snapshot.exists, (
157+
f"Expected {doc_path} to be absent, but it exists"
158+
)
159+
else:
160+
assert snapshot.exists, (
161+
f"Expected {doc_path} to exist, but it was absent"
162+
)
163+
assert snapshot.to_dict() == expected_content
149164

150165

151166
@pytest.mark.parametrize(
@@ -176,6 +191,7 @@ async def test_pipeline_expected_errors_async(test_dict, async_client):
176191
if "assert_results" in t
177192
or "assert_count" in t
178193
or "assert_results_approximate" in t
194+
or "assert_end_state" in t
179195
],
180196
ids=id_format,
181197
)
@@ -189,6 +205,7 @@ async def test_pipeline_results_async(test_dict, async_client):
189205
test_dict.get("assert_results_approximate", None)
190206
)
191207
expected_count = test_dict.get("assert_count", None)
208+
expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {}))
192209
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
193210
# check if server responds as expected
194211
got_results = [snapshot.data() async for snapshot in pipeline.stream()]
@@ -204,6 +221,19 @@ async def test_pipeline_results_async(test_dict, async_client):
204221
)
205222
if expected_count is not None:
206223
assert len(got_results) == expected_count
224+
if expected_end_state:
225+
for doc_path, expected_content in expected_end_state.items():
226+
doc_ref = async_client.document(doc_path)
227+
snapshot = await doc_ref.get()
228+
if expected_content is None:
229+
assert not snapshot.exists, (
230+
f"Expected {doc_path} to be absent, but it exists"
231+
)
232+
else:
233+
assert snapshot.exists, (
234+
f"Expected {doc_path} to exist, but it was absent"
235+
)
236+
assert snapshot.to_dict() == expected_content
207237

208238

209239
#################################################################################
@@ -223,7 +253,12 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]):
223253
# find arguments if given
224254
if isinstance(stage, dict):
225255
stage_yaml_args = stage[stage_name]
226-
stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args)
256+
if stage_yaml_args is None:
257+
stage_obj = stage_cls()
258+
else:
259+
stage_obj = _apply_yaml_args_to_callable(
260+
stage_cls, client, stage_yaml_args
261+
)
227262
else:
228263
# yaml has no arguments
229264
stage_obj = stage_cls()
@@ -291,20 +326,21 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args):
291326
Helper to instantiate a class with yaml arguments. The arguments will be applied
292327
as positional or keyword arguments, based on type
293328
"""
294-
if isinstance(yaml_args, dict):
295-
return callable_obj(**_parse_expressions(client, yaml_args))
329+
parsed = _parse_expressions(client, yaml_args)
330+
if isinstance(yaml_args, dict) and isinstance(parsed, dict):
331+
return callable_obj(**parsed)
296332
elif isinstance(yaml_args, list) and not (
297333
callable_obj == expr.Constant
298334
or callable_obj == Vector
299335
or callable_obj == expr.Array
300336
):
301337
# yaml has an array of arguments. Treat as args
302-
return callable_obj(*_parse_expressions(client, yaml_args))
303-
elif yaml_args is None:
338+
return callable_obj(*parsed)
339+
elif yaml_args is None and callable_obj != expr.Constant:
304340
return callable_obj()
305341
else:
306342
# yaml has a single argument
307-
return callable_obj(_parse_expressions(client, yaml_args))
343+
return callable_obj(parsed)
308344

309345

310346
def _is_expr_string(yaml_str):

packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,3 +960,39 @@ def test_to_pb(self):
960960
assert got_fn.args[0].field_reference_value == "city"
961961
assert got_fn.args[1].string_value == "SF"
962962
assert len(result.options) == 0
963+
964+
965+
class TestDelete:
966+
def _make_one(self):
967+
return stages.Delete()
968+
969+
def test_to_pb(self):
970+
instance = self._make_one()
971+
result = instance._to_pb()
972+
assert result.name == "delete"
973+
assert len(result.args) == 0
974+
assert len(result.options) == 0
975+
976+
977+
class TestUpdate:
978+
def _make_one(self, *args):
979+
return stages.Update(*args)
980+
981+
def test_to_pb_empty(self):
982+
instance = self._make_one()
983+
result = instance._to_pb()
984+
assert result.name == "update"
985+
assert len(result.args) == 1
986+
assert result.args[0].map_value.fields == {}
987+
assert len(result.options) == 0
988+
989+
def test_to_pb_with_fields(self):
990+
instance = self._make_one(
991+
Field.of("score").add(10).as_("score"), Constant.of("active").as_("status")
992+
)
993+
result = instance._to_pb()
994+
assert result.name == "update"
995+
assert len(result.args) == 1
996+
assert "score" in result.args[0].map_value.fields
997+
assert "status" in result.args[0].map_value.fields
998+
assert len(result.options) == 0

0 commit comments

Comments
 (0)