diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e25a9f7..b198738 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -26,6 +26,7 @@ jobs: - '3.8' - '3.9' - '3.10' + - '3.11' runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/spanner_orm/query.py b/spanner_orm/query.py index 85a35aa..ac09f9a 100644 --- a/spanner_orm/query.py +++ b/spanner_orm/query.py @@ -14,13 +14,15 @@ """Helps build SQL for complex Spanner queries.""" import abc -from typing import Any, Dict, Iterable, List, Sequence, Tuple, Type +from typing import Any, Dict, Generic, Iterable, List, Sequence, Tuple, Type, TypeVar from spanner_orm import condition from spanner_orm import error +ResultType = TypeVar('ResultType') -class SpannerQuery(abc.ABC): + +class SpannerQuery(abc.ABC, Generic[ResultType]): """Helps build SQL for complex Spanner queries.""" def __init__(self, model: Type[Any], @@ -46,7 +48,7 @@ def types(self) -> Dict[str, Any]: return self._types @abc.abstractmethod - def process_results(self, results: List[Sequence[Any]]) -> None: + def process_results(self, results: List[Sequence[Any]]) -> ResultType: pass def _segments(self, @@ -133,7 +135,7 @@ def _limit(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: return (sql, parameters, types) -class CountQuery(SpannerQuery): +class CountQuery(SpannerQuery[int]): """Handles COUNT Spanner queries.""" def __init__(self, model: Type[Any], @@ -151,7 +153,7 @@ def process_results(self, results: List[Sequence[Any]]) -> int: return int(results[0][0]) -class SelectQuery(SpannerQuery): +class SelectQuery(SpannerQuery[List[Type[Any]]]): """Handles SELECT Spanner queries.""" def __init__(self, model: Type[Any], @@ -188,7 +190,7 @@ def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: def process_results(self, results: List[Sequence[Any]]) -> List[Type[Any]]: return [self._process_row(result) for result in results] - def _process_row(self, row: List[Any]) -> Type[Any]: + def _process_row(self, row: Sequence[Any]) -> Type[Any]: """Parses a row of results from a Spanner query based on the conditions.""" values = dict(zip(self._model.columns, row)) join_values = row[len(self._model.columns):]