Skip to content

Commit e5d27eb

Browse files
committed
Add multiple models support to ContextQuery
1 parent da10c63 commit e5d27eb

1 file changed

Lines changed: 35 additions & 14 deletions

File tree

cassandra/cqlengine/query.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -285,33 +285,54 @@ class ContextQuery(object):
285285
with ContextQuery(Automobile, keyspace='test4') as A:
286286
print len(A.objects.all()) # 0 result
287287
288+
# Multiple models
289+
with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2):
290+
print len(A.objects.all())
291+
print len(A2.objects.all())
292+
288293
"""
289294

290-
def __init__(self, model, keyspace=None, connection=None):
295+
def __init__(self, *args, **kwargs):
291296
"""
292-
:param model: A model. This should be a class type, not an instance.
293-
:param keyspace: (optional) A keyspace name
297+
:param *args: One or more models. A model should be a class type, not an instance.
298+
:param **kwargs: (optional) Context parameters: can be keyspace or connection
294299
"""
295300
from cassandra.cqlengine import models
296301

297-
if not issubclass(model, models.Model):
298-
raise CQLEngineException("Models must be derived from base Model.")
302+
self.models = []
299303

300-
self.model = model
304+
if len(args) < 1:
305+
raise CQLEngineException("No model provided.")
301306

302-
if keyspace:
303-
from cassandra.cqlengine.models import _copy_model_class
304-
ks = keyspace
305-
self.model = _copy_model_class(model, {'__keyspace__': ks})
307+
keyspace = kwargs.pop('keyspace', None)
308+
connection = kwargs.pop('connection', None)
306309

307-
if connection:
308-
self.model._connection = connection
310+
if kwargs:
311+
raise CQLEngineException("Unknown keyword argument(s): {0}".format(
312+
','.join(kwargs.keys())))
313+
314+
for model in args:
315+
if not issubclass(model, models.Model):
316+
raise CQLEngineException("Models must be derived from base Model.")
317+
318+
m = copy.deepcopy(model) if not keyspace else None
319+
320+
if keyspace:
321+
from cassandra.cqlengine.models import _copy_model_class
322+
ks = keyspace
323+
m = _copy_model_class(model, {'__keyspace__': ks})
324+
325+
if connection:
326+
m._connection = connection
327+
328+
self.models.append(m)
309329

310330
def __enter__(self):
311-
return self.model
331+
if len(self.models) > 1:
332+
return tuple(self.models)
333+
return self.models[0]
312334

313335
def __exit__(self, exc_type, exc_val, exc_tb):
314-
self.model._connection = None
315336
return
316337

317338

0 commit comments

Comments
 (0)