Skip to content

Commit a189bfb

Browse files
authored
Merge pull request google#171 from google/spanner-2
Upgrade spanner-orm to work with google-cloud-spanner v2
2 parents adf63ef + eb8e89b commit a189bfb

8 files changed

Lines changed: 95 additions & 94 deletions

File tree

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
pip install \
4444
absl-py \
4545
google-api-core \
46-
'google-cloud-spanner >= 1.6, <2.0.0dev' \
46+
'google-cloud-spanner >= 2, <4' \
4747
immutabledict \
4848
portpicker \
4949
pytest

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
setup(
1919
name='spanner-orm',
20-
version='0.1.10',
20+
version='0.2.0',
2121
description='Basic ORM for Spanner',
2222
maintainer='Python Spanner ORM developers',
2323
maintainer_email='python-spanner-orm@google.com',
@@ -26,7 +26,7 @@
2626
include_package_data=True,
2727
python_requires='~=3.7',
2828
install_requires=[
29-
'google-cloud-spanner >= 1.6, <4',
29+
'google-cloud-spanner >= 2, <4',
3030
'immutabledict',
3131
],
3232
tests_require=['absl-py', 'google-api-core', 'portpicker'],

spanner_orm/condition.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from spanner_orm import relationship
3131

3232
from google.api_core import datetime_helpers
33-
from google.cloud.spanner_v1.proto import type_pb2
33+
from google.cloud import spanner
34+
from google.cloud import spanner_v1
3435
import immutabledict
3536

3637
T = TypeVar('T')
@@ -105,7 +106,7 @@ def sql(self) -> str:
105106
def _sql(self) -> str:
106107
pass
107108

108-
def types(self) -> Dict[str, type_pb2.Type]:
109+
def types(self) -> Dict[str, spanner_v1.Type]:
109110
"""Returns parameter types to be used in the SQL query.
110111
111112
Returns:
@@ -117,7 +118,7 @@ def types(self) -> Dict[str, type_pb2.Type]:
117118
return self._types()
118119

119120
@abc.abstractmethod
120-
def _types(self) -> Dict[str, type_pb2.Type]:
121+
def _types(self) -> Dict[str, spanner_v1.Type]:
121122
raise NotImplementedError
122123

123124
@abc.abstractmethod
@@ -158,7 +159,8 @@ def _validate(self, model_class: Type[Any]) -> None:
158159
]
159160

160161

161-
def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
162+
def _spanner_type_of_python_object(
163+
value: GuessableParamType) -> spanner_v1.Type:
162164
"""Returns the Cloud Spanner type of the given object.
163165
164166
Args:
@@ -173,31 +175,27 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
173175
raise TypeError(
174176
'Cannot infer type of None, because any SQL type can be NULL.')
175177
simple_type_code = {
176-
bool: type_pb2.BOOL,
177-
int: type_pb2.INT64,
178-
float: type_pb2.FLOAT64,
179-
datetime_helpers.DatetimeWithNanoseconds: type_pb2.TIMESTAMP,
180-
datetime.datetime: type_pb2.TIMESTAMP,
181-
datetime.date: type_pb2.DATE,
182-
bytes: type_pb2.BYTES,
183-
str: type_pb2.STRING,
184-
decimal.Decimal: type_pb2.NUMERIC,
178+
bool: spanner_v1.TypeCode.BOOL,
179+
int: spanner_v1.TypeCode.INT64,
180+
float: spanner_v1.TypeCode.FLOAT64,
181+
datetime_helpers.DatetimeWithNanoseconds: spanner_v1.TypeCode.TIMESTAMP,
182+
datetime.datetime: spanner_v1.TypeCode.TIMESTAMP,
183+
datetime.date: spanner_v1.TypeCode.DATE,
184+
bytes: spanner_v1.TypeCode.BYTES,
185+
str: spanner_v1.TypeCode.STRING,
186+
decimal.Decimal: spanner_v1.TypeCode.NUMERIC,
185187
}.get(type(value))
186188
if simple_type_code is not None:
187-
return type_pb2.Type(code=simple_type_code)
189+
return spanner_v1.Type(code=simple_type_code)
188190
elif isinstance(value, (list, tuple)):
189191
element_types = tuple(
190192
_spanner_type_of_python_object(item)
191193
for item in value
192194
if item is not None)
193-
unique_element_type_count = len({
194-
# Protos aren't hashable, so use their serializations.
195-
element_type.SerializeToString(deterministic=True)
196-
for element_type in element_types
197-
})
198-
if unique_element_type_count == 1:
199-
return type_pb2.Type(
200-
code=type_pb2.ARRAY,
195+
if element_types and all(
196+
a == b for a, b in zip(element_types, element_types[1:])):
197+
return spanner_v1.Type(
198+
code=spanner_v1.TypeCode.ARRAY,
201199
array_element_type=element_types[0],
202200
)
203201
else:
@@ -211,7 +209,7 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
211209
class Param:
212210
"""Parameter for substitution into a SQL query."""
213211
value: Any
214-
type: type_pb2.Type
212+
type: spanner_v1.Type
215213

216214
@classmethod
217215
def from_value(cls: Type[T], value: GuessableParamType) -> T:
@@ -220,14 +218,13 @@ def from_value(cls: Type[T], value: GuessableParamType) -> T:
220218

221219
# BYTES must be base64-encoded, see
222220
# https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110
223-
if (isinstance(value, bytes) and
224-
guessed_type == type_pb2.Type(code=type_pb2.BYTES)):
221+
if (isinstance(value, bytes) and guessed_type == spanner.param_types.BYTES):
225222
encoded_value = base64.b64encode(value).decode()
226223
elif (isinstance(value, (list, tuple)) and
227224
all(isinstance(x, bytes) for x in value if x is not None) and
228-
guessed_type == type_pb2.Type(
229-
code=type_pb2.ARRAY,
230-
array_element_type=type_pb2.Type(code=type_pb2.BYTES),
225+
guessed_type == spanner_v1.Type(
226+
code=spanner_v1.TypeCode.ARRAY,
227+
array_element_type=spanner.param_types.BYTES,
231228
)):
232229
encoded_value = tuple(
233230
None if item is None else base64.b64encode(item).decode()
@@ -299,7 +296,7 @@ def _params(self) -> Dict[str, Any]:
299296
if isinstance(v, Param)
300297
}
301298

302-
def _types(self) -> Dict[str, type_pb2.Type]:
299+
def _types(self) -> Dict[str, spanner_v1.Type]:
303300
"""See base class."""
304301
return {
305302
self.key(k): v.type
@@ -345,7 +342,7 @@ def _sql(self) -> str:
345342
other_table=self.destination_model_class.table,
346343
other_column=self.destination_column)
347344

348-
def _types(self) -> Dict[str, type_pb2.Type]:
345+
def _types(self) -> Dict[str, spanner_v1.Type]:
349346
return {}
350347

351348
def _validate(self, model_class: Type[Any]) -> None:
@@ -411,7 +408,7 @@ def _sql(self) -> str:
411408
hints = (f'FORCE_INDEX={self.name}', *self._extra_hints)
412409
return f'@{{{",".join(hints)}}}'
413410

414-
def _types(self) -> Dict[str, type_pb2.Type]:
411+
def _types(self) -> Dict[str, spanner_v1.Type]:
415412
return {}
416413

417414
def _validate(self, model_class: Type[Any]) -> None:
@@ -433,7 +430,7 @@ def _sql(self) -> str:
433430
return '({})'.format(' AND '.join(
434431
f'{column} IS NOT NULL' for column in self.index.columns))
435432

436-
def _types(self) -> Dict[str, type_pb2.Type]:
433+
def _types(self) -> Dict[str, spanner_v1.Type]:
437434
return {}
438435

439436

@@ -542,7 +539,7 @@ def segment(self) -> Segment:
542539
def _sql(self) -> str:
543540
return ''
544541

545-
def _types(self) -> Dict[str, type_pb2.Type]:
542+
def _types(self) -> Dict[str, spanner_v1.Type]:
546543
return {}
547544

548545
def _validate(self, model_class: Type[Any]) -> None:
@@ -601,10 +598,10 @@ def _sql(self) -> str:
601598
limit_key=self._limit_key, offset_key=self._offset_key)
602599
return 'LIMIT @{limit_key}'.format(limit_key=self._limit_key)
603600

604-
def _types(self) -> Dict[str, type_pb2.Type]:
605-
types = {self._limit_key: type_pb2.Type(code=type_pb2.INT64)}
601+
def _types(self) -> Dict[str, spanner_v1.Type]:
602+
types = {self._limit_key: spanner.param_types.INT64}
606603
if self.offset:
607-
types[self._offset_key] = type_pb2.Type(code=type_pb2.INT64)
604+
types[self._offset_key] = spanner.param_types.INT64
608605
return types
609606

610607
def _validate(self, model_class: Type[Any]) -> None:
@@ -657,7 +654,7 @@ def _sql(self) -> str:
657654
def segment(self) -> Segment:
658655
return Segment.WHERE
659656

660-
def _types(self) -> type_pb2.Type:
657+
def _types(self) -> spanner_v1.Type:
661658
result = {}
662659
for condition in self.all_conditions:
663660
condition.suffix = str(int(self.suffix or 0) + len(result))
@@ -702,7 +699,7 @@ def _sql(self) -> str:
702699
def segment(self) -> Segment:
703700
return Segment.ORDER_BY
704701

705-
def _types(self) -> type_pb2.Type:
702+
def _types(self) -> spanner_v1.Type:
706703
return {}
707704

708705
def _validate(self, model_class: Type[Any]) -> None:
@@ -747,7 +744,7 @@ def _sql(self) -> str:
747744
operator=self.operator,
748745
column_key=self._column_key)
749746

750-
def _types(self) -> type_pb2.Type:
747+
def _types(self) -> spanner_v1.Type:
751748
return {self._column_key: self.model_class.fields[self.column].grpc_type()}
752749

753750
def _validate(self, model_class: Type[Any]) -> None:
@@ -773,9 +770,10 @@ def _sql(self) -> str:
773770
operator=self.operator,
774771
column_key=self._column_key)
775772

776-
def _types(self) -> type_pb2.Type:
773+
def _types(self) -> spanner_v1.Type:
777774
grpc_type = self.model_class.fields[self.column].grpc_type()
778-
list_type = type_pb2.Type(code=type_pb2.ARRAY, array_element_type=grpc_type)
775+
list_type = spanner_v1.Type(
776+
code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type)
779777
return {self._column_key: list_type}
780778

781779
def _validate(self, model_class: Type[Any]) -> None:
@@ -815,7 +813,7 @@ def _sql(self) -> str:
815813
operator=self.nullable_operator)
816814
return super()._sql()
817815

818-
def _types(self) -> type_pb2.Type:
816+
def _types(self) -> spanner_v1.Type:
819817
if self.is_null():
820818
return {}
821819
return super()._types()

spanner_orm/field.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import datetime
2121
from typing import Any, Type
2222

23-
from google.cloud.spanner_v1.proto import type_pb2
23+
from google.cloud import spanner
24+
from google.cloud import spanner_v1
2425
from spanner_orm import error
2526

2627

@@ -34,7 +35,7 @@ def ddl() -> str:
3435

3536
@staticmethod
3637
@abc.abstractmethod
37-
def grpc_type() -> type_pb2.Type:
38+
def grpc_type() -> spanner_v1.Type:
3839
raise NotImplementedError
3940

4041
@staticmethod
@@ -88,8 +89,8 @@ def ddl() -> str:
8889
return 'BOOL'
8990

9091
@staticmethod
91-
def grpc_type() -> type_pb2.Type:
92-
return type_pb2.Type(code=type_pb2.BOOL)
92+
def grpc_type() -> spanner_v1.Type:
93+
return spanner.param_types.BOOL
9394

9495
@staticmethod
9596
def validate_type(value: Any) -> None:
@@ -105,8 +106,8 @@ def ddl() -> str:
105106
return 'INT64'
106107

107108
@staticmethod
108-
def grpc_type() -> type_pb2.Type:
109-
return type_pb2.Type(code=type_pb2.INT64)
109+
def grpc_type() -> spanner_v1.Type:
110+
return spanner.param_types.INT64
110111

111112
@staticmethod
112113
def validate_type(value: Any) -> None:
@@ -122,8 +123,8 @@ def ddl() -> str:
122123
return 'FLOAT64'
123124

124125
@staticmethod
125-
def grpc_type() -> type_pb2.Type:
126-
return type_pb2.Type(code=type_pb2.FLOAT64)
126+
def grpc_type() -> spanner_v1.Type:
127+
return spanner.param_types.FLOAT64
127128

128129
@staticmethod
129130
def validate_type(value: Any) -> None:
@@ -139,8 +140,8 @@ def ddl() -> str:
139140
return 'STRING(MAX)'
140141

141142
@staticmethod
142-
def grpc_type() -> type_pb2.Type:
143-
return type_pb2.Type(code=type_pb2.STRING)
143+
def grpc_type() -> spanner_v1.Type:
144+
return spanner.param_types.STRING
144145

145146
@staticmethod
146147
def validate_type(value) -> None:
@@ -156,8 +157,8 @@ def ddl() -> str:
156157
return 'ARRAY<STRING(MAX)>'
157158

158159
@staticmethod
159-
def grpc_type() -> type_pb2.Type:
160-
return type_pb2.Type(code=type_pb2.ARRAY)
160+
def grpc_type() -> spanner_v1.Type:
161+
return spanner.param_types.Array(spanner.param_types.STRING)
161162

162163
@staticmethod
163164
def validate_type(value: Any) -> None:
@@ -176,8 +177,8 @@ def ddl() -> str:
176177
return 'TIMESTAMP'
177178

178179
@staticmethod
179-
def grpc_type() -> type_pb2.Type:
180-
return type_pb2.Type(code=type_pb2.TIMESTAMP)
180+
def grpc_type() -> spanner_v1.Type:
181+
return spanner.param_types.TIMESTAMP
181182

182183
@staticmethod
183184
def validate_type(value: Any) -> None:
@@ -193,8 +194,8 @@ def ddl() -> str:
193194
return 'BYTES(MAX)'
194195

195196
@staticmethod
196-
def grpc_type() -> type_pb2.Type:
197-
return type_pb2.Type(code=type_pb2.BYTES)
197+
def grpc_type() -> spanner_v1.Type:
198+
return spanner.param_types.BYTES
198199

199200
@staticmethod
200201
def validate_type(value) -> None:

spanner_orm/table_apis.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from typing import Any, Dict, Iterable, List, Sequence
1919

2020
from google.cloud import spanner
21+
from google.cloud import spanner_v1
2122
from google.cloud.spanner_v1 import transaction as spanner_transaction
22-
from google.cloud.spanner_v1.proto import type_pb2
2323

2424
_logger = logging.getLogger(__name__)
2525

@@ -49,9 +49,10 @@ def find(transaction: spanner_transaction.Transaction, table_name: str,
4949
return list(stream_results)
5050

5151

52-
def sql_query(transaction: spanner_transaction.Transaction, query: str,
53-
parameters: Dict[str, Any],
54-
parameter_types: Dict[str, type_pb2.Type]) -> List[Sequence[Any]]:
52+
def sql_query(
53+
transaction: spanner_transaction.Transaction, query: str,
54+
parameters: Dict[str, Any],
55+
parameter_types: Dict[str, spanner_v1.Type]) -> List[Sequence[Any]]:
5556
"""Executes a given SQL query against the Spanner database.
5657
5758
This isn't technically read-only, but it's necessary to implement the read-

spanner_orm/testlib/spanner_emulator/testlib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def _get_instance(spanner_client: client.Client) -> instance.Instance:
5555
Args:
5656
spanner_client: An initialized spanner client.
5757
"""
58-
existing_instances = list(spanner_client.list_instances())
59-
if existing_instances:
60-
return existing_instances[0]
58+
existing_instances_pb = list(spanner_client.list_instances())
59+
if existing_instances_pb:
60+
return instance.Instance.from_pb(existing_instances_pb[0], spanner_client)
6161

6262
# The emulator has one default config.
6363
config = list(spanner_client.list_instance_configs())[0]

0 commit comments

Comments
 (0)