Skip to content

Commit faf6663

Browse files
authored
Firestore: Add Watch Support (googleapis#6191)
Firestore watch
1 parent 20f2b92 commit faf6663

File tree

16 files changed

+2497
-5
lines changed

16 files changed

+2497
-5
lines changed

api_core/google/api_core/bidi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def on_response(response):
481481
print(response)
482482
483483
consumer = BackgroundConsumer(rpc, on_response)
484-
consume.start()
484+
consumer.start()
485485
486486
Note that error handling *must* be done by using the provided
487487
``bidi_rpc``'s ``add_done_callback``. This helper will automatically exit

firestore/google/cloud/firestore.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.cloud.firestore_v1beta1 import Transaction
3232
from google.cloud.firestore_v1beta1 import transactional
3333
from google.cloud.firestore_v1beta1 import types
34+
from google.cloud.firestore_v1beta1 import Watch
3435
from google.cloud.firestore_v1beta1 import WriteBatch
3536
from google.cloud.firestore_v1beta1 import WriteOption
3637

@@ -52,6 +53,7 @@
5253
'Transaction',
5354
'transactional',
5455
'types',
56+
'Watch',
5557
'WriteBatch',
5658
'WriteOption',
5759
]

firestore/google/cloud/firestore_v1beta1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from google.cloud.firestore_v1beta1.query import Query
3535
from google.cloud.firestore_v1beta1.transaction import Transaction
3636
from google.cloud.firestore_v1beta1.transaction import transactional
37+
from google.cloud.firestore_v1beta1.watch import Watch
3738

3839

3940
__all__ = [
@@ -53,6 +54,7 @@
5354
'Transaction',
5455
'transactional',
5556
'types',
57+
'Watch',
5658
'WriteBatch',
5759
'WriteOption',
5860
]

firestore/google/cloud/firestore_v1beta1/_helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
"""Common helpers shared across Google Cloud Firestore modules."""
1616

1717

18-
import collections
18+
try:
19+
from collections import abc
20+
except ImportError: # python 2.7
21+
import collections as abc
22+
1923
import datetime
2024
import re
2125

@@ -745,7 +749,7 @@ def get_nested_value(field_path, data):
745749

746750
nested_data = data
747751
for index, field_name in enumerate(field_names):
748-
if isinstance(nested_data, collections.Mapping):
752+
if isinstance(nested_data, abc.Mapping):
749753
if field_name in nested_data:
750754
nested_data = nested_data[field_name]
751755
else:

firestore/google/cloud/firestore_v1beta1/collection.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from google.cloud.firestore_v1beta1 import _helpers
2323
from google.cloud.firestore_v1beta1 import query as query_mod
2424
from google.cloud.firestore_v1beta1.proto import document_pb2
25-
25+
from google.cloud.firestore_v1beta1.watch import Watch
26+
from google.cloud.firestore_v1beta1 import document
2627

2728
_AUTO_ID_CHARS = (
2829
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789')
@@ -371,6 +372,37 @@ def get(self, transaction=None):
371372
query = query_mod.Query(self)
372373
return query.get(transaction=transaction)
373374

375+
def on_snapshot(self, callback):
376+
"""Monitor the documents in this collection.
377+
378+
This starts a watch on this collection using a background thread. The
379+
provided callback is run on the snapshot of the documents.
380+
381+
Args:
382+
callback(~.firestore.collection.CollectionSnapshot): a callback
383+
to run when a change occurs.
384+
385+
Example:
386+
from google.cloud import firestore
387+
388+
db = firestore.Client()
389+
collection_ref = db.collection(u'users')
390+
391+
def on_snapshot(collection_snapshot):
392+
for doc in collection_snapshot.documents:
393+
print(u'{} => {}'.format(doc.id, doc.to_dict()))
394+
395+
# Watch this collection
396+
collection_watch = collection_ref.on_snapshot(on_snapshot)
397+
398+
# Terminate this watch
399+
collection_watch.unsubscribe()
400+
"""
401+
return Watch.for_query(query_mod.Query(self),
402+
callback,
403+
document.DocumentSnapshot,
404+
document.DocumentReference)
405+
374406

375407
def _auto_id():
376408
"""Generate a "random" automatically generated ID.

firestore/google/cloud/firestore_v1beta1/document.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import six
2020

2121
from google.cloud.firestore_v1beta1 import _helpers
22+
from google.cloud.firestore_v1beta1.watch import Watch
2223

2324

2425
class DocumentReference(object):
@@ -445,6 +446,38 @@ def collections(self, page_size=None):
445446
iterator.item_to_value = _item_to_collection_ref
446447
return iterator
447448

449+
def on_snapshot(self, callback):
450+
"""Watch this document.
451+
452+
This starts a watch on this document using a background thread. The
453+
provided callback is run on the snapshot.
454+
455+
Args:
456+
callback(~.firestore.document.DocumentSnapshot):a callback to run
457+
when a change occurs
458+
459+
Example:
460+
from google.cloud import firestore
461+
462+
db = firestore.Client()
463+
collection_ref = db.collection(u'users')
464+
465+
def on_snapshot(document_snapshot):
466+
doc = document_snapshot
467+
print(u'{} => {}'.format(doc.id, doc.to_dict()))
468+
469+
doc_ref = db.collection(u'users').document(
470+
u'alovelace' + unique_resource_id())
471+
472+
# Watch this document
473+
doc_watch = doc_ref.on_snapshot(on_snapshot)
474+
475+
# Terminate this watch
476+
doc_watch.unsubscribe()
477+
"""
478+
return Watch.for_document(self, callback, DocumentSnapshot,
479+
DocumentReference)
480+
448481

449482
class DocumentSnapshot(object):
450483
"""A snapshot of document data in a Firestore database.
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2017 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from enum import Enum
16+
from google.cloud.firestore_v1beta1._helpers import decode_value
17+
import math
18+
19+
20+
class TypeOrder(Enum):
21+
# NOTE: This order is defined by the backend and cannot be changed.
22+
NULL = 0
23+
BOOLEAN = 1
24+
NUMBER = 2
25+
TIMESTAMP = 3
26+
STRING = 4
27+
BLOB = 5
28+
REF = 6
29+
GEO_POINT = 7
30+
ARRAY = 8
31+
OBJECT = 9
32+
33+
@staticmethod
34+
def from_value(value):
35+
v = value.WhichOneof('value_type')
36+
37+
lut = {
38+
'null_value': TypeOrder.NULL,
39+
'boolean_value': TypeOrder.BOOLEAN,
40+
'integer_value': TypeOrder.NUMBER,
41+
'double_value': TypeOrder.NUMBER,
42+
'timestamp_value': TypeOrder.TIMESTAMP,
43+
'string_value': TypeOrder.STRING,
44+
'bytes_value': TypeOrder.BLOB,
45+
'reference_value': TypeOrder.REF,
46+
'geo_point_value': TypeOrder.GEO_POINT,
47+
'array_value': TypeOrder.ARRAY,
48+
'map_value': TypeOrder.OBJECT,
49+
}
50+
51+
if v not in lut:
52+
raise ValueError(
53+
"Could not detect value type for " + v)
54+
return lut[v]
55+
56+
57+
class Order(object):
58+
'''
59+
Order implements the ordering semantics of the backend.
60+
'''
61+
62+
@classmethod
63+
def compare(cls, left, right):
64+
'''
65+
Main comparison function for all Firestore types.
66+
@return -1 is left < right, 0 if left == right, otherwise 1
67+
'''
68+
# First compare the types.
69+
leftType = TypeOrder.from_value(left).value
70+
rightType = TypeOrder.from_value(right).value
71+
72+
if leftType != rightType:
73+
if leftType < rightType:
74+
return -1
75+
return 1
76+
77+
value_type = left.WhichOneof('value_type')
78+
79+
if value_type == 'null_value':
80+
return 0 # nulls are all equal
81+
elif value_type == 'boolean_value':
82+
return cls._compare_to(left.boolean_value, right.boolean_value)
83+
elif value_type == 'integer_value':
84+
return cls.compare_numbers(left, right)
85+
elif value_type == 'double_value':
86+
return cls.compare_numbers(left, right)
87+
elif value_type == 'timestamp_value':
88+
return cls.compare_timestamps(left, right)
89+
elif value_type == 'string_value':
90+
return cls._compare_to(left.string_value, right.string_value)
91+
elif value_type == 'bytes_value':
92+
return cls.compare_blobs(left, right)
93+
elif value_type == 'reference_value':
94+
return cls.compare_resource_paths(left, right)
95+
elif value_type == 'geo_point_value':
96+
return cls.compare_geo_points(left, right)
97+
elif value_type == 'array_value':
98+
return cls.compare_arrays(left, right)
99+
elif value_type == 'map_value':
100+
return cls.compare_objects(left, right)
101+
else:
102+
raise ValueError('Unknown ``value_type``', str(value_type))
103+
104+
@staticmethod
105+
def compare_blobs(left, right):
106+
left_bytes = left.bytes_value
107+
right_bytes = right.bytes_value
108+
109+
return Order._compare_to(left_bytes, right_bytes)
110+
111+
@staticmethod
112+
def compare_timestamps(left, right):
113+
left = left.timestamp_value
114+
right = right.timestamp_value
115+
116+
seconds = Order._compare_to(left.seconds or 0, right.seconds or 0)
117+
if seconds != 0:
118+
return seconds
119+
120+
return Order._compare_to(left.nanos or 0, right.nanos or 0)
121+
122+
@staticmethod
123+
def compare_geo_points(left, right):
124+
left_value = decode_value(left, None)
125+
right_value = decode_value(right, None)
126+
cmp = (
127+
(left_value.latitude > right_value.latitude) -
128+
(left_value.latitude < right_value.latitude)
129+
)
130+
131+
if cmp != 0:
132+
return cmp
133+
return (
134+
(left_value.longitude > right_value.longitude) -
135+
(left_value.longitude < right_value.longitude)
136+
)
137+
138+
@staticmethod
139+
def compare_resource_paths(left, right):
140+
left = left.reference_value
141+
right = right.reference_value
142+
143+
left_segments = left.split('/')
144+
right_segments = right.split('/')
145+
shorter = min(len(left_segments), len(right_segments))
146+
# compare segments
147+
for i in range(shorter):
148+
if (left_segments[i] < right_segments[i]):
149+
return -1
150+
if (left_segments[i] > right_segments[i]):
151+
return 1
152+
153+
left_length = len(left)
154+
right_length = len(right)
155+
return (left_length > right_length) - (left_length < right_length)
156+
157+
@staticmethod
158+
def compare_arrays(left, right):
159+
l_values = left.array_value.values
160+
r_values = right.array_value.values
161+
162+
length = min(len(l_values), len(r_values))
163+
for i in range(length):
164+
cmp = Order.compare(l_values[i], r_values[i])
165+
if cmp != 0:
166+
return cmp
167+
168+
return Order._compare_to(len(l_values), len(r_values))
169+
170+
@staticmethod
171+
def compare_objects(left, right):
172+
left_fields = left.map_value.fields
173+
right_fields = right.map_value.fields
174+
175+
for left_key, right_key in zip(
176+
sorted(left_fields), sorted(right_fields)
177+
):
178+
keyCompare = Order._compare_to(left_key, right_key)
179+
if keyCompare != 0:
180+
return keyCompare
181+
182+
value_compare = Order.compare(
183+
left_fields[left_key], right_fields[right_key])
184+
if value_compare != 0:
185+
return value_compare
186+
187+
return Order._compare_to(len(left_fields), len(right_fields))
188+
189+
@staticmethod
190+
def compare_numbers(left, right):
191+
left_value = decode_value(left, None)
192+
right_value = decode_value(right, None)
193+
return Order.compare_doubles(left_value, right_value)
194+
195+
@staticmethod
196+
def compare_doubles(left, right):
197+
if math.isnan(left):
198+
if math.isnan(right):
199+
return 0
200+
return -1
201+
if math.isnan(right):
202+
return 1
203+
204+
return Order._compare_to(left, right)
205+
206+
@staticmethod
207+
def _compare_to(left, right):
208+
# We can't just use cmp(left, right) because cmp doesn't exist
209+
# in Python 3, so this is an equivalent suggested by
210+
# https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons
211+
return (left > right) - (left < right)

0 commit comments

Comments
 (0)