Skip to content

Commit 7f1cfee

Browse files
authored
Allow passing metadata as part of creating a bidi (googleapis#7514)
* allows providing rpc metadata for bidi streams
1 parent e5a5912 commit 7f1cfee

6 files changed

Lines changed: 39 additions & 12 deletions

File tree

api_core/google/api_core/bidi.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,11 @@ class BidiRpc(object):
147147
148148
initial_request = example_pb2.StreamingRpcRequest(
149149
setting='example')
150-
rpc = BidiRpc(stub.StreamingRpc, initial_request=initial_request)
150+
rpc = BidiRpc(
151+
stub.StreamingRpc,
152+
initial_request=initial_request,
153+
metadata=[('name', 'value')]
154+
)
151155
152156
rpc.open()
153157
@@ -165,11 +169,14 @@ class BidiRpc(object):
165169
Callable[None, protobuf.Message]]): The initial request to
166170
yield. This is useful if an initial request is needed to start the
167171
stream.
172+
metadata (Sequence[Tuple(str, str)]): RPC metadata to include in
173+
the request.
168174
"""
169175

170-
def __init__(self, start_rpc, initial_request=None):
176+
def __init__(self, start_rpc, initial_request=None, metadata=None):
171177
self._start_rpc = start_rpc
172178
self._initial_request = initial_request
179+
self._rpc_metadata = metadata
173180
self._request_queue = queue.Queue()
174181
self._request_generator = None
175182
self._is_active = False
@@ -200,7 +207,7 @@ def open(self):
200207
request_generator = _RequestQueueGenerator(
201208
self._request_queue, initial_request=self._initial_request
202209
)
203-
call = self._start_rpc(iter(request_generator))
210+
call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata)
204211

205212
request_generator.call = call
206213

@@ -288,10 +295,14 @@ def should_recover(exc):
288295
initial_request = example_pb2.StreamingRpcRequest(
289296
setting='example')
290297
291-
rpc = ResumeableBidiRpc(
298+
metadata = [('header_name', 'value')]
299+
300+
rpc = ResumableBidiRpc(
292301
stub.StreamingRpc,
302+
should_recover=should_recover,
293303
initial_request=initial_request,
294-
should_recover=should_recover)
304+
metadata=metadata
305+
)
295306
296307
rpc.open()
297308
@@ -310,10 +321,12 @@ def should_recover(exc):
310321
should_recover (Callable[[Exception], bool]): A function that returns
311322
True if the stream should be recovered. This will be called
312323
whenever an error is encountered on the stream.
324+
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
325+
the request.
313326
"""
314327

315-
def __init__(self, start_rpc, should_recover, initial_request=None):
316-
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request)
328+
def __init__(self, start_rpc, should_recover, initial_request=None, metadata=None):
329+
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
317330
self._should_recover = should_recover
318331
self._operational_lock = threading.RLock()
319332
self._finalized = False

api_core/noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def lint_setup_py(session):
8888
def pytype(session):
8989
"""Run type-checking."""
9090
session.install(
91-
".", "grpcio >= 1.8.2", "grpcio-gcp >= 0.2.2", "pytype >= 2018.9.26"
91+
".", "grpcio >= 1.8.2", "grpcio-gcp >= 0.2.2", "pytype >= 2019.3.21"
9292
)
9393
session.run("pytype")
9494

api_core/tests/unit/test_bidi.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ def make_rpc():
125125
call = mock.create_autospec(_CallAndFuture, instance=True)
126126
rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
127127

128-
def rpc_side_effect(request):
128+
def rpc_side_effect(request, metadata=None):
129129
call.is_active.return_value = True
130130
call.request = request
131+
call.metadata = metadata
131132
return call
132133

133134
rpc.side_effect = rpc_side_effect
@@ -172,6 +173,15 @@ def test_done_callbacks(self):
172173

173174
callback.assert_called_once_with(mock.sentinel.future)
174175

176+
def test_metadata(self):
177+
rpc, call = make_rpc()
178+
bidi_rpc = bidi.BidiRpc(rpc, metadata=mock.sentinel.A)
179+
assert bidi_rpc._rpc_metadata == mock.sentinel.A
180+
181+
bidi_rpc.open()
182+
assert bidi_rpc.call == call
183+
assert bidi_rpc.call.metadata == mock.sentinel.A
184+
175185
def test_open(self):
176186
rpc, call = make_rpc()
177187
bidi_rpc = bidi.BidiRpc(rpc)

firestore/google/cloud/firestore_v1beta1/watch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,10 @@ def should_recover(exc): # pragma: NO COVER
213213
ResumableBidiRpc = self.ResumableBidiRpc # FBO unit tests
214214

215215
self._rpc = ResumableBidiRpc(
216-
self._api.transport._stubs["firestore_stub"].Listen,
216+
self._api.transport.listen,
217217
initial_request=initial_request,
218218
should_recover=should_recover,
219+
rpc_metadata=self._firestore._rpc_metadata,
219220
)
220221

221222
self._rpc.add_done_callback(self._on_rpc_done)

firestore/tests/unit/v1beta1/test_cross_language.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,13 @@ def convert_precondition(precond):
342342

343343

344344
class DummyRpc(object): # pragma: NO COVER
345-
def __init__(self, listen, initial_request, should_recover):
345+
def __init__(self, listen, initial_request, should_recover, rpc_metadata=None):
346346
self.listen = listen
347347
self.initial_request = initial_request
348348
self.should_recover = should_recover
349349
self.closed = False
350350
self.callbacks = []
351+
self._rpc_metadata = rpc_metadata
351352

352353
def add_done_callback(self, callback):
353354
self.callbacks.append(callback)

firestore/tests/unit/v1beta1/test_watch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ def _to_protobuf(self):
713713
class DummyFirestore(object):
714714
_firestore_api = DummyFirestoreClient()
715715
_database_string = "abc://bar/"
716+
_rpc_metadata = None
716717

717718
def document(self, *document_path): # pragma: NO COVER
718719
if len(document_path) == 1:
@@ -781,12 +782,13 @@ def Thread(self, name, target, kwargs):
781782

782783

783784
class DummyRpc(object):
784-
def __init__(self, listen, initial_request, should_recover):
785+
def __init__(self, listen, initial_request, should_recover, rpc_metadata=None):
785786
self.listen = listen
786787
self.initial_request = initial_request
787788
self.should_recover = should_recover
788789
self.closed = False
789790
self.callbacks = []
791+
self._rpc_metadata = rpc_metadata
790792

791793
def add_done_callback(self, callback):
792794
self.callbacks.append(callback)

0 commit comments

Comments
 (0)