Skip to content

Commit 71ec2db

Browse files
committed
Upgrade spanner-orm to work with google-cloud-spanner v2
1 parent ff4ae10 commit 71ec2db

7 files changed

Lines changed: 93 additions & 89 deletions

File tree

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
setup(
1919
name='spanner-orm',
20-
version='0.1.10',
20+
version='0.2.0',
2121
description='Basic ORM for Spanner',
2222
maintainer='Python Spanner ORM developers',
2323
maintainer_email='python-spanner-orm@google.com',
@@ -26,7 +26,7 @@
2626
include_package_data=True,
2727
python_requires='~=3.7',
2828
install_requires=[
29-
'google-cloud-spanner >= 1.6, <4',
29+
'google-cloud-spanner >= 2, <4',
3030
'immutabledict',
3131
],
3232
tests_require=['absl-py', 'google-api-core', 'portpicker'],

spanner_orm/condition.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from spanner_orm import relationship
3131

3232
from google.api_core import datetime_helpers
33-
from google.cloud.spanner_v1.proto import type_pb2
33+
from google.cloud import spanner
34+
from google.cloud import spanner_v1
3435
import immutabledict
3536

3637
T = TypeVar('T')
@@ -105,7 +106,7 @@ def sql(self) -> str:
105106
def _sql(self) -> str:
106107
pass
107108

108-
def types(self) -> Dict[str, type_pb2.Type]:
109+
def types(self) -> Dict[str, spanner_v1.Type]:
109110
"""Returns parameter types to be used in the SQL query.
110111
111112
Returns:
@@ -117,7 +118,7 @@ def types(self) -> Dict[str, type_pb2.Type]:
117118
return self._types()
118119

119120
@abc.abstractmethod
120-
def _types(self) -> Dict[str, type_pb2.Type]:
121+
def _types(self) -> Dict[str, spanner_v1.Type]:
121122
raise NotImplementedError
122123

123124
@abc.abstractmethod
@@ -158,7 +159,8 @@ def _validate(self, model_class: Type[Any]) -> None:
158159
]
159160

160161

161-
def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
162+
def _spanner_type_of_python_object(
163+
value: GuessableParamType) -> spanner_v1.Type:
162164
"""Returns the Cloud Spanner type of the given object.
163165
164166
Args:
@@ -173,31 +175,30 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
173175
raise TypeError(
174176
'Cannot infer type of None, because any SQL type can be NULL.')
175177
simple_type_code = {
176-
bool: type_pb2.BOOL,
177-
int: type_pb2.INT64,
178-
float: type_pb2.FLOAT64,
179-
datetime_helpers.DatetimeWithNanoseconds: type_pb2.TIMESTAMP,
180-
datetime.datetime: type_pb2.TIMESTAMP,
181-
datetime.date: type_pb2.DATE,
182-
bytes: type_pb2.BYTES,
183-
str: type_pb2.STRING,
184-
decimal.Decimal: type_pb2.NUMERIC,
178+
bool: spanner_v1.TypeCode.BOOL,
179+
int: spanner_v1.TypeCode.INT64,
180+
float: spanner_v1.TypeCode.FLOAT64,
181+
datetime_helpers.DatetimeWithNanoseconds: spanner_v1.TypeCode.TIMESTAMP,
182+
datetime.datetime: spanner_v1.TypeCode.TIMESTAMP,
183+
datetime.date: spanner_v1.TypeCode.DATE,
184+
bytes: spanner_v1.TypeCode.BYTES,
185+
str: spanner_v1.TypeCode.STRING,
186+
decimal.Decimal: spanner_v1.TypeCode.NUMERIC,
185187
}.get(type(value))
186188
if simple_type_code is not None:
187-
return type_pb2.Type(code=simple_type_code)
189+
return spanner_v1.Type(code=simple_type_code)
188190
elif isinstance(value, (list, tuple)):
189191
element_types = tuple(
190192
_spanner_type_of_python_object(item)
191193
for item in value
192194
if item is not None)
193195
unique_element_type_count = len({
194-
# Protos aren't hashable, so use their serializations.
195-
element_type.SerializeToString(deterministic=True)
196-
for element_type in element_types
196+
# Protos aren't hashable, so serialize them.
197+
str(element_type) for element_type in element_types
197198
})
198199
if unique_element_type_count == 1:
199-
return type_pb2.Type(
200-
code=type_pb2.ARRAY,
200+
return spanner_v1.Type(
201+
code=spanner_v1.TypeCode.ARRAY,
201202
array_element_type=element_types[0],
202203
)
203204
else:
@@ -211,7 +212,7 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
211212
class Param:
212213
"""Parameter for substitution into a SQL query."""
213214
value: Any
214-
type: type_pb2.Type
215+
type: spanner_v1.Type
215216

216217
@classmethod
217218
def from_value(cls: Type[T], value: GuessableParamType) -> T:
@@ -220,14 +221,13 @@ def from_value(cls: Type[T], value: GuessableParamType) -> T:
220221

221222
# BYTES must be base64-encoded, see
222223
# https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110
223-
if (isinstance(value, bytes) and
224-
guessed_type == type_pb2.Type(code=type_pb2.BYTES)):
224+
if (isinstance(value, bytes) and guessed_type == spanner.param_types.BYTES):
225225
encoded_value = base64.b64encode(value).decode()
226226
elif (isinstance(value, (list, tuple)) and
227227
all(isinstance(x, bytes) for x in value if x is not None) and
228-
guessed_type == type_pb2.Type(
229-
code=type_pb2.ARRAY,
230-
array_element_type=type_pb2.Type(code=type_pb2.BYTES),
228+
guessed_type == spanner_v1.Type(
229+
code=spanner_v1.TypeCode.ARRAY,
230+
array_element_type=spanner.param_types.BYTES,
231231
)):
232232
encoded_value = tuple(
233233
None if item is None else base64.b64encode(item).decode()
@@ -299,7 +299,7 @@ def _params(self) -> Dict[str, Any]:
299299
if isinstance(v, Param)
300300
}
301301

302-
def _types(self) -> Dict[str, type_pb2.Type]:
302+
def _types(self) -> Dict[str, spanner_v1.Type]:
303303
"""See base class."""
304304
return {
305305
self.key(k): v.type
@@ -345,7 +345,7 @@ def _sql(self) -> str:
345345
other_table=self.destination_model_class.table,
346346
other_column=self.destination_column)
347347

348-
def _types(self) -> Dict[str, type_pb2.Type]:
348+
def _types(self) -> Dict[str, spanner_v1.Type]:
349349
return {}
350350

351351
def _validate(self, model_class: Type[Any]) -> None:
@@ -389,7 +389,7 @@ def segment(self) -> Segment:
389389
def _sql(self) -> str:
390390
return '@{{FORCE_INDEX={}}}'.format(self.name)
391391

392-
def _types(self) -> Dict[str, type_pb2.Type]:
392+
def _types(self) -> Dict[str, spanner_v1.Type]:
393393
return {}
394394

395395
def _validate(self, model_class: Type[Any]) -> None:
@@ -509,7 +509,7 @@ def segment(self) -> Segment:
509509
def _sql(self) -> str:
510510
return ''
511511

512-
def _types(self) -> Dict[str, type_pb2.Type]:
512+
def _types(self) -> Dict[str, spanner_v1.Type]:
513513
return {}
514514

515515
def _validate(self, model_class: Type[Any]) -> None:
@@ -568,10 +568,10 @@ def _sql(self) -> str:
568568
limit_key=self._limit_key, offset_key=self._offset_key)
569569
return 'LIMIT @{limit_key}'.format(limit_key=self._limit_key)
570570

571-
def _types(self) -> Dict[str, type_pb2.Type]:
572-
types = {self._limit_key: type_pb2.Type(code=type_pb2.INT64)}
571+
def _types(self) -> Dict[str, spanner_v1.Type]:
572+
types = {self._limit_key: spanner.param_types.INT64}
573573
if self.offset:
574-
types[self._offset_key] = type_pb2.Type(code=type_pb2.INT64)
574+
types[self._offset_key] = spanner.param_types.INT64
575575
return types
576576

577577
def _validate(self, model_class: Type[Any]) -> None:
@@ -624,7 +624,7 @@ def _sql(self) -> str:
624624
def segment(self) -> Segment:
625625
return Segment.WHERE
626626

627-
def _types(self) -> type_pb2.Type:
627+
def _types(self) -> spanner_v1.Type:
628628
result = {}
629629
for condition in self.all_conditions:
630630
condition.suffix = str(int(self.suffix or 0) + len(result))
@@ -669,7 +669,7 @@ def _sql(self) -> str:
669669
def segment(self) -> Segment:
670670
return Segment.ORDER_BY
671671

672-
def _types(self) -> type_pb2.Type:
672+
def _types(self) -> spanner_v1.Type:
673673
return {}
674674

675675
def _validate(self, model_class: Type[Any]) -> None:
@@ -714,7 +714,7 @@ def _sql(self) -> str:
714714
operator=self.operator,
715715
column_key=self._column_key)
716716

717-
def _types(self) -> type_pb2.Type:
717+
def _types(self) -> spanner_v1.Type:
718718
return {self._column_key: self.model_class.fields[self.column].grpc_type()}
719719

720720
def _validate(self, model_class: Type[Any]) -> None:
@@ -740,9 +740,10 @@ def _sql(self) -> str:
740740
operator=self.operator,
741741
column_key=self._column_key)
742742

743-
def _types(self) -> type_pb2.Type:
743+
def _types(self) -> spanner_v1.Type:
744744
grpc_type = self.model_class.fields[self.column].grpc_type()
745-
list_type = type_pb2.Type(code=type_pb2.ARRAY, array_element_type=grpc_type)
745+
list_type = spanner_v1.Type(
746+
code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type)
746747
return {self._column_key: list_type}
747748

748749
def _validate(self, model_class: Type[Any]) -> None:
@@ -782,7 +783,7 @@ def _sql(self) -> str:
782783
operator=self.nullable_operator)
783784
return super()._sql()
784785

785-
def _types(self) -> type_pb2.Type:
786+
def _types(self) -> spanner_v1.Type:
786787
if self.is_null():
787788
return {}
788789
return super()._types()

spanner_orm/field.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import datetime
2121
from typing import Any, Type
2222

23-
from google.cloud.spanner_v1.proto import type_pb2
23+
from google.cloud import spanner
24+
from google.cloud import spanner_v1
2425
from spanner_orm import error
2526

2627

@@ -34,7 +35,7 @@ def ddl() -> str:
3435

3536
@staticmethod
3637
@abc.abstractmethod
37-
def grpc_type() -> type_pb2.Type:
38+
def grpc_type() -> spanner_v1.Type:
3839
raise NotImplementedError
3940

4041
@staticmethod
@@ -88,8 +89,8 @@ def ddl() -> str:
8889
return 'BOOL'
8990

9091
@staticmethod
91-
def grpc_type() -> type_pb2.Type:
92-
return type_pb2.Type(code=type_pb2.BOOL)
92+
def grpc_type() -> spanner_v1.Type:
93+
return spanner.param_types.BOOL
9394

9495
@staticmethod
9596
def validate_type(value: Any) -> None:
@@ -105,8 +106,8 @@ def ddl() -> str:
105106
return 'INT64'
106107

107108
@staticmethod
108-
def grpc_type() -> type_pb2.Type:
109-
return type_pb2.Type(code=type_pb2.INT64)
109+
def grpc_type() -> spanner_v1.Type:
110+
return spanner.param_types.INT64
110111

111112
@staticmethod
112113
def validate_type(value: Any) -> None:
@@ -122,8 +123,8 @@ def ddl() -> str:
122123
return 'FLOAT64'
123124

124125
@staticmethod
125-
def grpc_type() -> type_pb2.Type:
126-
return type_pb2.Type(code=type_pb2.FLOAT64)
126+
def grpc_type() -> spanner_v1.Type:
127+
return spanner.param_types.FLOAT64
127128

128129
@staticmethod
129130
def validate_type(value: Any) -> None:
@@ -139,8 +140,8 @@ def ddl() -> str:
139140
return 'STRING(MAX)'
140141

141142
@staticmethod
142-
def grpc_type() -> type_pb2.Type:
143-
return type_pb2.Type(code=type_pb2.STRING)
143+
def grpc_type() -> spanner_v1.Type:
144+
return spanner.param_types.STRING
144145

145146
@staticmethod
146147
def validate_type(value) -> None:
@@ -156,8 +157,8 @@ def ddl() -> str:
156157
return 'ARRAY<STRING(MAX)>'
157158

158159
@staticmethod
159-
def grpc_type() -> type_pb2.Type:
160-
return type_pb2.Type(code=type_pb2.ARRAY)
160+
def grpc_type() -> spanner_v1.Type:
161+
return spanner.param_types.Array(spanner.param_types.STRING)
161162

162163
@staticmethod
163164
def validate_type(value: Any) -> None:
@@ -176,8 +177,8 @@ def ddl() -> str:
176177
return 'TIMESTAMP'
177178

178179
@staticmethod
179-
def grpc_type() -> type_pb2.Type:
180-
return type_pb2.Type(code=type_pb2.TIMESTAMP)
180+
def grpc_type() -> spanner_v1.Type:
181+
return spanner.param_types.TIMESTAMP
181182

182183
@staticmethod
183184
def validate_type(value: Any) -> None:
@@ -193,8 +194,8 @@ def ddl() -> str:
193194
return 'BYTES(MAX)'
194195

195196
@staticmethod
196-
def grpc_type() -> type_pb2.Type:
197-
return type_pb2.Type(code=type_pb2.BYTES)
197+
def grpc_type() -> spanner_v1.Type:
198+
return spanner.param_types.BYTES
198199

199200
@staticmethod
200201
def validate_type(value) -> None:

spanner_orm/table_apis.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable.
2121
from google.cloud import spanner # pytype: disable=import-error
22+
from google.cloud import spanner_v1
2223
from google.cloud.spanner_v1 import transaction as spanner_transaction
23-
from google.cloud.spanner_v1.proto import type_pb2
2424

2525
_logger = logging.getLogger(__name__)
2626

@@ -50,9 +50,10 @@ def find(transaction: spanner_transaction.Transaction, table_name: str,
5050
return list(stream_results)
5151

5252

53-
def sql_query(transaction: spanner_transaction.Transaction, query: str,
54-
parameters: Dict[str, Any],
55-
parameter_types: Dict[str, type_pb2.Type]) -> List[Sequence[Any]]:
53+
def sql_query(
54+
transaction: spanner_transaction.Transaction, query: str,
55+
parameters: Dict[str, Any],
56+
parameter_types: Dict[str, spanner_v1.Type]) -> List[Sequence[Any]]:
5657
"""Executes a given SQL query against the Spanner database.
5758
5859
This isn't technically read-only, but it's necessary to implement the read-

spanner_orm/testlib/spanner_emulator/testlib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def _get_instance(spanner_client: client.Client) -> instance.Instance:
5555
Args:
5656
spanner_client: An initialized spanner client.
5757
"""
58-
existing_instances = list(spanner_client.list_instances())
59-
if existing_instances:
60-
return existing_instances[0]
58+
existing_instances_pb = list(spanner_client.list_instances())
59+
if existing_instances_pb:
60+
return instance.Instance.from_pb(existing_instances_pb[0], spanner_client)
6161

6262
# The emulator has one default config.
6363
config = list(spanner_client.list_instance_configs())[0]

0 commit comments

Comments
 (0)