4444from cassandra import (ConsistencyLevel , AuthenticationFailed ,
4545 OperationTimedOut , UnsupportedOperation )
4646from cassandra .connection import ConnectionException , ConnectionShutdown
47+ from cassandra .encoder import cql_encode_all_types , cql_encoders
4748from cassandra .protocol import (QueryMessage , ResultMessage ,
4849 ErrorMessage , ReadTimeoutErrorMessage ,
4950 WriteTimeoutErrorMessage ,
@@ -409,8 +410,7 @@ def __init__(self,
409410 self ._listener_lock = Lock ()
410411
411412 # let Session objects be GC'ed (and shutdown) when the user no longer
412- # holds a reference. Normally the cycle detector would handle this,
413- # but implementing __del__ prevents that.
413+ # holds a reference.
414414 self .sessions = WeakSet ()
415415 self .metadata = Metadata (self )
416416 self .control_connection = None
@@ -451,8 +451,10 @@ def __init__(self,
451451 self .control_connection = ControlConnection (
452452 self , self .control_connection_timeout )
453453
454- def register_type_class (self , keyspace , user_type , klass ):
454+ def register_user_type (self , keyspace , user_type , klass ):
455455 self ._user_types [keyspace ][user_type ] = klass
456+ for session in self .sessions :
457+ self .session .user_type_registered (keyspace , user_type , klass )
456458
457459 def get_min_requests_per_connection (self , host_distance ):
458460 return self ._min_requests_per_connection [host_distance ]
@@ -602,6 +604,9 @@ def shutdown(self):
602604
603605 def _new_session (self ):
604606 session = Session (self , self .metadata .all_hosts ())
607+ for keyspace , type_map in six .iteritems (self ._user_types ):
608+ for udt_name , klass in six .iteritems (type_map ):
609+ session .user_type_registered (keyspace , udt_name , klass )
605610 self .sessions .add (session )
606611 return session
607612
@@ -1064,6 +1069,19 @@ class Session(object):
10641069 _metrics = None
10651070 _protocol_version = None
10661071
1072+ encoders = None
1073+
1074+ def user_type_registered (self , keyspace , user_type , klass ):
1075+ type_meta = self .cluster .metadata .keyspaces [keyspace ].user_types [user_type ]
1076+
1077+ def encode (val ):
1078+ return '{ %s }' % ' , ' .join ('%s : %s' % (
1079+ field_name ,
1080+ cql_encode_all_types (getattr (val , field_name ))
1081+ ) for field_name in type_meta .field_names )
1082+
1083+ self ._encoders [klass ] = encode
1084+
10671085 def __init__ (self , cluster , hosts ):
10681086 self .cluster = cluster
10691087 self .hosts = hosts
@@ -1074,6 +1092,8 @@ def __init__(self, cluster, hosts):
10741092 self ._metrics = cluster .metrics
10751093 self ._protocol_version = self .cluster .protocol_version
10761094
1095+ self ._encoders = cql_encoders .copy ()
1096+
10771097 # create connection pools in parallel
10781098 futures = []
10791099 for host in hosts :
@@ -1196,7 +1216,7 @@ def _create_response_future(self, query, parameters, trace):
11961216 if isinstance (query , SimpleStatement ):
11971217 query_string = query .query_string
11981218 if parameters :
1199- query_string = bind_params (query .query_string , parameters )
1219+ query_string = bind_params (query .query_string , parameters , self . _encoders )
12001220 message = QueryMessage (
12011221 query_string , cl , query .serial_consistency_level ,
12021222 fetch_size , timestamp = timestamp )
@@ -1701,24 +1721,24 @@ def _refresh_schema(self, connection, keyspace=None, table=None, usertype=None,
17011721 cf_query , col_query )
17021722
17031723 log .debug ("[control connection] Fetched table info for %s.%s, rebuilding metadata" , (keyspace , table ))
1704- cf_result = dict_factory (* cf_result .results )
1705- col_result = dict_factory (* col_result .results )
1724+ cf_result = dict_factory (* cf_result .results ) if cf_result else {}
1725+ col_result = dict_factory (* col_result .results ) if col_result else {}
17061726 self ._cluster .metadata .table_changed (keyspace , table , cf_result , col_result )
17071727 elif usertype :
17081728 # user defined types within this keyspace changed
17091729 where_clause = " WHERE keyspace_name = '%s' AND type_name = '%s'" % (keyspace , usertype )
17101730 types_query = QueryMessage (query = self ._SELECT_USERTYPES + where_clause , consistency_level = cl )
17111731 types_result = connection .wait_for_response (types_query )
17121732 log .debug ("[control connection] Fetched user type info for %s.%s, rebuilding metadata" , (keyspace , usertype ))
1713- types_result = dict_factory (* types_result )
1733+ types_result = dict_factory (* types_result . results ) if types_result . results else {}
17141734 self ._cluster .metadata .usertype_changed (keyspace , usertype , types_result )
17151735 elif keyspace :
17161736 # only the keyspace itself changed (such as replication settings)
17171737 where_clause = " WHERE keyspace_name = '%s'" % (keyspace ,)
17181738 ks_query = QueryMessage (query = self ._SELECT_KEYSPACES + where_clause , consistency_level = cl )
17191739 ks_result = connection .wait_for_response (ks_query )
17201740 log .debug ("[control connection] Fetched keyspace info for %s, rebuilding metadata" , (keyspace ,))
1721- ks_result = dict_factory (* types_result )
1741+ ks_result = dict_factory (* ks_result . results ) if ks_result . results else {}
17221742 self ._cluster .metadata .keyspace_changed (keyspace , ks_result )
17231743 else :
17241744 # build everything from scratch
@@ -1730,12 +1750,12 @@ def _refresh_schema(self, connection, keyspace=None, table=None, usertype=None,
17301750 if self ._protocol_version >= 3 :
17311751 queries .append (QueryMessage (query = self ._SELECT_USERTYPES , consistency_level = cl ))
17321752 ks_result , cf_result , col_result , types_result = connection .wait_for_responses (* queries )
1733- types_result = dict_factory (* types_result )
1753+ types_result = dict_factory (* types_result . results ) if types_result . results else {}
17341754 else :
17351755 ks_result , cf_result , col_result = connection .wait_for_responses (* queries )
17361756 types_result = {}
17371757
1738- ks_result = dict_factory (* types_result )
1758+ ks_result = dict_factory (* ks_result . results )
17391759 cf_result = dict_factory (* cf_result .results )
17401760 col_result = dict_factory (* col_result .results )
17411761
0 commit comments