3030from spanner_orm import relationship
3131
3232from google .api_core import datetime_helpers
33- from google .cloud .spanner_v1 .proto import type_pb2
33+ from google .cloud import spanner
34+ from google .cloud import spanner_v1
3435import immutabledict
3536
3637T = TypeVar ('T' )
@@ -105,7 +106,7 @@ def sql(self) -> str:
105106 def _sql (self ) -> str :
106107 pass
107108
108- def types (self ) -> Dict [str , type_pb2 .Type ]:
109+ def types (self ) -> Dict [str , spanner_v1 .Type ]:
109110 """Returns parameter types to be used in the SQL query.
110111
111112 Returns:
@@ -117,7 +118,7 @@ def types(self) -> Dict[str, type_pb2.Type]:
117118 return self ._types ()
118119
119120 @abc .abstractmethod
120- def _types (self ) -> Dict [str , type_pb2 .Type ]:
121+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
121122 raise NotImplementedError
122123
123124 @abc .abstractmethod
@@ -158,7 +159,8 @@ def _validate(self, model_class: Type[Any]) -> None:
158159]
159160
160161
161- def _spanner_type_of_python_object (value : GuessableParamType ) -> type_pb2 .Type :
162+ def _spanner_type_of_python_object (
163+ value : GuessableParamType ) -> spanner_v1 .Type :
162164 """Returns the Cloud Spanner type of the given object.
163165
164166 Args:
@@ -173,31 +175,30 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
173175 raise TypeError (
174176 'Cannot infer type of None, because any SQL type can be NULL.' )
175177 simple_type_code = {
176- bool : type_pb2 .BOOL ,
177- int : type_pb2 .INT64 ,
178- float : type_pb2 .FLOAT64 ,
179- datetime_helpers .DatetimeWithNanoseconds : type_pb2 .TIMESTAMP ,
180- datetime .datetime : type_pb2 .TIMESTAMP ,
181- datetime .date : type_pb2 .DATE ,
182- bytes : type_pb2 .BYTES ,
183- str : type_pb2 .STRING ,
184- decimal .Decimal : type_pb2 .NUMERIC ,
178+ bool : spanner_v1 . TypeCode .BOOL ,
179+ int : spanner_v1 . TypeCode .INT64 ,
180+ float : spanner_v1 . TypeCode .FLOAT64 ,
181+ datetime_helpers .DatetimeWithNanoseconds : spanner_v1 . TypeCode .TIMESTAMP ,
182+ datetime .datetime : spanner_v1 . TypeCode .TIMESTAMP ,
183+ datetime .date : spanner_v1 . TypeCode .DATE ,
184+ bytes : spanner_v1 . TypeCode .BYTES ,
185+ str : spanner_v1 . TypeCode .STRING ,
186+ decimal .Decimal : spanner_v1 . TypeCode .NUMERIC ,
185187 }.get (type (value ))
186188 if simple_type_code is not None :
187- return type_pb2 .Type (code = simple_type_code )
189+ return spanner_v1 .Type (code = simple_type_code )
188190 elif isinstance (value , (list , tuple )):
189191 element_types = tuple (
190192 _spanner_type_of_python_object (item )
191193 for item in value
192194 if item is not None )
193195 unique_element_type_count = len ({
194- # Protos aren't hashable, so use their serializations.
195- element_type .SerializeToString (deterministic = True )
196- for element_type in element_types
196+ # Protos aren't hashable, so serialize them.
197+ str (element_type ) for element_type in element_types
197198 })
198199 if unique_element_type_count == 1 :
199- return type_pb2 .Type (
200- code = type_pb2 .ARRAY ,
200+ return spanner_v1 .Type (
201+ code = spanner_v1 . TypeCode .ARRAY ,
201202 array_element_type = element_types [0 ],
202203 )
203204 else :
@@ -211,7 +212,7 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
211212class Param :
212213 """Parameter for substitution into a SQL query."""
213214 value : Any
214- type : type_pb2 .Type
215+ type : spanner_v1 .Type
215216
216217 @classmethod
217218 def from_value (cls : Type [T ], value : GuessableParamType ) -> T :
@@ -220,14 +221,13 @@ def from_value(cls: Type[T], value: GuessableParamType) -> T:
220221
221222 # BYTES must be base64-encoded, see
222223 # https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110
223- if (isinstance (value , bytes ) and
224- guessed_type == type_pb2 .Type (code = type_pb2 .BYTES )):
224+ if (isinstance (value , bytes ) and guessed_type == spanner .param_types .BYTES ):
225225 encoded_value = base64 .b64encode (value ).decode ()
226226 elif (isinstance (value , (list , tuple )) and
227227 all (isinstance (x , bytes ) for x in value if x is not None ) and
228- guessed_type == type_pb2 .Type (
229- code = type_pb2 .ARRAY ,
230- array_element_type = type_pb2 . Type ( code = type_pb2 .BYTES ) ,
228+ guessed_type == spanner_v1 .Type (
229+ code = spanner_v1 . TypeCode .ARRAY ,
230+ array_element_type = spanner . param_types .BYTES ,
231231 )):
232232 encoded_value = tuple (
233233 None if item is None else base64 .b64encode (item ).decode ()
@@ -299,7 +299,7 @@ def _params(self) -> Dict[str, Any]:
299299 if isinstance (v , Param )
300300 }
301301
302- def _types (self ) -> Dict [str , type_pb2 .Type ]:
302+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
303303 """See base class."""
304304 return {
305305 self .key (k ): v .type
@@ -345,7 +345,7 @@ def _sql(self) -> str:
345345 other_table = self .destination_model_class .table ,
346346 other_column = self .destination_column )
347347
348- def _types (self ) -> Dict [str , type_pb2 .Type ]:
348+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
349349 return {}
350350
351351 def _validate (self , model_class : Type [Any ]) -> None :
@@ -389,7 +389,7 @@ def segment(self) -> Segment:
389389 def _sql (self ) -> str :
390390 return '@{{FORCE_INDEX={}}}' .format (self .name )
391391
392- def _types (self ) -> Dict [str , type_pb2 .Type ]:
392+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
393393 return {}
394394
395395 def _validate (self , model_class : Type [Any ]) -> None :
@@ -509,7 +509,7 @@ def segment(self) -> Segment:
509509 def _sql (self ) -> str :
510510 return ''
511511
512- def _types (self ) -> Dict [str , type_pb2 .Type ]:
512+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
513513 return {}
514514
515515 def _validate (self , model_class : Type [Any ]) -> None :
@@ -568,10 +568,10 @@ def _sql(self) -> str:
568568 limit_key = self ._limit_key , offset_key = self ._offset_key )
569569 return 'LIMIT @{limit_key}' .format (limit_key = self ._limit_key )
570570
571- def _types (self ) -> Dict [str , type_pb2 .Type ]:
572- types = {self ._limit_key : type_pb2 . Type ( code = type_pb2 .INT64 ) }
571+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
572+ types = {self ._limit_key : spanner . param_types .INT64 }
573573 if self .offset :
574- types [self ._offset_key ] = type_pb2 . Type ( code = type_pb2 .INT64 )
574+ types [self ._offset_key ] = spanner . param_types .INT64
575575 return types
576576
577577 def _validate (self , model_class : Type [Any ]) -> None :
@@ -624,7 +624,7 @@ def _sql(self) -> str:
624624 def segment (self ) -> Segment :
625625 return Segment .WHERE
626626
627- def _types (self ) -> type_pb2 .Type :
627+ def _types (self ) -> spanner_v1 .Type :
628628 result = {}
629629 for condition in self .all_conditions :
630630 condition .suffix = str (int (self .suffix or 0 ) + len (result ))
@@ -669,7 +669,7 @@ def _sql(self) -> str:
669669 def segment (self ) -> Segment :
670670 return Segment .ORDER_BY
671671
672- def _types (self ) -> type_pb2 .Type :
672+ def _types (self ) -> spanner_v1 .Type :
673673 return {}
674674
675675 def _validate (self , model_class : Type [Any ]) -> None :
@@ -714,7 +714,7 @@ def _sql(self) -> str:
714714 operator = self .operator ,
715715 column_key = self ._column_key )
716716
717- def _types (self ) -> type_pb2 .Type :
717+ def _types (self ) -> spanner_v1 .Type :
718718 return {self ._column_key : self .model_class .fields [self .column ].grpc_type ()}
719719
720720 def _validate (self , model_class : Type [Any ]) -> None :
@@ -740,9 +740,10 @@ def _sql(self) -> str:
740740 operator = self .operator ,
741741 column_key = self ._column_key )
742742
743- def _types (self ) -> type_pb2 .Type :
743+ def _types (self ) -> spanner_v1 .Type :
744744 grpc_type = self .model_class .fields [self .column ].grpc_type ()
745- list_type = type_pb2 .Type (code = type_pb2 .ARRAY , array_element_type = grpc_type )
745+ list_type = spanner_v1 .Type (
746+ code = spanner_v1 .TypeCode .ARRAY , array_element_type = grpc_type )
746747 return {self ._column_key : list_type }
747748
748749 def _validate (self , model_class : Type [Any ]) -> None :
@@ -782,7 +783,7 @@ def _sql(self) -> str:
782783 operator = self .nullable_operator )
783784 return super ()._sql ()
784785
785- def _types (self ) -> type_pb2 .Type :
786+ def _types (self ) -> spanner_v1 .Type :
786787 if self .is_null ():
787788 return {}
788789 return super ()._types ()
0 commit comments