Skip to content

Commit 8480123

Browse files
committed
Use registered UDTs for non-prepared encoding
1 parent 7a838e2 commit 8480123

4 files changed

Lines changed: 54 additions & 29 deletions

File tree

cassandra/cluster.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from cassandra import (ConsistencyLevel, AuthenticationFailed,
4545
OperationTimedOut, UnsupportedOperation)
4646
from cassandra.connection import ConnectionException, ConnectionShutdown
47+
from cassandra.encoder import cql_encode_all_types, cql_encoders
4748
from cassandra.protocol import (QueryMessage, ResultMessage,
4849
ErrorMessage, ReadTimeoutErrorMessage,
4950
WriteTimeoutErrorMessage,
@@ -409,8 +410,7 @@ def __init__(self,
409410
self._listener_lock = Lock()
410411

411412
# let Session objects be GC'ed (and shutdown) when the user no longer
412-
# holds a reference. Normally the cycle detector would handle this,
413-
# but implementing __del__ prevents that.
413+
# holds a reference.
414414
self.sessions = WeakSet()
415415
self.metadata = Metadata(self)
416416
self.control_connection = None
@@ -451,8 +451,10 @@ def __init__(self,
451451
self.control_connection = ControlConnection(
452452
self, self.control_connection_timeout)
453453

454-
def register_type_class(self, keyspace, user_type, klass):
454+
def register_user_type(self, keyspace, user_type, klass):
455455
self._user_types[keyspace][user_type] = klass
456+
for session in self.sessions:
457+
self.session.user_type_registered(keyspace, user_type, klass)
456458

457459
def get_min_requests_per_connection(self, host_distance):
458460
return self._min_requests_per_connection[host_distance]
@@ -602,6 +604,9 @@ def shutdown(self):
602604

603605
def _new_session(self):
604606
session = Session(self, self.metadata.all_hosts())
607+
for keyspace, type_map in six.iteritems(self._user_types):
608+
for udt_name, klass in six.iteritems(type_map):
609+
session.user_type_registered(keyspace, udt_name, klass)
605610
self.sessions.add(session)
606611
return session
607612

@@ -1064,6 +1069,19 @@ class Session(object):
10641069
_metrics = None
10651070
_protocol_version = None
10661071

1072+
encoders = None
1073+
1074+
def user_type_registered(self, keyspace, user_type, klass):
1075+
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
1076+
1077+
def encode(val):
1078+
return '{ %s }' % ' , '.join('%s : %s' % (
1079+
field_name,
1080+
cql_encode_all_types(getattr(val, field_name))
1081+
) for field_name in type_meta.field_names)
1082+
1083+
self._encoders[klass] = encode
1084+
10671085
def __init__(self, cluster, hosts):
10681086
self.cluster = cluster
10691087
self.hosts = hosts
@@ -1074,6 +1092,8 @@ def __init__(self, cluster, hosts):
10741092
self._metrics = cluster.metrics
10751093
self._protocol_version = self.cluster.protocol_version
10761094

1095+
self._encoders = cql_encoders.copy()
1096+
10771097
# create connection pools in parallel
10781098
futures = []
10791099
for host in hosts:
@@ -1196,7 +1216,7 @@ def _create_response_future(self, query, parameters, trace):
11961216
if isinstance(query, SimpleStatement):
11971217
query_string = query.query_string
11981218
if parameters:
1199-
query_string = bind_params(query.query_string, parameters)
1219+
query_string = bind_params(query.query_string, parameters, self._encoders)
12001220
message = QueryMessage(
12011221
query_string, cl, query.serial_consistency_level,
12021222
fetch_size, timestamp=timestamp)
@@ -1701,24 +1721,24 @@ def _refresh_schema(self, connection, keyspace=None, table=None, usertype=None,
17011721
cf_query, col_query)
17021722

17031723
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", (keyspace, table))
1704-
cf_result = dict_factory(*cf_result.results)
1705-
col_result = dict_factory(*col_result.results)
1724+
cf_result = dict_factory(*cf_result.results) if cf_result else {}
1725+
col_result = dict_factory(*col_result.results) if col_result else {}
17061726
self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result)
17071727
elif usertype:
17081728
# user defined types within this keyspace changed
17091729
where_clause = " WHERE keyspace_name = '%s' AND type_name = '%s'" % (keyspace, usertype)
17101730
types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl)
17111731
types_result = connection.wait_for_response(types_query)
17121732
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", (keyspace, usertype))
1713-
types_result = dict_factory(*types_result)
1733+
types_result = dict_factory(*types_result.results) if types_result.results else {}
17141734
self._cluster.metadata.usertype_changed(keyspace, usertype, types_result)
17151735
elif keyspace:
17161736
# only the keyspace itself changed (such as replication settings)
17171737
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
17181738
ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
17191739
ks_result = connection.wait_for_response(ks_query)
17201740
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", (keyspace,))
1721-
ks_result = dict_factory(*types_result)
1741+
ks_result = dict_factory(*ks_result.results) if ks_result.results else {}
17221742
self._cluster.metadata.keyspace_changed(keyspace, ks_result)
17231743
else:
17241744
# build everything from scratch
@@ -1730,12 +1750,12 @@ def _refresh_schema(self, connection, keyspace=None, table=None, usertype=None,
17301750
if self._protocol_version >= 3:
17311751
queries.append(QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl))
17321752
ks_result, cf_result, col_result, types_result = connection.wait_for_responses(*queries)
1733-
types_result = dict_factory(*types_result)
1753+
types_result = dict_factory(*types_result.results) if types_result.results else {}
17341754
else:
17351755
ks_result, cf_result, col_result = connection.wait_for_responses(*queries)
17361756
types_result = {}
17371757

1738-
ks_result = dict_factory(*types_result)
1758+
ks_result = dict_factory(*ks_result.results)
17391759
cf_result = dict_factory(*cf_result.results)
17401760
col_result = dict_factory(*col_result.results)
17411761

cassandra/protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,14 @@ def recv_body(cls, f, protocol_version, user_type_map):
560560
ksname = read_string(f)
561561
results = ksname
562562
elif kind == RESULT_KIND_PREPARED:
563-
results = cls.recv_results_prepared(f)
563+
results = cls.recv_results_prepared(f, user_type_map)
564564
elif kind == RESULT_KIND_SCHEMA_CHANGE:
565565
results = cls.recv_results_schema_change(f, protocol_version)
566566
return cls(kind, results, paging_state)
567567

568568
@classmethod
569569
def recv_results_rows(cls, f, protocol_version, user_type_map):
570-
paging_state, column_metadata = cls.recv_results_metadata(f)
570+
paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map)
571571
rowcount = read_int(f)
572572
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
573573
colnames = [c[2] for c in column_metadata]
@@ -579,9 +579,9 @@ def recv_results_rows(cls, f, protocol_version, user_type_map):
579579
return (paging_state, (colnames, parsed_rows))
580580

581581
@classmethod
582-
def recv_results_prepared(cls, f):
582+
def recv_results_prepared(cls, f, user_type_map):
583583
query_id = read_binary_string(f)
584-
_, column_metadata = cls.recv_results_metadata(f)
584+
_, column_metadata = cls.recv_results_metadata(f, user_type_map)
585585
return (query_id, column_metadata)
586586

587587
@classmethod

cassandra/query.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,10 @@ class BatchStatement(Statement):
567567
"""
568568

569569
_statements_and_parameters = None
570+
_session = None
570571

571572
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
572-
consistency_level=None):
573+
consistency_level=None, session=None):
573574
"""
574575
`batch_type` specifies The :class:`.BatchType` for the batch operation.
575576
Defaults to :attr:`.BatchType.LOGGED`.
@@ -605,6 +606,7 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
605606
"""
606607
self.batch_type = batch_type
607608
self._statements_and_parameters = []
609+
self._session = session
608610
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
609611

610612
def add(self, statement, parameters=None):
@@ -617,7 +619,8 @@ def add(self, statement, parameters=None):
617619
"""
618620
if isinstance(statement, six.string_types):
619621
if parameters:
620-
statement = bind_params(statement, parameters)
622+
encoders = cql_encoders if self._session is None else self._session.encoders
623+
statement = bind_params(statement, parameters, encoders)
621624
self._statements_and_parameters.append((False, statement, ()))
622625
elif isinstance(statement, PreparedStatement):
623626
query_id = statement.query_id
@@ -635,7 +638,8 @@ def add(self, statement, parameters=None):
635638
# it must be a SimpleStatement
636639
query_string = statement.query_string
637640
if parameters:
638-
query_string = bind_params(query_string, parameters)
641+
encoders = cql_encoders if self._session is None else self._session.encoders
642+
query_string = bind_params(query_string, parameters, encoders)
639643
self._statements_and_parameters.append((False, query_string, ()))
640644
return self
641645

@@ -677,11 +681,11 @@ def __str__(self):
677681
return cql_encode_sequence(self.sequence)
678682

679683

680-
def bind_params(query, params):
684+
def bind_params(query, params, encoders):
681685
if isinstance(params, dict):
682-
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
686+
return query % dict((k, encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
683687
else:
684-
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
688+
return query % tuple(encoders.get(type(v), cql_encode_object)(v) for v in params)
685689

686690

687691
class TraceUnavailable(Exception):

tests/unit/test_parameter_binding.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
except ImportError:
1818
import unittest # noqa
1919

20+
from cassandra.encoder import cql_encoders
2021
from cassandra.query import bind_params, ValueSequence
2122
from cassandra.query import PreparedStatement, BoundStatement
2223
from cassandra.cqltypes import Int32Type
@@ -28,43 +29,43 @@
2829
class ParamBindingTest(unittest.TestCase):
2930

3031
def test_bind_sequence(self):
31-
result = bind_params("%s %s %s", (1, "a", 2.0))
32+
result = bind_params("%s %s %s", (1, "a", 2.0), cql_encoders)
3233
self.assertEqual(result, "1 'a' 2.0")
3334

3435
def test_bind_map(self):
35-
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
36+
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), cql_encoders)
3637
self.assertEqual(result, "1 'a' 2.0")
3738

3839
def test_sequence_param(self):
39-
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
40+
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), cql_encoders)
4041
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
4142

4243
def test_generator_param(self):
43-
result = bind_params("%s", ((i for i in xrange(3)),))
44+
result = bind_params("%s", ((i for i in xrange(3)),), cql_encoders)
4445
self.assertEqual(result, "[ 0 , 1 , 2 ]")
4546

4647
def test_none_param(self):
47-
result = bind_params("%s", (None,))
48+
result = bind_params("%s", (None,), cql_encoders)
4849
self.assertEqual(result, "NULL")
4950

5051
def test_list_collection(self):
51-
result = bind_params("%s", (['a', 'b', 'c'],))
52+
result = bind_params("%s", (['a', 'b', 'c'],), cql_encoders)
5253
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
5354

5455
def test_set_collection(self):
55-
result = bind_params("%s", (set(['a', 'b']),))
56+
result = bind_params("%s", (set(['a', 'b']),), cql_encoders)
5657
self.assertIn(result, ("{ 'a' , 'b' }", "{ 'b' , 'a' }"))
5758

5859
def test_map_collection(self):
5960
vals = OrderedDict()
6061
vals['a'] = 'a'
6162
vals['b'] = 'b'
6263
vals['c'] = 'c'
63-
result = bind_params("%s", (vals,))
64+
result = bind_params("%s", (vals,), cql_encoders)
6465
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
6566

6667
def test_quote_escaping(self):
67-
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
68+
result = bind_params("%s", ("""'ef''ef"ef""ef'""",), cql_encoders)
6869
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
6970

7071

0 commit comments

Comments
 (0)