Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
.eggs
*.egg-info
.pytype
env

# Files that may or may not be added to the repo while acquiring the Spanner
# emulator.
Expand Down
30 changes: 28 additions & 2 deletions spanner_orm/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,20 @@ def validate_type(self, value: Any) -> None:
class String(FieldType):
"""Represents a string type."""

def __init__(self, length: Optional[int] = None):
"""Initializer.

Args:
length: Length of the String. MAX is used if not specified.
"""
if length is not None and length <= 0:
raise error.ValidationError('String length must be positive')
self._length = length

def ddl(self) -> str:
"""See base class."""
del self # Unused.
if self._length is not None:
return f'STRING({self._length})'
return 'STRING(MAX)'

def grpc_type(self) -> spanner_v1.Type:
Expand Down Expand Up @@ -218,9 +229,20 @@ def validate_type(self, value: Any) -> None:
class BytesBase64(FieldType):
"""Represents a bytes type that must be base64 encoded."""

def __init__(self, length: Optional[int] = None):
"""Initializer.

Args:
length: Length of the Bytes. MAX is used if not specified.
"""
if length is not None and length <= 0:
raise error.ValidationError('Bytes length must be positive')
self._length = length

def ddl(self) -> str:
"""See base class."""
del self # Unused.
if self._length is not None:
return f'BYTES({self._length})'
return 'BYTES(MAX)'

def grpc_type(self) -> spanner_v1.Type:
Expand Down Expand Up @@ -298,10 +320,14 @@ def field_type_from_ddl(ddl: str) -> FieldType:
return Float()
elif ddl == 'STRING(MAX)':
return String()
elif (match := re.fullmatch(r'STRING\(([0-9]+)\)', ddl)) is not None:
return String(int(match.group(1)))
elif ddl == 'TIMESTAMP':
return Timestamp()
elif ddl == 'BYTES(MAX)':
return BytesBase64()
elif (match := re.fullmatch(r'BYTES\(([0-9]+)\)', ddl)) is not None:
return BytesBase64(int(match.group(1)))
elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None:
return Array(field_type_from_ddl(match.group(1)))
else:
Expand Down
24 changes: 22 additions & 2 deletions spanner_orm/tests/field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ class FieldTest(parameterized.TestCase):
(field.Integer(), 'INT64'),
(field.Float(), 'FLOAT64'),
(field.String(), 'STRING(MAX)'),
(field.String(10), 'STRING(10)'),
(field.Timestamp(), 'TIMESTAMP'),
(field.BytesBase64(), 'BYTES(MAX)'),
(field.BytesBase64(10), 'BYTES(10)'),
(field.Array(field.Boolean()), 'ARRAY<BOOL>'),
(field.Array(field.String()), 'ARRAY<STRING(MAX)>'),
(field.Array(field.String(10)), 'ARRAY<STRING(10)>'),
(field.Array(field.BytesBase64()), 'ARRAY<BYTES(MAX)>'),
(field.Array(field.BytesBase64(10)), 'ARRAY<BYTES(10)>'),
)
def test_field_type_ddl(
self,
Expand All @@ -50,12 +55,16 @@ def test_field_type_ddl(
(field.Integer(), spanner.param_types.INT64),
(field.Float(), spanner.param_types.FLOAT64),
(field.String(), spanner.param_types.STRING),
(field.String(10), spanner.param_types.STRING),
(field.Timestamp(), spanner.param_types.TIMESTAMP),
(field.BytesBase64(), spanner.param_types.BYTES),
(field.BytesBase64(10), spanner.param_types.BYTES),
(field.Array(field.Boolean()),
spanner.param_types.Array(spanner.param_types.BOOL)),
(field.Array(field.String()),
spanner.param_types.Array(spanner.param_types.STRING)),
(field.Array(field.String(10)),
spanner.param_types.Array(spanner.param_types.STRING)),
)
def test_field_type_grpc_type(
self,
Expand All @@ -70,8 +79,10 @@ def test_field_type_grpc_type(
(field.Float(), 1),
(field.Float(), 1.0),
(field.String(), 'foo'),
(field.String(10), 'foo'),
(field.Timestamp(), datetime.datetime(2022, 9, 21)),
(field.BytesBase64(), base64.b64encode(b'\x00')),
(field.BytesBase64(10), base64.b64encode(b'\x00')),
(field.Array(field.Boolean()), [True]),
)
def test_field_type_validate_type_ok(
Expand All @@ -86,9 +97,11 @@ def test_field_type_validate_type_ok(
(field.Integer(), 1.0),
(field.Float(), '1.0'),
(field.String(), b'foo'),
(field.String(10), b'foo'),
(field.Timestamp(), datetime.date(2022, 9, 21)),
(field.BytesBase64(), base64.b64encode(b'\x00').decode('utf-8')),
(field.BytesBase64(), b'!'),
(field.BytesBase64(10), b'!'),
(field.Array(field.Boolean()), {True}),
(field.Array(field.Boolean()), [1]),
)
Expand All @@ -103,6 +116,8 @@ def test_field_type_validate_type_error(
@parameterized.parameters(
(field.Boolean(), field.Boolean(), True),
(field.Boolean(), field.String(), False),
(field.String(10), field.String(20), True),
Comment thread
hazyd marked this conversation as resolved.
(field.String(), field.String(10), True),
(field.Array(field.Integer()), field.Array(field.Integer()), False),
(field.Array(field.Integer()), field.Integer(), False),
)
Expand Down Expand Up @@ -145,17 +160,22 @@ def test_string_array_is_deprecated_and_equivalent_to_array_of_string(self):
'INT64',
'FLOAT64',
'STRING(MAX)',
'STRING(10)',
'TIMESTAMP',
'BYTES(MAX)',
'BYTES(10)',
'ARRAY<INT64>',
'ARRAY<STRING(MAX)>',
'ARRAY<STRING(10)>',
)
def test_ddl_to_field_type_to_ddl(self, ddl: str):
self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl)
Comment thread
hazyd marked this conversation as resolved.

def test_field_type_from_ddl_invalid(self):
@parameterized.parameters('UNICORN(MAX)', 'STRING(MAX1)', 'STRING(MIN)',
'ARRAY<STRING(MAX1)>', 'BYTES(MAX1)', 'BYTES(MIN)')
def test_field_type_from_ddl_invalid(self, ddl: str):
with self.assertRaisesRegex(error.SpannerError, 'DDL type'):
field.field_type_from_ddl('UNICORN(MAX)')
field.field_type_from_ddl(ddl)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Spanner ORM migration: create_custom_length_field.

Migration ID: 'f959b767457d'
Created: 2022-09-13 13:28:34-07:00
"""

import spanner_orm

migration_id = 'f959b767457d'
prev_migration_id = '69a8f072dacf'


class OriginalTeeTable(spanner_orm.model.Model):
"""ORM Model with the original schema for the Commands table.
Don't update this model, create new migrations instead.
"""

__table__ = 'Tee'
id = spanner_orm.Field(spanner_orm.String, primary_key=True)
custom_string_length = spanner_orm.Field(spanner_orm.String(20))
custom_array_string_length = spanner_orm.Field(
spanner_orm.Array(spanner_orm.String(4)))
custom_bytes_length = spanner_orm.Field(spanner_orm.BytesBase64(20))
custom_array_bytes_length = spanner_orm.Field(
spanner_orm.Array(spanner_orm.BytesBase64(4)))


def upgrade() -> spanner_orm.CreateTable:
Comment thread
dseomn marked this conversation as resolved.
"""Creates the original Commands table."""
return spanner_orm.CreateTable(OriginalTeeTable)


def downgrade() -> spanner_orm.DropTable:
"""Drops the original Commands table."""
return spanner_orm.DropTable(OriginalTeeTable.__table__)
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ class OriginalUnittestModelTable(spanner_orm.model.Model):
float_2 = field.Field(field.Float, nullable=True)
string = field.Field(field.String, primary_key=True)
string_2 = field.Field(field.String, nullable=True)
string_3 = field.Field(field.String(20), nullable=True)
bytes_ = field.Field(field.BytesBase64, primary_key=True)
bytes_2 = field.Field(field.BytesBase64, nullable=True)
bytes_3 = field.Field(field.BytesBase64(20), nullable=True)
timestamp = field.Field(field.Timestamp)
string_array = field.Field(field.StringArray, nullable=True)
string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)


def upgrade() -> spanner_orm.CreateTable:
Expand Down
6 changes: 3 additions & 3 deletions spanner_orm/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def test_set_error_on_primary_key(self):
with self.assertRaises(AttributeError):
test_model.key = 'error'

@parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'),
('string_2', 5), ('bytes_2', 'string'),
('string_array', 'foo'), ('timestamp', 5))
@parameterized.parameters(
('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), ('string_3', 5),
('bytes_2', 'string'), ('string_array', 'foo'), ('timestamp', 5))
def test_set_error_on_invalid_type(self, attribute, value):
string_array = ['foo', 'bar']
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
Expand Down
6 changes: 6 additions & 0 deletions spanner_orm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,13 @@ class UnittestModel(model.Model):
float_2 = field.Field(field.Float, nullable=True)
string = field.Field(field.String, primary_key=True)
string_2 = field.Field(field.String, nullable=True)
string_3 = field.Field(field.String(20), nullable=True)
bytes_ = field.Field(field.BytesBase64, primary_key=True)
bytes_2 = field.Field(field.BytesBase64, nullable=True)
bytes_3 = field.Field(field.BytesBase64(20), nullable=True)
timestamp = field.Field(field.Timestamp)
string_array = field.Field(field.StringArray, nullable=True)
string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)

test_index = index.Index(['string_2'])

Expand All @@ -125,10 +128,13 @@ class UnittestModelWithoutSecondaryIndexes(model.Model):
float_2 = field.Field(field.Float, nullable=True)
string = field.Field(field.String, primary_key=True)
string_2 = field.Field(field.String, nullable=True)
string_3 = field.Field(field.String(20), nullable=True)
bytes_ = field.Field(field.BytesBase64, primary_key=True)
bytes_2 = field.Field(field.BytesBase64, nullable=True)
bytes_3 = field.Field(field.BytesBase64(20), nullable=True)
timestamp = field.Field(field.Timestamp)
string_array = field.Field(field.StringArray, nullable=True)
string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)


class NullFilteredIndexModel(model.Model):
Expand Down
9 changes: 6 additions & 3 deletions spanner_orm/tests/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,13 @@ def test_create_table(self, get_model):
test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,'
' float_ FLOAT64 NOT NULL, float_2 FLOAT64,'
' string STRING(MAX) NOT NULL, string_2 STRING(MAX),'
' string_3 STRING(20),'
' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),'
' timestamp TIMESTAMP NOT NULL, string_array'
' ARRAY<STRING(MAX)>) PRIMARY KEY '
'(int_, float_, string, bytes_)')
' bytes_3 BYTES(20),'
' timestamp TIMESTAMP NOT NULL,'
' string_array ARRAY<STRING(MAX)>,'
' string_array_2 ARRAY<STRING(20)>)'
' PRIMARY KEY (int_, float_, string, bytes_)')
self.assertEqual(test_update.ddl(), test_model_ddl)

@mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model')
Expand Down