1515"""Helpers for applying Google Cloud Firestore changes in a transaction."""
1616from __future__ import annotations
1717
18- from typing import TYPE_CHECKING , Any , AsyncGenerator , Callable , Coroutine , Optional
18+ from typing import (
19+ TYPE_CHECKING ,
20+ Any ,
21+ AsyncGenerator ,
22+ Awaitable ,
23+ Callable ,
24+ Coroutine ,
25+ Optional ,
26+ TypeVar ,
27+ ParamSpec ,
28+ Concatenate ,
29+ )
1930
2031from google .api_core import exceptions , gapic_v1
2132from google .api_core import retry_async as retries
4152 from google .cloud .firestore_v1 .query_profile import ExplainOptions
4253
4354
55+ T = TypeVar ("T" )
56+ P = ParamSpec ("P" )
57+
58+
4459class AsyncTransaction (async_batch .AsyncWriteBatch , BaseTransaction ):
4560 """Accumulate read-and-write operations to be sent in a transaction.
4661
@@ -236,11 +251,13 @@ class _AsyncTransactional(_BaseTransactional):
236251 A coroutine that should be run (and retried) in a transaction.
237252 """
238253
239- def __init__ (self , to_wrap ) -> None :
254+ def __init__ (
255+ self , to_wrap : Callable [Concatenate [AsyncTransaction , P ], Awaitable [T ]]
256+ ) -> None :
240257 super (_AsyncTransactional , self ).__init__ (to_wrap )
241258
242259 async def _pre_commit (
243- self , transaction : AsyncTransaction , * args , ** kwargs
260+ self , transaction : AsyncTransaction , * args : P . args , ** kwargs : P . kwargs
244261 ) -> Coroutine :
245262 """Begin transaction and call the wrapped coroutine.
246263
@@ -254,7 +271,7 @@ async def _pre_commit(
254271 along to the wrapped coroutine.
255272
256273 Returns:
257- Any : result of the wrapped coroutine.
274+ T : result of the wrapped coroutine.
258275
259276 Raises:
260277 Exception: Any failure caused by ``to_wrap``.
@@ -269,20 +286,22 @@ async def _pre_commit(
269286 self .retry_id = self .current_id
270287 return await self .to_wrap (transaction , * args , ** kwargs )
271288
272- async def __call__ (self , transaction , * args , ** kwargs ):
289+ async def __call__ (
290+ self , transaction : AsyncTransaction , * args : P .args , ** kwargs : P .kwargs
291+ ) -> T :
273292 """Execute the wrapped callable within a transaction.
274293
275294 Args:
276295 transaction
277- (:class:`~google.cloud.firestore_v1.transaction.Transaction `):
296+ (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction `):
278297 A transaction to execute the callable within.
279298 args (Tuple[Any, ...]): The extra positional arguments to pass
280299 along to the wrapped callable.
281300 kwargs (Dict[str, Any]): The extra keyword arguments to pass
282301 along to the wrapped callable.
283302
284303 Returns:
285- Any : The result of the wrapped callable.
304+ T : The result of the wrapped callable.
286305
287306 Raises:
288307 ValueError: If the transaction does not succeed in
@@ -321,13 +340,13 @@ async def __call__(self, transaction, *args, **kwargs):
321340
322341
323342def async_transactional (
324- to_wrap : Callable [[AsyncTransaction ], Any ]
325- ) -> _AsyncTransactional :
343+ to_wrap : Callable [Concatenate [AsyncTransaction , P ], Awaitable [ T ] ]
344+ ) -> Callable [ Concatenate [ AsyncTransaction , P ], Awaitable [ T ]] :
326345 """Decorate a callable so that it runs in a transaction.
327346
328347 Args:
329348 to_wrap
330- (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction `, ...], Any]):
349+ (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction `, ...], Any]):
331350 A callable that should be run (and retried) in a transaction.
332351
333352 Returns:
0 commit comments