Skip to content

Commit 10a079f

Browse files
committed
Initial DriverContext and registry
1 parent ab241e7 commit 10a079f

19 files changed

Lines changed: 319 additions & 135 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
Features
55
--------
6+
* Introduce the DriverContext (PYTHON-958)
67
* Make protocol messages pluggable (PYTHON-956)
78
* cqlengine: Remove default limit on QuerySets (PYTHON-517)
89
* cqlengine: asynchronous execution support (PYTHON-605)

cassandra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def emit(self, record):
2222

2323
logging.getLogger('cassandra').addHandler(NullHandler())
2424

25-
__version_info__ = (3, 14, 0)
25+
__version_info__ = (4, 0, 'dev0')
2626
__version__ = '.'.join(map(str, __version_info__))
2727

2828

cassandra/cluster.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
BatchStatement, bind_params, QueryTrace, TraceUnavailable,
7171
named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET)
7272
from cassandra.timestamps import MonotonicTimestampGenerator
73+
from cassandra.context import DriverContext
7374

7475

7576
def _is_eventlet_monkey_patched():
@@ -651,6 +652,8 @@ def _default_load_balancing_policy(self):
651652
documentation for :meth:`Session.timestamp_generator`.
652653
"""
653654

655+
_context = None
656+
654657
@property
655658
def schema_metadata_enabled(self):
656659
"""
@@ -734,13 +737,16 @@ def __init__(self,
734737
allow_beta_protocol_version=False,
735738
timestamp_generator=None,
736739
idle_heartbeat_timeout=30,
737-
no_compact=False):
740+
no_compact=False,
741+
context=None):
738742
"""
739743
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
740744
extablishing connection pools or refreshing metadata.
741745
742746
Any of the mutable Cluster attributes may be set as keyword arguments to the constructor.
743747
"""
748+
self._context = context or DriverContext()
749+
744750
if contact_points is not None:
745751
if contact_points is _NOT_SET:
746752
self._contact_points_explicit = False
@@ -964,11 +970,14 @@ def connection_factory(self, address, *args, **kwargs):
964970
Intended for internal use only.
965971
"""
966972
kwargs = self._make_connection_kwargs(address, kwargs)
967-
return self.connection_class.factory(address, self.connect_timeout, *args, **kwargs)
973+
return self.connection_class.factory(
974+
self._context.protocol_handler, address, self.connect_timeout,
975+
*args, **kwargs)
968976

969977
def _make_connection_factory(self, host, *args, **kwargs):
970978
kwargs = self._make_connection_kwargs(host.address, kwargs)
971-
return partial(self.connection_class.factory, host.address, self.connect_timeout, *args, **kwargs)
979+
return partial(self.connection_class.factory, self._context.protocol_handler, host.address,
980+
self.connect_timeout, *args, **kwargs)
972981

973982
def _make_connection_kwargs(self, address, kwargs_dict):
974983
if self._auth_provider_callable:
@@ -1098,7 +1107,7 @@ def __exit__(self, *args):
10981107
self.shutdown()
10991108

11001109
def _new_session(self, keyspace):
1101-
session = Session(self, self.metadata.all_hosts(), keyspace)
1110+
session = Session(self, self.metadata.all_hosts(), keyspace, context=self._context)
11021111
self._session_register_user_types(session)
11031112
self.sessions.add(session)
11041113
return session
@@ -1710,24 +1719,18 @@ class Session(object):
17101719
.. versionadded:: 2.1.0
17111720
"""
17121721

1713-
client_protocol_handler = ProtocolHandler
1714-
"""
1715-
Specifies a protocol handler that will be used for client-initiated requests (i.e. no
1716-
internal driver requests). This can be used to override or extend features such as
1717-
message or type ser/des.
1718-
1719-
The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`.
1720-
1721-
When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser`
1722-
"""
1723-
1722+
_protocol_handler_class = ProtocolHandler
1723+
_protocol_handler = None
17241724
_lock = None
17251725
_pools = None
17261726
_profile_manager = None
17271727
_metrics = None
17281728
_request_init_callbacks = None
1729+
_context = None
1730+
1731+
def __init__(self, cluster, hosts, keyspace=None, context=None):
1732+
self._context = context or DriverContext()
17291733

1730-
def __init__(self, cluster, hosts, keyspace=None):
17311734
self.cluster = cluster
17321735
self.hosts = hosts
17331736
self.keyspace = keyspace
@@ -1757,6 +1760,26 @@ def __init__(self, cluster, hosts, keyspace=None):
17571760
msg += " using keyspace '%s'" % self.keyspace
17581761
raise NoHostAvailable(msg, [h.address for h in hosts])
17591762

1763+
@property
1764+
def protocol_handler_class(self):
1765+
"""
1766+
Specifies a protocol handler that will be used for client-initiated requests (i.e. no
1767+
internal driver requests). This can be used to override or extend features such as
1768+
message or type ser/des.
1769+
1770+
The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`.
1771+
1772+
When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser`
1773+
"""
1774+
return self._protocol_handler_class
1775+
1776+
@protocol_handler_class.setter
1777+
def protocol_handler_class(self, value):
1778+
self._protocol_handler_class = value
1779+
self._protocol_handler = self._protocol_handler_class(
1780+
self._context.message_codec_registry.encoders,
1781+
self._context.message_codec_registry.decoders)
1782+
17601783
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None):
17611784
"""
17621785
Execute the given query and synchronously wait for the response.
@@ -1829,7 +1852,6 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None
18291852
18301853
"""
18311854
future = self._create_response_future(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state)
1832-
future._protocol_handler = self.client_protocol_handler
18331855
self._on_request(future)
18341856
future.send_request()
18351857
return future
@@ -3074,13 +3096,14 @@ class ResponseFuture(object):
30743096
_custom_payload = None
30753097
_warnings = None
30763098
_timer = None
3077-
_protocol_handler = ProtocolHandler
3099+
30783100
_spec_execution_plan = NoSpeculativeExecutionPlan()
30793101

30803102
_warned_timeout = False
30813103

30823104
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None,
3083-
retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None):
3105+
retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None,
3106+
speculative_execution_plan=None):
30843107
self.session = session
30853108
# TODO: normalize handling of retry policy and row factory
30863109
self.row_factory = row_factory or session.cluster._default_row_factory
@@ -3237,10 +3260,9 @@ def _query(self, host, message=None, cb=None):
32373260
if cb is None:
32383261
cb = partial(self._set_result, host, connection, pool)
32393262

3240-
self.request_encoded_size = connection.send_msg(message, request_id, cb=cb,
3241-
encoder=self._protocol_handler.encode_message,
3242-
decoder=self._protocol_handler.decode_message,
3243-
result_metadata=result_meta)
3263+
self.request_encoded_size = connection.send_msg(
3264+
message, request_id, cb=cb, result_metadata=result_meta,
3265+
protocol_handler=self.session._protocol_handler)
32443266
self.attempted_hosts.append(host)
32453267
return request_id
32463268
except NoConnectionsAvailable as exc:

cassandra/connection.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class Connection(object):
204204
cql_version = None
205205
no_compact = False
206206
protocol_version = ProtocolVersion.MAX_SUPPORTED
207+
protocol_handler = None
207208

208209
keyspace = None
209210
compression = True
@@ -256,17 +257,19 @@ class Connection(object):
256257

257258
_check_hostname = False
258259

259-
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
260+
def __init__(self, protocol_handler, host='127.0.0.1', port=9042, authenticator=None,
260261
ssl_options=None, sockopts=None, compression=True,
261-
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
262-
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False):
262+
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED,
263+
is_control_connection=False, user_type_map=None, connect_timeout=None,
264+
allow_beta_protocol_version=False, no_compact=False):
263265
self.host = host
264266
self.port = port
265267
self.authenticator = authenticator
266268
self.ssl_options = ssl_options.copy() if ssl_options else None
267269
self.sockopts = sockopts
268270
self.compression = compression
269271
self.cql_version = cql_version
272+
self.protocol_handler = protocol_handler
270273
self.protocol_version = protocol_version
271274
self.is_control_connection = is_control_connection
272275
self.user_type_map = user_type_map
@@ -315,15 +318,15 @@ def create_timer(cls, timeout, callback):
315318
raise NotImplementedError()
316319

317320
@classmethod
318-
def factory(cls, host, timeout, *args, **kwargs):
321+
def factory(cls, protocol_handler, host, timeout, *args, **kwargs):
319322
"""
320323
A factory function which returns connections which have
321324
succeeded in connecting and are ready for service (or
322325
raises an exception otherwise).
323326
"""
324327
start = time.time()
325328
kwargs['connect_timeout'] = timeout
326-
conn = cls(host, *args, **kwargs)
329+
conn = cls(protocol_handler, host, *args, **kwargs)
327330
elapsed = time.time() - start
328331
conn.connected_event.wait(timeout - elapsed)
329332
if conn.last_error:
@@ -452,17 +455,20 @@ def handle_pushed(self, response):
452455
except Exception:
453456
log.exception("Pushed event handler errored, ignoring:")
454457

455-
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None):
458+
def send_msg(self, msg, request_id, cb, result_metadata=None, protocol_handler=None):
456459
if self.is_defunct:
457460
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
458461
elif self.is_closed:
459462
raise ConnectionShutdown("Connection to %s is closed" % self.host)
460463

464+
protocol_handler = protocol_handler or self.protocol_handler
461465
# queue the decoder function with the request
462466
# this allows us to inject custom functions per request to encode, decode messages
463-
self._requests[request_id] = (cb, decoder, result_metadata)
464-
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
465-
allow_beta_protocol_version=self.allow_beta_protocol_version)
467+
468+
self._requests[request_id] = (cb, protocol_handler.decode_message, result_metadata)
469+
msg = protocol_handler.encode_message(
470+
msg, request_id, self.protocol_version, compressor=self.compressor,
471+
allow_beta_protocol_version=self.allow_beta_protocol_version)
466472
self.push(msg)
467473
return len(msg)
468474

@@ -587,7 +593,7 @@ def process_msg(self, header, body):
587593
stream_id = header.stream
588594
if stream_id < 0:
589595
callback = None
590-
decoder = ProtocolHandler.decode_message
596+
decoder = self.protocol_handler.decode_message
591597
result_metadata = None
592598
else:
593599
try:

cassandra/context.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from cassandra.registry import MessageCodecRegistry
16+
from cassandra.protocol import ProtocolHandler
17+
18+
__all__ = ['DriverContext']
19+
20+
21+
class SingletonProvider(object):
22+
"""
23+
Providers are strategies of accessing objects. The SingletonProvider
24+
returns the same object instance on each call. The instance is also
25+
lazy-initialized.
26+
27+
:param provider: a callable that is used to create the object instance.
28+
:param *args: the provider callable args
29+
:param *kwargs: the provider callable kwargs
30+
"""
31+
_obj = None
32+
_provider = None
33+
_args = None
34+
_kwargs = None
35+
36+
def __init__(self, provider, *args, **kwargs):
37+
self._provider = provider
38+
self._args = args
39+
self._kwargs = kwargs
40+
41+
def __call__(self):
42+
if self._obj is None:
43+
self._obj = self._provider(*self._args, **self._kwargs)
44+
return self._obj
45+
46+
47+
class DriverContext(object):
48+
49+
_message_codec_registry = None
50+
# the default protocol handler
51+
_protocol_handler = None
52+
53+
def __init__(self):
54+
self._message_codec_registry = SingletonProvider(MessageCodecRegistry.factory)
55+
self._protocol_handler = SingletonProvider(
56+
ProtocolHandler,
57+
self._message_codec_registry().encoders,
58+
self._message_codec_registry().decoders)
59+
60+
@property
61+
def message_codec_registry(self):
62+
return self._message_codec_registry()
63+
64+
@property
65+
def protocol_handler(self):
66+
return self._protocol_handler()

0 commit comments

Comments
 (0)