@@ -285,33 +285,54 @@ class ContextQuery(object):
285285 with ContextQuery(Automobile, keyspace='test4') as A:
286286 print len(A.objects.all()) # 0 result
287287
288+ # Multiple models
289+ with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2):
290+ print len(A.objects.all())
291+ print len(A2.objects.all())
292+
288293 """
289294
290- def __init__ (self , model , keyspace = None , connection = None ):
295+ def __init__ (self , * args , ** kwargs ):
291296 """
292- :param model: A model. This should be a class type, not an instance.
293- :param keyspace : (optional) A keyspace name
297+ :param *args: One or more models. A model should be a class type, not an instance.
298+ :param **kwargs : (optional) Context parameters: can be keyspace or connection
294299 """
295300 from cassandra .cqlengine import models
296301
297- if not issubclass (model , models .Model ):
298- raise CQLEngineException ("Models must be derived from base Model." )
302+ self .models = []
299303
300- self .model = model
304+ if len (args ) < 1 :
305+ raise CQLEngineException ("No model provided." )
301306
302- if keyspace :
303- from cassandra .cqlengine .models import _copy_model_class
304- ks = keyspace
305- self .model = _copy_model_class (model , {'__keyspace__' : ks })
307+ keyspace = kwargs .pop ('keyspace' , None )
308+ connection = kwargs .pop ('connection' , None )
306309
307- if connection :
308- self .model ._connection = connection
310+ if kwargs :
311+ raise CQLEngineException ("Unknown keyword argument(s): {0}" .format (
312+ ',' .join (kwargs .keys ())))
313+
314+ for model in args :
315+ if not issubclass (model , models .Model ):
316+ raise CQLEngineException ("Models must be derived from base Model." )
317+
318+ m = copy .deepcopy (model ) if not keyspace else None
319+
320+ if keyspace :
321+ from cassandra .cqlengine .models import _copy_model_class
322+ ks = keyspace
323+ m = _copy_model_class (model , {'__keyspace__' : ks })
324+
325+ if connection :
326+ m ._connection = connection
327+
328+ self .models .append (m )
309329
310330 def __enter__ (self ):
311- return self .model
331+ if len (self .models ) > 1 :
332+ return tuple (self .models )
333+ return self .models [0 ]
312334
313335 def __exit__ (self , exc_type , exc_val , exc_tb ):
314- self .model ._connection = None
315336 return
316337
317338
0 commit comments