3030from spanner_orm import relationship
3131
3232from google .api_core import datetime_helpers
33- from google .cloud .spanner_v1 .proto import type_pb2
33+ from google .cloud import spanner
34+ from google .cloud import spanner_v1
3435import immutabledict
3536
3637T = TypeVar ('T' )
@@ -105,7 +106,7 @@ def sql(self) -> str:
105106 def _sql (self ) -> str :
106107 pass
107108
108- def types (self ) -> Dict [str , type_pb2 .Type ]:
109+ def types (self ) -> Dict [str , spanner_v1 .Type ]:
109110 """Returns parameter types to be used in the SQL query.
110111
111112 Returns:
@@ -117,7 +118,7 @@ def types(self) -> Dict[str, type_pb2.Type]:
117118 return self ._types ()
118119
119120 @abc .abstractmethod
120- def _types (self ) -> Dict [str , type_pb2 .Type ]:
121+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
121122 raise NotImplementedError
122123
123124 @abc .abstractmethod
@@ -158,7 +159,8 @@ def _validate(self, model_class: Type[Any]) -> None:
158159]
159160
160161
161- def _spanner_type_of_python_object (value : GuessableParamType ) -> type_pb2 .Type :
162+ def _spanner_type_of_python_object (
163+ value : GuessableParamType ) -> spanner_v1 .Type :
162164 """Returns the Cloud Spanner type of the given object.
163165
164166 Args:
@@ -173,31 +175,27 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
173175 raise TypeError (
174176 'Cannot infer type of None, because any SQL type can be NULL.' )
175177 simple_type_code = {
176- bool : type_pb2 .BOOL ,
177- int : type_pb2 .INT64 ,
178- float : type_pb2 .FLOAT64 ,
179- datetime_helpers .DatetimeWithNanoseconds : type_pb2 .TIMESTAMP ,
180- datetime .datetime : type_pb2 .TIMESTAMP ,
181- datetime .date : type_pb2 .DATE ,
182- bytes : type_pb2 .BYTES ,
183- str : type_pb2 .STRING ,
184- decimal .Decimal : type_pb2 .NUMERIC ,
178+ bool : spanner_v1 . TypeCode .BOOL ,
179+ int : spanner_v1 . TypeCode .INT64 ,
180+ float : spanner_v1 . TypeCode .FLOAT64 ,
181+ datetime_helpers .DatetimeWithNanoseconds : spanner_v1 . TypeCode .TIMESTAMP ,
182+ datetime .datetime : spanner_v1 . TypeCode .TIMESTAMP ,
183+ datetime .date : spanner_v1 . TypeCode .DATE ,
184+ bytes : spanner_v1 . TypeCode .BYTES ,
185+ str : spanner_v1 . TypeCode .STRING ,
186+ decimal .Decimal : spanner_v1 . TypeCode .NUMERIC ,
185187 }.get (type (value ))
186188 if simple_type_code is not None :
187- return type_pb2 .Type (code = simple_type_code )
189+ return spanner_v1 .Type (code = simple_type_code )
188190 elif isinstance (value , (list , tuple )):
189191 element_types = tuple (
190192 _spanner_type_of_python_object (item )
191193 for item in value
192194 if item is not None )
193- unique_element_type_count = len ({
194- # Protos aren't hashable, so use their serializations.
195- element_type .SerializeToString (deterministic = True )
196- for element_type in element_types
197- })
198- if unique_element_type_count == 1 :
199- return type_pb2 .Type (
200- code = type_pb2 .ARRAY ,
195+ if element_types and all (
196+ a == b for a , b in zip (element_types , element_types [1 :])):
197+ return spanner_v1 .Type (
198+ code = spanner_v1 .TypeCode .ARRAY ,
201199 array_element_type = element_types [0 ],
202200 )
203201 else :
@@ -211,7 +209,7 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type:
211209class Param :
212210 """Parameter for substitution into a SQL query."""
213211 value : Any
214- type : type_pb2 .Type
212+ type : spanner_v1 .Type
215213
216214 @classmethod
217215 def from_value (cls : Type [T ], value : GuessableParamType ) -> T :
@@ -220,14 +218,13 @@ def from_value(cls: Type[T], value: GuessableParamType) -> T:
220218
221219 # BYTES must be base64-encoded, see
222220 # https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110
223- if (isinstance (value , bytes ) and
224- guessed_type == type_pb2 .Type (code = type_pb2 .BYTES )):
221+ if (isinstance (value , bytes ) and guessed_type == spanner .param_types .BYTES ):
225222 encoded_value = base64 .b64encode (value ).decode ()
226223 elif (isinstance (value , (list , tuple )) and
227224 all (isinstance (x , bytes ) for x in value if x is not None ) and
228- guessed_type == type_pb2 .Type (
229- code = type_pb2 .ARRAY ,
230- array_element_type = type_pb2 . Type ( code = type_pb2 .BYTES ) ,
225+ guessed_type == spanner_v1 .Type (
226+ code = spanner_v1 . TypeCode .ARRAY ,
227+ array_element_type = spanner . param_types .BYTES ,
231228 )):
232229 encoded_value = tuple (
233230 None if item is None else base64 .b64encode (item ).decode ()
@@ -299,7 +296,7 @@ def _params(self) -> Dict[str, Any]:
299296 if isinstance (v , Param )
300297 }
301298
302- def _types (self ) -> Dict [str , type_pb2 .Type ]:
299+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
303300 """See base class."""
304301 return {
305302 self .key (k ): v .type
@@ -345,7 +342,7 @@ def _sql(self) -> str:
345342 other_table = self .destination_model_class .table ,
346343 other_column = self .destination_column )
347344
348- def _types (self ) -> Dict [str , type_pb2 .Type ]:
345+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
349346 return {}
350347
351348 def _validate (self , model_class : Type [Any ]) -> None :
@@ -411,7 +408,7 @@ def _sql(self) -> str:
411408 hints = (f'FORCE_INDEX={ self .name } ' , * self ._extra_hints )
412409 return f'@{{{ "," .join (hints )} }}'
413410
414- def _types (self ) -> Dict [str , type_pb2 .Type ]:
411+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
415412 return {}
416413
417414 def _validate (self , model_class : Type [Any ]) -> None :
@@ -433,7 +430,7 @@ def _sql(self) -> str:
433430 return '({})' .format (' AND ' .join (
434431 f'{ column } IS NOT NULL' for column in self .index .columns ))
435432
436- def _types (self ) -> Dict [str , type_pb2 .Type ]:
433+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
437434 return {}
438435
439436
@@ -542,7 +539,7 @@ def segment(self) -> Segment:
542539 def _sql (self ) -> str :
543540 return ''
544541
545- def _types (self ) -> Dict [str , type_pb2 .Type ]:
542+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
546543 return {}
547544
548545 def _validate (self , model_class : Type [Any ]) -> None :
@@ -601,10 +598,10 @@ def _sql(self) -> str:
601598 limit_key = self ._limit_key , offset_key = self ._offset_key )
602599 return 'LIMIT @{limit_key}' .format (limit_key = self ._limit_key )
603600
604- def _types (self ) -> Dict [str , type_pb2 .Type ]:
605- types = {self ._limit_key : type_pb2 . Type ( code = type_pb2 .INT64 ) }
601+ def _types (self ) -> Dict [str , spanner_v1 .Type ]:
602+ types = {self ._limit_key : spanner . param_types .INT64 }
606603 if self .offset :
607- types [self ._offset_key ] = type_pb2 . Type ( code = type_pb2 .INT64 )
604+ types [self ._offset_key ] = spanner . param_types .INT64
608605 return types
609606
610607 def _validate (self , model_class : Type [Any ]) -> None :
@@ -657,7 +654,7 @@ def _sql(self) -> str:
657654 def segment (self ) -> Segment :
658655 return Segment .WHERE
659656
660- def _types (self ) -> type_pb2 .Type :
657+ def _types (self ) -> spanner_v1 .Type :
661658 result = {}
662659 for condition in self .all_conditions :
663660 condition .suffix = str (int (self .suffix or 0 ) + len (result ))
@@ -702,7 +699,7 @@ def _sql(self) -> str:
702699 def segment (self ) -> Segment :
703700 return Segment .ORDER_BY
704701
705- def _types (self ) -> type_pb2 .Type :
702+ def _types (self ) -> spanner_v1 .Type :
706703 return {}
707704
708705 def _validate (self , model_class : Type [Any ]) -> None :
@@ -747,7 +744,7 @@ def _sql(self) -> str:
747744 operator = self .operator ,
748745 column_key = self ._column_key )
749746
750- def _types (self ) -> type_pb2 .Type :
747+ def _types (self ) -> spanner_v1 .Type :
751748 return {self ._column_key : self .model_class .fields [self .column ].grpc_type ()}
752749
753750 def _validate (self , model_class : Type [Any ]) -> None :
@@ -773,9 +770,10 @@ def _sql(self) -> str:
773770 operator = self .operator ,
774771 column_key = self ._column_key )
775772
776- def _types (self ) -> type_pb2 .Type :
773+ def _types (self ) -> spanner_v1 .Type :
777774 grpc_type = self .model_class .fields [self .column ].grpc_type ()
778- list_type = type_pb2 .Type (code = type_pb2 .ARRAY , array_element_type = grpc_type )
775+ list_type = spanner_v1 .Type (
776+ code = spanner_v1 .TypeCode .ARRAY , array_element_type = grpc_type )
779777 return {self ._column_key : list_type }
780778
781779 def _validate (self , model_class : Type [Any ]) -> None :
@@ -815,7 +813,7 @@ def _sql(self) -> str:
815813 operator = self .nullable_operator )
816814 return super ()._sql ()
817815
818- def _types (self ) -> type_pb2 .Type :
816+ def _types (self ) -> spanner_v1 .Type :
819817 if self .is_null ():
820818 return {}
821819 return super ()._types ()
0 commit comments