@@ -1075,9 +1075,9 @@ def first(self, *parameters):
10751075 return None
10761076 return cm .extract_count () or cm .extract_command ()
10771077
1078- def copy (self ,
1078+ def _copy_data_in (self ,
10791079 iterable ,
1080- tps : "tuples per *set*" = 1000 ,
1080+ tps : "tuples per *set*" = None ,
10811081 ):
10821082 """
10831083 Given an iterable, execute the COPY ... FROM STDIN statement and
@@ -1086,39 +1086,45 @@ def copy(self,
10861086 `tps` is the number of tuples to buffer prior to giving the data
10871087 to the socket's send.
10881088 """
1089+ tps = tps or 500
10891090 iterable = iter (iterable )
10901091 x = pq .Transaction ((
10911092 pq .element .Bind (
10921093 b'' ,
10931094 self ._pq_statement_id ,
1094- (), () (),
1095+ (), (), (),
10951096 ),
10961097 pq .element .Execute (b'' , 1 ),
1097-
1098+ pq . element . SynchronizeMessage ,
10981099 ))
10991100 self .ife_descend (x )
11001101 self .database ._pq_push (x )
1102+
1103+ # Get the COPY started.
11011104 while x .state is not pq .Complete :
1105+ self .database ._pq_step ()
1106+ if hasattr (x , 'CopyFailSequence' ) and x .messages is x .CopyFailSequence :
1107+ break
1108+ else :
1109+ # Oh, it's not a COPY at all.
1110+ e = pg_exc .OperationError (
1111+ "load() used on a non-COPY FROM STDIN query" ,
1112+ )
1113+ x .ife_descend (e )
1114+ e .raise_exception ()
1115+
1116+ while x .messages :
11021117 # Process any messages setup for sending.
11031118 while x .messages is not x .CopyFailSequence :
11041119 self .database ._pq_step ()
1105- if x .state is pq .Complete :
1106- e = pg_exc .OperationError (
1107- "load() used on a non-COPY FROM STDIN query"
1108- )
1109- x .ife_descend (e )
1110- e .raise_exception ()
1111-
1112- copyseq = [
1113- pq .element .CopyData (l ) for l in islice (iterable , tps )
1114- ]
1115- x .messages = copyseq
1116- x .messages = self ._pq_xact .CopyDoneSequence
1120+ x .messages = [
1121+ pq .element .CopyData (l ) for l in islice (iterable , tps )
1122+ ]
1123+ x .messages = x .CopyDoneSequence
11171124 self .database ._pq_complete ()
11181125
1119- def load (self , tupleseq , tps = 40 ):
1120- if self .closed :
1121- self .prepare ()
1126+ def _load_bulk_tuples (self , tupleseq , tps = None ):
1127+ tps = tps or 64
11221128 last = pq .element .FlushMessage
11231129 tupleseqiter = iter (tupleseq )
11241130 try :
@@ -1164,6 +1170,19 @@ def load(self, tupleseq, tps = 40):
11641170 self .database .synchronize ()
11651171 raise
11661172
1173+ def load (self , iterable , tps = None ):
1174+ """
1175+ Execute the query for each parameter set in `iterable`.
1176+
1177+ In cases of ``COPY ... FROM STDIN``, iterable must be an iterable `bytes`.
1178+ """
1179+ if self .closed :
1180+ self .prepare ()
1181+ if not self ._input :
1182+ return self ._copy_data_in (iterable , tps = tps )
1183+ else :
1184+ return self ._load_bulk_tuples (iterable , tps = tps )
1185+
11671186class StoredProcedure (pg_api .StoredProcedure ):
11681187 ife_ancestor = None
11691188 procedure_id = None
0 commit comments