Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions cassandra/cqlengine/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def __init__(self, name, hosts, consistency=None,
self.cluster_options = cluster_options if cluster_options else {}
self.lazy_connect_lock = threading.RLock()

@classmethod
def from_session(cls, name, session):
instance = cls(name=name, hosts=session.hosts)
instance.cluster, instance.session = session.cluster, session
instance.setup_session()
return instance

def setup(self):
"""Setup the connection"""
global cluster, session
Expand Down Expand Up @@ -132,21 +139,67 @@ def handle_lazy_connect(self):
self.setup()


def register_connection(name, hosts, consistency=None, lazy_connect=False,
retry_connect=False, cluster_options=None, default=False):
def register_connection(name, hosts=None, consistency=None, lazy_connect=False,
retry_connect=False, cluster_options=None, default=False,
session=None):
"""
Add a connection to the connection registry. ``hosts`` and ``session`` are
mutually exclusive, and ``consistency``, ``lazy_connect``,
``retry_connect``, and ``cluster_options`` only work with ``hosts``. Using
``hosts`` will create a new :class:`cassandra.cluster.Cluster` and
:class:`cassandra.cluster.Session`.

:param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`).
:param int consistency: The default :class:`~.ConsistencyLevel` for the
registered connection's new session. Default is the same as
:attr:`.Session.default_consistency_level`. For use with ``hosts`` only;
will fail when used with ``session``.
:param bool lazy_connect: True if should not connect until first use. For
use with ``hosts`` only; will fail when used with ``session``.
:param bool retry_connect: True if we should retry to connect even if there
was a connection failure initially. For use with ``hosts`` only; will
fail when used with ``session``.
:param dict cluster_options: A dict of options to be used as keyword
arguments to :class:`cassandra.cluster.Cluster`. For use with ``hosts``
only; will fail when used with ``session``.
:param bool default: If True, set the new connection as the cqlengine
default
:param Session session: A :class:`cassandra.cluster.Session` to be used in
the created connection.
"""

if name in _connections:
log.warning("Registering connection '{0}' when it already exists.".format(name))

conn = Connection(name, hosts, consistency=consistency,lazy_connect=lazy_connect,
retry_connect=retry_connect, cluster_options=cluster_options)
hosts_xor_session_passed = (hosts is None) ^ (session is None)
if not hosts_xor_session_passed:
raise CQLEngineException(
"Must pass exactly one of 'hosts' or 'session' arguments"
)
elif session is not None:
invalid_config_args = (consistency is not None or
lazy_connect is not False or
retry_connect is not False or
cluster_options is not None)
if invalid_config_args:
raise CQLEngineException(
"Session configuration arguments and 'session' argument are mutually exclusive"
)
conn = Connection.from_session(name, session=session)
conn.setup_session()
elif hosts is not None:
conn = Connection(
name, hosts=hosts,
consistency=consistency, lazy_connect=lazy_connect,
retry_connect=retry_connect, cluster_options=cluster_options
)
conn.setup()

_connections[name] = conn

if default:
set_default_connection(name)

conn.setup()
return conn


Expand Down Expand Up @@ -222,7 +275,12 @@ def set_session(s):
This may be relaxed in the future
"""

conn = get_connection()
try:
conn = get_connection()
except CQLEngineException:
# no default connection set; initalize one
register_connection('default', session=s, default=True)
conn = get_connection()

if conn.session:
log.warning("configuring new default connection for cqlengine when one was already set")
Expand Down Expand Up @@ -304,7 +362,11 @@ def get_cluster(connection=None):
def register_udt(keyspace, type_name, klass, connection=None):
udt_by_keyspace[keyspace][type_name] = klass

cluster = get_cluster(connection)
try:
cluster = get_cluster(connection)
except CQLEngineException:
cluster = None

if cluster:
try:
cluster.register_user_type(keyspace, type_name, klass)
Expand Down
13 changes: 12 additions & 1 deletion docs/cqlengine/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Connections are experimental and aimed to ease the use of multiple sessions with
Register a new connection
=========================

To use cqlengine, you need at least a default connection. This is currently done automatically under the hood with :func:`connection.setup <.connection.setup>`. If you want to use another cluster/session, you need to register a new cqlengine connection. You register a connection with :func:`~.connection.register_connection`
To use cqlengine, you need at least a default connection. If you initialize cqlengine's connections with with :func:`connection.setup <.connection.setup>`, a connection will be created automatically. If you want to use another cluster/session, you need to register a new cqlengine connection. You register a connection with :func:`~.connection.register_connection`:

.. code-block:: python

Expand All @@ -17,6 +17,17 @@ To use cqlengine, you need at least a default connection. This is currently done
connection.setup(['127.0.0.1')
connection.register_connection('cluster2', ['127.0.0.2'])

:func:`~.connection.register_connection` can take a list of hosts, as shown above, in which case it will create a connection with a new session. It can also take a `session` argument if you've already created a session:

.. code-block:: python

from cassandra.cqlengine import connection
from cassandra.cluster import Cluster

session = Cluster(['127.0.0.1']).connect()
connection.register_connection('cluster3', session=session)


Change the default connection
=============================

Expand Down
7 changes: 7 additions & 0 deletions tests/integration/cqlengine/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cluster import NoHostAvailable
from cassandra.cqlengine import columns, CQLEngineException
from cassandra.cqlengine import connection as conn
Expand Down Expand Up @@ -217,6 +218,12 @@ def test_create_drop_table(self):
for ks in self.keyspaces:
drop_keyspace(ks, connections=self.conns)

def test_connection_creation_from_session(self):
session = Cluster(['127.0.0.1']).connect()
connection_name = 'from_session'
conn.register_connection(connection_name, session=session)
self.addCleanup(conn.unregister_connection, connection_name)


class BatchQueryConnectionTests(BaseCassEngTestCase):

Expand Down
59 changes: 59 additions & 0 deletions tests/unit/cqlengine/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
import unittest2 as unittest
except ImportError:
import unittest # noqa

from cassandra.cqlengine import connection
from cassandra.query import dict_factory

from mock import Mock


class ConnectionTest(unittest.TestCase):

no_registered_connection_msg = "doesn't exist in the registry"

def setUp(self):
super(ConnectionTest, self).setUp()
self.assertFalse(
connection._connections,
'Test precondition not met: connections are registered: {cs}'.format(cs=connection._connections)
)

def test_set_session_without_existing_connection(self):
"""
Users can set the default session without having a default connection set.
"""
mock_session = Mock(
row_factory=dict_factory,
encoder=Mock(mapping={})
)
connection.set_session(mock_session)

def test_get_session_fails_without_existing_connection(self):
"""
Users can't get the default session without having a default connection set.
"""
with self.assertRaisesRegexp(connection.CQLEngineException, self.no_registered_connection_msg):
connection.get_session(connection=None)

def test_get_cluster_fails_without_existing_connection(self):
"""
Users can't get the default cluster without having a default connection set.
"""
with self.assertRaisesRegexp(connection.CQLEngineException, self.no_registered_connection_msg):
connection.get_cluster(connection=None)
41 changes: 41 additions & 0 deletions tests/unit/cqlengine/test_udt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
import unittest2 as unittest
except ImportError:
import unittest # noqa

from cassandra.cqlengine import columns
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.usertype import UserType


class UDTTest(unittest.TestCase):

def test_initialization_without_existing_connection(self):
"""
Test that users can define models with UDTs without initializing
connections.

Written to reproduce PYTHON-649.
"""

class Value(UserType):
t = columns.Text()

class DummyUDT(Model):
__keyspace__ = 'ks'
primary_key = columns.Integer(primary_key=True)
value = columns.UserDefinedType(Value)