Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
maintainer_email='dbrandao@google.com',
url='https://github.com/google/python-spanner-orm',
packages=['spanner_orm', 'spanner_orm.admin'],
python_requires='~=3.7',
install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'],
tests_require=['absl-py'],
entry_points={
Expand Down
4 changes: 2 additions & 2 deletions spanner_orm/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ def equal_to(column, value):
return EqualityCondition(column, value)


def force_index(index):
return ForceIndexCondition(index)
def force_index(forced_index):
return ForceIndexCondition(forced_index)


def greater_than(column, value):
Expand Down
56 changes: 31 additions & 25 deletions spanner_orm/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,45 @@
# limitations under the License.
"""Helper to deal with field types in Spanner interactions."""

from __future__ import annotations

import abc
import datetime
from typing import Any, Type

from google.cloud.spanner_v1.proto import type_pb2


class Field(object):
"""Represents a column in a table as a field in a model."""

def __init__(self, field_type, nullable=False, primary_key=False):
def __init__(self,
field_type: Type[FieldType],
nullable: bool = False,
primary_key: bool = False):
self._type = field_type
self._nullable = nullable
self._primary_key = primary_key
self.name = None

def ddl(self):
def ddl(self) -> str:
if self._nullable:
return self._type.ddl()
return '{field_type} NOT NULL'.format(field_type=self._type.ddl())

def field_type(self):
def field_type(self) -> Type[FieldType]:
return self._type

def grpc_type(self):
def grpc_type(self) -> str:
return self._type.grpc_type()

def nullable(self):
def nullable(self) -> bool:
return self._nullable

def primary_key(self):
def primary_key(self) -> bool:
return self._primary_key

def validate(self, value):
def validate(self, value) -> None:
if value is None:
assert self._nullable, 'None set for non-nullable field'
else:
Expand All @@ -58,81 +64,81 @@ class FieldType(abc.ABC):

@staticmethod
@abc.abstractmethod
def ddl():
def ddl() -> str:
raise NotImplementedError

@staticmethod
@abc.abstractmethod
def grpc_type():
def grpc_type() -> type_pb2.Type:
raise NotImplementedError

@staticmethod
@abc.abstractmethod
def validate_type(value):
def validate_type(value: Any) -> None:
raise NotImplementedError


class Boolean(FieldType):
"""Represents a boolean type."""

@staticmethod
def ddl():
def ddl() -> str:
return 'BOOL'

@staticmethod
def grpc_type():
def grpc_type() -> type_pb2.Type:
return type_pb2.Type(code=type_pb2.BOOL)

@staticmethod
def validate_type(value):
def validate_type(value: Any) -> None:
assert isinstance(value, bool), '{} is not of type bool'.format(value)


class Integer(FieldType):
"""Represents an integer type."""

@staticmethod
def ddl():
def ddl() -> str:
return 'INT64'

@staticmethod
def grpc_type():
def grpc_type() -> type_pb2.Type:
return type_pb2.Type(code=type_pb2.INT64)

@staticmethod
def validate_type(value):
def validate_type(value: Any) -> None:
assert isinstance(value, int), '{} is not of type int'.format(value)


class String(FieldType):
"""Represents a string type."""

@staticmethod
def ddl():
def ddl() -> str:
return 'STRING(MAX)'

@staticmethod
def grpc_type():
def grpc_type() -> type_pb2.Type:
return type_pb2.Type(code=type_pb2.STRING)

@staticmethod
def validate_type(value):
def validate_type(value) -> None:
assert isinstance(value, str), '{} is not of type str'.format(value)


class StringArray(FieldType):
"""Represents an array of strings type."""

@staticmethod
def ddl():
def ddl() -> str:
return 'ARRAY<STRING(MAX)>'

@staticmethod
def grpc_type():
def grpc_type() -> type_pb2.Type:
return type_pb2.Type(code=type_pb2.ARRAY)

@staticmethod
def validate_type(value):
def validate_type(value: Any) -> None:
assert isinstance(value, list), '{} is not of type list'.format(value)
for item in value:
assert isinstance(item, str), '{} is not of type str'.format(item)
Expand All @@ -142,15 +148,15 @@ class Timestamp(FieldType):
"""Represents a timestamp type."""

@staticmethod
def ddl():
def ddl() -> str:
return 'TIMESTAMP'

@staticmethod
def grpc_type():
def grpc_type() -> type_pb2.Type:
return type_pb2.Type(code=type_pb2.TIMESTAMP)

@staticmethod
def validate_type(value):
def validate_type(value: Any) -> None:
assert isinstance(value, datetime.datetime)


Expand Down
17 changes: 10 additions & 7 deletions spanner_orm/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helps to deal with indices on Models."""
"""Represents an index on a Model."""

from typing import List, Optional


class Index(object):
"""Represents an index on a Model."""
PRIMARY_INDEX = 'PRIMARY_KEY'

def __init__(self,
columns,
parent=None,
null_filtered=False,
unique=False,
storing_columns=None):
columns: List[str],
parent: Optional[str] = None,
null_filtered: bool = False,
unique: bool = False,
storing_columns: Optional[List[str]] = None):
assert columns, 'An index must have at least one column'
self.columns = columns
self.name = None
Expand All @@ -33,5 +36,5 @@ def __init__(self,
self.storing_columns = storing_columns or []

@property
def primary(self):
def primary(self) -> bool:
return self.name == self.PRIMARY_INDEX
20 changes: 12 additions & 8 deletions spanner_orm/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
# limitations under the License.
"""Helps define a foreign key relationship between two models."""

from __future__ import annotations

from typing import Dict, List, Type, Union

from spanner_orm import condition
from spanner_orm import error
from spanner_orm import model
Expand All @@ -23,10 +27,10 @@ class Relationship(object):
"""Helps define a foreign key relationship between two models."""

def __init__(self,
destination_handle,
constraints,
is_parent=False,
single=False):
destination_handle: Union[Type[model.Model], str],
constraints: Dict[str, str],
is_parent: bool = False,
single: bool = False):
"""Creates a ModelRelationship.

Args:
Expand All @@ -48,21 +52,21 @@ def __init__(self,
self.origin = None

@property
def conditions(self):
def conditions(self) -> List[condition.Condition]:
assert self.origin, 'Origin must be set before conditions is called'
return self._parse_constraints()

@property
def destination(self):
def destination(self) -> Type[model.Model]:
if not self._destination:
self._destination = model.load_model(self._destination_handle)
return self._destination

@property
def single(self):
def single(self) -> bool:
return self._single

def _parse_constraints(self):
def _parse_constraints(self) -> List[condition.Condition]:
"""Validates the dictionary of constraints and turns it into Conditions."""
conditions = []
for origin_column, destination_column in self._constraints.items():
Expand Down
10 changes: 5 additions & 5 deletions spanner_orm/tests/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ def test_includes(self):
self.assertEmpty(select_query.types())

def test_includes_with_object(self):
select_query = self.includes(models.ChildTestModel.parent)
select_query = self.includes(models.RelationshipTestModel.parent)

# The column order varies between test runs
expected_sql = (
r'SELECT ChildTestModel\S* ChildTestModel\S* ARRAY\(SELECT AS '
r'STRUCT SmallTestModel\S* SmallTestModel\S* SmallTestModel\S* FROM '
r'SmallTestModel WHERE SmallTestModel.key = '
r'ChildTestModel.parent_key\)')
r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* '
r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* '
r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = '
r'RelationshipTestModel.parent_key\)')
self.assertRegex(select_query.sql(), expected_sql)
self.assertEmpty(select_query.parameters())
self.assertEmpty(select_query.types())
Expand Down