Skip to content

Commit e2f103b

Browse files
authored
Merge pull request google#176 from hazyd/length
Support length for various fields
2 parents 09c5a4d + 0196241 commit e2f103b

8 files changed

Lines changed: 104 additions & 10 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__
66
.eggs
77
*.egg-info
88
.pytype
9+
env
910

1011
# Files that may or may not be added to the repo while acquiring the Spanner
1112
# emulator.

spanner_orm/field.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,20 @@ def validate_type(self, value: Any) -> None:
178178
class String(FieldType):
179179
"""Represents a string type."""
180180

181+
def __init__(self, length: Optional[int] = None):
182+
"""Initializer.
183+
184+
Args:
185+
length: Length of the String. MAX is used if not specified.
186+
"""
187+
if length is not None and length <= 0:
188+
raise error.ValidationError('String length must be positive')
189+
self._length = length
190+
181191
def ddl(self) -> str:
182192
"""See base class."""
183-
del self # Unused.
193+
if self._length is not None:
194+
return f'STRING({self._length})'
184195
return 'STRING(MAX)'
185196

186197
def grpc_type(self) -> spanner_v1.Type:
@@ -218,9 +229,20 @@ def validate_type(self, value: Any) -> None:
218229
class BytesBase64(FieldType):
219230
"""Represents a bytes type that must be base64 encoded."""
220231

232+
def __init__(self, length: Optional[int] = None):
233+
"""Initializer.
234+
235+
Args:
236+
length: Length of the Bytes. MAX is used if not specified.
237+
"""
238+
if length is not None and length <= 0:
239+
raise error.ValidationError('Bytes length must be positive')
240+
self._length = length
241+
221242
def ddl(self) -> str:
222243
"""See base class."""
223-
del self # Unused.
244+
if self._length is not None:
245+
return f'BYTES({self._length})'
224246
return 'BYTES(MAX)'
225247

226248
def grpc_type(self) -> spanner_v1.Type:
@@ -298,10 +320,14 @@ def field_type_from_ddl(ddl: str) -> FieldType:
298320
return Float()
299321
elif ddl == 'STRING(MAX)':
300322
return String()
323+
elif (match := re.fullmatch(r'STRING\(([0-9]+)\)', ddl)) is not None:
324+
return String(int(match.group(1)))
301325
elif ddl == 'TIMESTAMP':
302326
return Timestamp()
303327
elif ddl == 'BYTES(MAX)':
304328
return BytesBase64()
329+
elif (match := re.fullmatch(r'BYTES\(([0-9]+)\)', ddl)) is not None:
330+
return BytesBase64(int(match.group(1)))
305331
elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None:
306332
return Array(field_type_from_ddl(match.group(1)))
307333
else:

spanner_orm/tests/field_test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,15 @@ class FieldTest(parameterized.TestCase):
3333
(field.Integer(), 'INT64'),
3434
(field.Float(), 'FLOAT64'),
3535
(field.String(), 'STRING(MAX)'),
36+
(field.String(10), 'STRING(10)'),
3637
(field.Timestamp(), 'TIMESTAMP'),
3738
(field.BytesBase64(), 'BYTES(MAX)'),
39+
(field.BytesBase64(10), 'BYTES(10)'),
3840
(field.Array(field.Boolean()), 'ARRAY<BOOL>'),
3941
(field.Array(field.String()), 'ARRAY<STRING(MAX)>'),
42+
(field.Array(field.String(10)), 'ARRAY<STRING(10)>'),
43+
(field.Array(field.BytesBase64()), 'ARRAY<BYTES(MAX)>'),
44+
(field.Array(field.BytesBase64(10)), 'ARRAY<BYTES(10)>'),
4045
)
4146
def test_field_type_ddl(
4247
self,
@@ -50,12 +55,16 @@ def test_field_type_ddl(
5055
(field.Integer(), spanner.param_types.INT64),
5156
(field.Float(), spanner.param_types.FLOAT64),
5257
(field.String(), spanner.param_types.STRING),
58+
(field.String(10), spanner.param_types.STRING),
5359
(field.Timestamp(), spanner.param_types.TIMESTAMP),
5460
(field.BytesBase64(), spanner.param_types.BYTES),
61+
(field.BytesBase64(10), spanner.param_types.BYTES),
5562
(field.Array(field.Boolean()),
5663
spanner.param_types.Array(spanner.param_types.BOOL)),
5764
(field.Array(field.String()),
5865
spanner.param_types.Array(spanner.param_types.STRING)),
66+
(field.Array(field.String(10)),
67+
spanner.param_types.Array(spanner.param_types.STRING)),
5968
)
6069
def test_field_type_grpc_type(
6170
self,
@@ -70,8 +79,10 @@ def test_field_type_grpc_type(
7079
(field.Float(), 1),
7180
(field.Float(), 1.0),
7281
(field.String(), 'foo'),
82+
(field.String(10), 'foo'),
7383
(field.Timestamp(), datetime.datetime(2022, 9, 21)),
7484
(field.BytesBase64(), base64.b64encode(b'\x00')),
85+
(field.BytesBase64(10), base64.b64encode(b'\x00')),
7586
(field.Array(field.Boolean()), [True]),
7687
)
7788
def test_field_type_validate_type_ok(
@@ -86,9 +97,11 @@ def test_field_type_validate_type_ok(
8697
(field.Integer(), 1.0),
8798
(field.Float(), '1.0'),
8899
(field.String(), b'foo'),
100+
(field.String(10), b'foo'),
89101
(field.Timestamp(), datetime.date(2022, 9, 21)),
90102
(field.BytesBase64(), base64.b64encode(b'\x00').decode('utf-8')),
91103
(field.BytesBase64(), b'!'),
104+
(field.BytesBase64(10), b'!'),
92105
(field.Array(field.Boolean()), {True}),
93106
(field.Array(field.Boolean()), [1]),
94107
)
@@ -103,6 +116,8 @@ def test_field_type_validate_type_error(
103116
@parameterized.parameters(
104117
(field.Boolean(), field.Boolean(), True),
105118
(field.Boolean(), field.String(), False),
119+
(field.String(10), field.String(20), True),
120+
(field.String(), field.String(10), True),
106121
(field.Array(field.Integer()), field.Array(field.Integer()), False),
107122
(field.Array(field.Integer()), field.Integer(), False),
108123
)
@@ -145,17 +160,22 @@ def test_string_array_is_deprecated_and_equivalent_to_array_of_string(self):
145160
'INT64',
146161
'FLOAT64',
147162
'STRING(MAX)',
163+
'STRING(10)',
148164
'TIMESTAMP',
149165
'BYTES(MAX)',
166+
'BYTES(10)',
150167
'ARRAY<INT64>',
151168
'ARRAY<STRING(MAX)>',
169+
'ARRAY<STRING(10)>',
152170
)
153171
def test_ddl_to_field_type_to_ddl(self, ddl: str):
154172
self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl)
155173

156-
def test_field_type_from_ddl_invalid(self):
174+
@parameterized.parameters('UNICORN(MAX)', 'STRING(MAX1)', 'STRING(MIN)',
175+
'ARRAY<STRING(MAX1)>', 'BYTES(MAX1)', 'BYTES(MIN)')
176+
def test_field_type_from_ddl_invalid(self, ddl: str):
157177
with self.assertRaisesRegex(error.SpannerError, 'DDL type'):
158-
field.field_type_from_ddl('UNICORN(MAX)')
178+
field.field_type_from_ddl(ddl)
159179

160180

161181
if __name__ == '__main__':
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Spanner ORM migration: create_custom_length_field.
2+
3+
Migration ID: 'f959b767457d'
4+
Created: 2022-09-13 13:28:34-07:00
5+
"""
6+
7+
import spanner_orm
8+
9+
migration_id = 'f959b767457d'
10+
prev_migration_id = '69a8f072dacf'
11+
12+
13+
class OriginalTeeTable(spanner_orm.model.Model):
14+
"""ORM Model with the original schema for the Commands table.
15+
Don't update this model, create new migrations instead.
16+
"""
17+
18+
__table__ = 'Tee'
19+
id = spanner_orm.Field(spanner_orm.String, primary_key=True)
20+
custom_string_length = spanner_orm.Field(spanner_orm.String(20))
21+
custom_array_string_length = spanner_orm.Field(
22+
spanner_orm.Array(spanner_orm.String(4)))
23+
custom_bytes_length = spanner_orm.Field(spanner_orm.BytesBase64(20))
24+
custom_array_bytes_length = spanner_orm.Field(
25+
spanner_orm.Array(spanner_orm.BytesBase64(4)))
26+
27+
28+
def upgrade() -> spanner_orm.CreateTable:
29+
"""Creates the original Commands table."""
30+
return spanner_orm.CreateTable(OriginalTeeTable)
31+
32+
33+
def downgrade() -> spanner_orm.DropTable:
34+
"""Drops the original Commands table."""
35+
return spanner_orm.DropTable(OriginalTeeTable.__table__)

spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ class OriginalUnittestModelTable(spanner_orm.model.Model):
3434
float_2 = field.Field(field.Float, nullable=True)
3535
string = field.Field(field.String, primary_key=True)
3636
string_2 = field.Field(field.String, nullable=True)
37+
string_3 = field.Field(field.String(20), nullable=True)
3738
bytes_ = field.Field(field.BytesBase64, primary_key=True)
3839
bytes_2 = field.Field(field.BytesBase64, nullable=True)
40+
bytes_3 = field.Field(field.BytesBase64(20), nullable=True)
3941
timestamp = field.Field(field.Timestamp)
4042
string_array = field.Field(field.StringArray, nullable=True)
43+
string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)
4144

4245

4346
def upgrade() -> spanner_orm.CreateTable:

spanner_orm/tests/model_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ def test_set_error_on_primary_key(self):
226226
with self.assertRaises(AttributeError):
227227
test_model.key = 'error'
228228

229-
@parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'),
230-
('string_2', 5), ('bytes_2', 'string'),
231-
('string_array', 'foo'), ('timestamp', 5))
229+
@parameterized.parameters(
230+
('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), ('string_3', 5),
231+
('bytes_2', 'string'), ('string_array', 'foo'), ('timestamp', 5))
232232
def test_set_error_on_invalid_type(self, attribute, value):
233233
string_array = ['foo', 'bar']
234234
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)

spanner_orm/tests/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,13 @@ class UnittestModel(model.Model):
107107
float_2 = field.Field(field.Float, nullable=True)
108108
string = field.Field(field.String, primary_key=True)
109109
string_2 = field.Field(field.String, nullable=True)
110+
string_3 = field.Field(field.String(20), nullable=True)
110111
bytes_ = field.Field(field.BytesBase64, primary_key=True)
111112
bytes_2 = field.Field(field.BytesBase64, nullable=True)
113+
bytes_3 = field.Field(field.BytesBase64(20), nullable=True)
112114
timestamp = field.Field(field.Timestamp)
113115
string_array = field.Field(field.StringArray, nullable=True)
116+
string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)
114117

115118
test_index = index.Index(['string_2'])
116119

@@ -125,10 +128,13 @@ class UnittestModelWithoutSecondaryIndexes(model.Model):
125128
float_2 = field.Field(field.Float, nullable=True)
126129
string = field.Field(field.String, primary_key=True)
127130
string_2 = field.Field(field.String, nullable=True)
131+
string_3 = field.Field(field.String(20), nullable=True)
128132
bytes_ = field.Field(field.BytesBase64, primary_key=True)
129133
bytes_2 = field.Field(field.BytesBase64, nullable=True)
134+
bytes_3 = field.Field(field.BytesBase64(20), nullable=True)
130135
timestamp = field.Field(field.Timestamp)
131136
string_array = field.Field(field.StringArray, nullable=True)
137+
string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)
132138

133139

134140
class NullFilteredIndexModel(model.Model):

spanner_orm/tests/update_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ def test_create_table(self, get_model):
8686
test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,'
8787
' float_ FLOAT64 NOT NULL, float_2 FLOAT64,'
8888
' string STRING(MAX) NOT NULL, string_2 STRING(MAX),'
89+
' string_3 STRING(20),'
8990
' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),'
90-
' timestamp TIMESTAMP NOT NULL, string_array'
91-
' ARRAY<STRING(MAX)>) PRIMARY KEY '
92-
'(int_, float_, string, bytes_)')
91+
' bytes_3 BYTES(20),'
92+
' timestamp TIMESTAMP NOT NULL,'
93+
' string_array ARRAY<STRING(MAX)>,'
94+
' string_array_2 ARRAY<STRING(20)>)'
95+
' PRIMARY KEY (int_, float_, string, bytes_)')
9396
self.assertEqual(test_update.ddl(), test_model_ddl)
9497

9598
@mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model')

0 commit comments

Comments
 (0)