@@ -143,8 +143,12 @@ class BatchQuery(object):
143143
144144 _consistency = None
145145
146+ _connection = None
147+ _connection_explicit = False
148+
149+
146150 def __init__ (self , batch_type = None , timestamp = None , consistency = None , execute_on_exception = False ,
147- timeout = conn .NOT_SET ):
151+ timeout = conn .NOT_SET , connection = None ):
148152 """
149153 :param batch_type: (optional) One of batch type values available through BatchType enum
150154 :type batch_type: str or None
@@ -161,6 +165,7 @@ def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on
161165 :param timeout: (optional) Timeout for the entire batch (in seconds), if not specified fallback
162166 to default session timeout
163167 :type timeout: float or None
168+ :param str connection: Connection name to use for the batch execution
164169 """
165170 self .queries = []
166171 self .batch_type = batch_type
@@ -173,6 +178,9 @@ def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on
173178 self ._callbacks = []
174179 self ._executed = False
175180 self ._context_entered = False
181+ self ._connection = connection
182+ if connection :
183+ self ._connection_explicit = True
176184
177185 def add_query (self , query ):
178186 if not isinstance (query , BaseCQLStatement ):
@@ -244,7 +252,7 @@ def execute(self):
244252
245253 query_list .append ('APPLY BATCH;' )
246254
247- tmp = conn .execute ('\n ' .join (query_list ), parameters , self ._consistency , self ._timeout )
255+ tmp = conn .execute ('\n ' .join (query_list ), parameters , self ._consistency , self ._timeout , connection = self . _connection )
248256 check_applied (tmp )
249257
250258 self .queries = []
@@ -544,6 +552,9 @@ def batch(self, batch_obj):
544552
545553 Note: running a select query with a batch object will raise an exception
546554 """
555+ if self ._connection :
556+ raise CQLEngineException ("Cannot specify the connection on model in batch mode." )
557+
547558 if batch_obj is not None and not isinstance (batch_obj , BatchQuery ):
548559 raise CQLEngineException ('batch_obj must be a BatchQuery instance or None' )
549560 clone = copy .deepcopy (self )
@@ -974,6 +985,9 @@ def using(self, keyspace=None, connection=None):
974985 Change the context on-the-fly of the Model class (connection, keyspace)
975986 """
976987
988+ if connection and self ._batch :
989+ raise CQLEngineException ("Cannot specify a connection on model in batch mode." )
990+
977991 clone = copy .deepcopy (self )
978992 if keyspace :
979993 new_type = type (self .model .__name__ , (self .model ,), {'__keyspace__' : keyspace })
@@ -1282,10 +1296,17 @@ def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None,
12821296 self ._timeout = timeout
12831297
12841298 def _execute (self , statement ):
1299+ connection = self .instance ._get_connection () if self .instance else self .model ._get_connection ()
12851300 if self ._batch :
1301+ if self ._batch ._connection :
1302+ if not self ._batch ._connection_explicit and connection and \
1303+ connection != self ._batch ._connection :
1304+ raise CQLEngineException ('BatchQuery queries must be executed on the same connection' )
1305+ else :
1306+ # set the BatchQuery connection from the model
1307+ self ._batch ._connection = connection
12861308 return self ._batch .add_query (statement )
12871309 else :
1288- connection = self .instance ._get_connection () if self .instance else self .model ._get_connection ()
12891310 results = _execute_statement (self .model , statement , self ._consistency , self ._timeout , connection = connection )
12901311 if self ._if_not_exists or self ._if_exists or self ._conditional :
12911312 check_applied (results )
0 commit comments