Skip to content

Commit ab5ed57

Browse files
author
James William Pye
committed
Fix copy.
load wasn't detecting non-parameterized queries properly and copy wasn't even working. :P
1 parent 25b12e3 commit ab5ed57

2 files changed

Lines changed: 61 additions & 19 deletions

File tree

postgresql/driver/pq3.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,9 +1075,9 @@ def first(self, *parameters):
10751075
return None
10761076
return cm.extract_count() or cm.extract_command()
10771077

1078-
def copy(self,
1078+
def _copy_data_in(self,
10791079
iterable,
1080-
tps : "tuples per *set*" = 1000,
1080+
tps : "tuples per *set*" = None,
10811081
):
10821082
"""
10831083
Given an iterable, execute the COPY ... FROM STDIN statement and
@@ -1086,39 +1086,45 @@ def copy(self,
10861086
`tps` is the number of tuples to buffer prior to giving the data
10871087
to the socket's send.
10881088
"""
1089+
tps = tps or 500
10891090
iterable = iter(iterable)
10901091
x = pq.Transaction((
10911092
pq.element.Bind(
10921093
b'',
10931094
self._pq_statement_id,
1094-
(), () (),
1095+
(), (), (),
10951096
),
10961097
pq.element.Execute(b'', 1),
1097-
1098+
pq.element.SynchronizeMessage,
10981099
))
10991100
self.ife_descend(x)
11001101
self.database._pq_push(x)
1102+
1103+
# Get the COPY started.
11011104
while x.state is not pq.Complete:
1105+
self.database._pq_step()
1106+
if hasattr(x, 'CopyFailSequence') and x.messages is x.CopyFailSequence:
1107+
break
1108+
else:
1109+
# Oh, it's not a COPY at all.
1110+
e = pg_exc.OperationError(
1111+
"load() used on a non-COPY FROM STDIN query",
1112+
)
1113+
x.ife_descend(e)
1114+
e.raise_exception()
1115+
1116+
while x.messages:
11021117
# Process any messages setup for sending.
11031118
while x.messages is not x.CopyFailSequence:
11041119
self.database._pq_step()
1105-
if x.state is pq.Complete:
1106-
e = pg_exc.OperationError(
1107-
"load() used on a non-COPY FROM STDIN query"
1108-
)
1109-
x.ife_descend(e)
1110-
e.raise_exception()
1111-
1112-
copyseq = [
1113-
pq.element.CopyData(l) for l in islice(iterable, tps)
1114-
]
1115-
x.messages = copyseq
1116-
x.messages = self._pq_xact.CopyDoneSequence
1120+
x.messages = [
1121+
pq.element.CopyData(l) for l in islice(iterable, tps)
1122+
]
1123+
x.messages = x.CopyDoneSequence
11171124
self.database._pq_complete()
11181125

1119-
def load(self, tupleseq, tps = 40):
1120-
if self.closed:
1121-
self.prepare()
1126+
def _load_bulk_tuples(self, tupleseq, tps = None):
1127+
tps = tps or 64
11221128
last = pq.element.FlushMessage
11231129
tupleseqiter = iter(tupleseq)
11241130
try:
@@ -1164,6 +1170,19 @@ def load(self, tupleseq, tps = 40):
11641170
self.database.synchronize()
11651171
raise
11661172

1173+
def load(self, iterable, tps = None):
1174+
"""
1175+
Execute the query for each parameter set in `iterable`.
1176+
1177+
In cases of ``COPY ... FROM STDIN``, iterable must be an iterable `bytes`.
1178+
"""
1179+
if self.closed:
1180+
self.prepare()
1181+
if not self._input:
1182+
return self._copy_data_in(iterable, tps = tps)
1183+
else:
1184+
return self._load_bulk_tuples(iterable, tps = tps)
1185+
11671186
class StoredProcedure(pg_api.StoredProcedure):
11681187
ife_ancestor = None
11691188
procedure_id = None

postgresql/test/test_driver.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ def raise_exc(l):
9595
e, v, tb = rl[0]
9696
raise v
9797
self.failUnlessRaises(pg_exc.QueryCanceledError, raise_exc, rl)
98+
99+
def testCopyToSTDOUT(self):
100+
with self.db.xact:
101+
self.db.execute("CREATE TABLE foo (i int)")
102+
foo = self.db.prepare('insert into foo values ($1)')
103+
foo.load(((x,) for x in range(500)))
104+
105+
copy_foo = self.db.prepare('copy foo to stdout')
106+
foo_content = set(copy_foo)
107+
expected = set((str(i).encode('ascii') + b'\n' for i in range(500)))
108+
self.failUnlessEqual(expected, foo_content)
109+
self.db.execute("DROP TABLE foo")
110+
111+
def testCopyFromSTDIN(self):
112+
with self.db.xact:
113+
self.db.execute("CREATE TABLE foo (i int)")
114+
foo = self.db.prepare('copy foo from stdin')
115+
foo.load((str(i).encode('ascii') + b'\n' for i in range(1000)))
116+
foo_content = list((
117+
x for (x,) in self.db.prepare('select * from foo order by 1 ASC')
118+
))
119+
self.failUnlessEqual(foo_content, list(range(1000)))
120+
self.db.execute("DROP TABLE foo")
98121

99122
def testLookupProcByName(self):
100123
self.db.execute(

0 commit comments

Comments
 (0)