Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- '3.8'
- '3.9'
- '3.10'
- '3.11'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
14 changes: 8 additions & 6 deletions spanner_orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
"""Helps build SQL for complex Spanner queries."""

import abc
from typing import Any, Dict, Iterable, List, Sequence, Tuple, Type
from typing import Any, Dict, Generic, Iterable, List, Sequence, Tuple, Type, TypeVar

from spanner_orm import condition
from spanner_orm import error

ResultType = TypeVar('ResultType')

class SpannerQuery(abc.ABC):

class SpannerQuery(abc.ABC, Generic[ResultType]):
"""Helps build SQL for complex Spanner queries."""

def __init__(self, model: Type[Any],
Expand All @@ -46,7 +48,7 @@ def types(self) -> Dict[str, Any]:
return self._types

@abc.abstractmethod
def process_results(self, results: List[Sequence[Any]]) -> None:
def process_results(self, results: List[Sequence[Any]]) -> ResultType:
pass

def _segments(self,
Expand Down Expand Up @@ -133,7 +135,7 @@ def _limit(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
return (sql, parameters, types)


class CountQuery(SpannerQuery):
class CountQuery(SpannerQuery[int]):
"""Handles COUNT Spanner queries."""

def __init__(self, model: Type[Any],
Expand All @@ -151,7 +153,7 @@ def process_results(self, results: List[Sequence[Any]]) -> int:
return int(results[0][0])


class SelectQuery(SpannerQuery):
class SelectQuery(SpannerQuery[List[Type[Any]]]):
"""Handles SELECT Spanner queries."""

def __init__(self, model: Type[Any],
Expand Down Expand Up @@ -188,7 +190,7 @@ def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
def process_results(self, results: List[Sequence[Any]]) -> List[Type[Any]]:
return [self._process_row(result) for result in results]

def _process_row(self, row: List[Any]) -> Type[Any]:
def _process_row(self, row: Sequence[Any]) -> Type[Any]:
"""Parses a row of results from a Spanner query based on the conditions."""
values = dict(zip(self._model.columns, row))
join_values = row[len(self._model.columns):]
Expand Down