Skip to content

Commit b6e602e

Browse files
author
James William Pye
committed
Implement subjective paramstyles.
This allows users to make use of %s formats for sequence-based parameters. The prior functionality appears to have been compliant, but this functionality is supported by other drivers. And people, apparently, use it.
1 parent 0a70580 commit b6e602e

2 files changed

Lines changed: 99 additions & 67 deletions

File tree

postgresql/driver/dbapi20.py

Lines changed: 85 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,31 @@
1010
apilevel = '2.0'
1111

1212
from operator import itemgetter
13+
from functools import partial
1314
import re
1415
import postgresql.driver as pg_driver
1516
import postgresql.types as pg_type
1617
import postgresql.string as pg_str
1718
import datetime, time
1819

19-
find_parameters = re.compile(r'%\(([^)]+)\)s')
20+
##
21+
# Basically, is it a mapping, or is it a sequence?
22+
# If findall()'s first index is 's', it's a sequence.
23+
# If it starts with '(', it's mapping.
24+
# The pain here is due to a need to recognize any %% escapes.
25+
parameters_re = re.compile(
26+
r'(?:%%)+|%(s|[(][^)]*[)]s)'
27+
)
28+
def percent_parameters(sql):
29+
# filter any %% matches(empty strings).
30+
return [
31+
x for x in parameters_re.findall(sql) if x
32+
]
33+
34+
def convert_keywords(keys, mapping):
35+
return [
36+
mapping[k] for k in keys
37+
]
2038

2139
from postgresql.exceptions import \
2240
Error, DataError, InternalError, \
@@ -61,15 +79,6 @@ def dbapi_type(typid):
6179
elif typid == pg_type.OIDOID:
6280
return ROWID
6381

64-
def convert_keyword_parameters(nseq, seq):
65-
"""
66-
Given a sequence of keywords, `nseq`, yield each mapping object in `seq`
67-
as a tuple whose objects are the values of the keys specified in `nseq` in
68-
an order consistent with that in `nseq`
69-
"""
70-
for x in seq:
71-
yield [x[y] for y in nseq]
72-
7382
class Cursor(object):
7483
rowcount = -1
7584
arraysize = 1
@@ -117,76 +126,85 @@ def nextset(self):
117126
del self._portal
118127
return len(self.__portals) or None
119128

120-
def execute(self, query, parameters = None):
121-
if parameters:
122-
parameters = list(parameters.items())
123-
pnmap = {}
124-
plist = []
125-
for x in range(len(parameters)):
126-
pnmap[parameters[x][0]] = '$' + str(x + 1)
127-
plist.append(parameters[x][1])
128-
# Substitute %(key)s with the $x positional parameter number
129-
rqparts = []
130-
for qpart in pg_str.split(query):
131-
if type(qpart) is type(()):
132-
# quoted section
133-
rqparts.append(qpart)
129+
def _convert_query(self, string):
130+
parts = list(pg_str.split(string))
131+
style = None
132+
count = 0
133+
keys = []
134+
kmap = {}
135+
transformer = tuple
136+
rparts = []
137+
for part in parts:
138+
if type(part) is type(()):
139+
# skip quoted portions
140+
rparts.append(part)
141+
else:
142+
r = percent_parameters(part)
143+
pcount = 0
144+
for x in r:
145+
if x == 's':
146+
pcount += 1
147+
else:
148+
x = x[1:-2]
149+
if x not in keys:
150+
kmap[x] = '$' + str(len(keys) + 1)
151+
keys.append(x)
152+
if r:
153+
if pcount:
154+
# format
155+
params = tuple([
156+
'$' + str(i+1) for i in range(count, count + pcount)
157+
])
158+
count += pcount
159+
rparts.append(part % params)
160+
else:
161+
# pyformat
162+
rparts.append(part % kmap)
134163
else:
135-
rqparts.append(qpart % pnmap)
136-
q = self.database.prepare(pg_str.unsplit(rqparts))
137-
r = q(*plist)
138-
else:
139-
q = self.database.prepare(query)
140-
r = q()
141-
142-
if q._output is not None and len(q._output) > 0:
164+
# no parameters identified in string
165+
rparts.append(part)
166+
167+
if keys:
168+
if count:
169+
raise TypeError(
170+
"keyword parameters and positional parameters used in query"
171+
)
172+
transformer = partial(convert_keywords, keys)
173+
count = len(keys)
174+
175+
return (pg_str.unsplit(rparts) if rparts else string, transformer, count)
176+
177+
def execute(self, statement, parameters = ()):
178+
sql, pxf, nparams = self._convert_query(statement)
179+
if nparams != -1 and len(parameters) != nparams:
180+
raise TypeError(
181+
"statement require %d parameters, given %d" %(
182+
nparams, len(parameters)
183+
)
184+
)
185+
ps = self.database.prepare(sql)
186+
c = ps(*pxf(parameters))
187+
if ps._output is not None and len(ps._output) > 0:
143188
# name, relationId, columnNumber, typeId, typlen, typmod, format
144189
self.description = tuple([
145190
(self.database.typio.decode(x[0]), dbapi_type(x[3]),
146191
None, None, None, None, None)
147-
for x in q._output
192+
for x in ps._output
148193
])
149-
self.__portals.insert(0, r)
194+
self.__portals.insert(0, c)
150195
else:
151196
self.description = None
152197
if self.__portals:
153198
del self._portal
154199
return self
155200

156-
def _convert_query(self, string, map):
157-
rqparts = []
158-
for qpart in pg_str.split(string):
159-
if type(qpart) is type(()):
160-
rqparts.append(qpart)
161-
else:
162-
rqparts.append(qpart % map)
163-
return pg_str.unsplit(rqparts)
164-
165-
def _statement_params(self, string):
166-
map = {}
167-
param_num = 1
168-
for qpart in pg_str.split(string):
169-
if type(qpart) is not type(()):
170-
for x in find_parameters.finditer(qpart):
171-
pname = x.group(1)
172-
if pname not in map:
173-
map[pname] = param_num
174-
param_num += 1
175-
return map
176-
177-
def executemany(self, query, param_iter):
178-
mapseq = list(self._statement_params(query).items())
179-
realquery = self._convert_query(query, {
180-
k : '$' + str(v) for k,v in mapseq
181-
})
182-
mapseq.sort(key = itemgetter(1))
183-
nseq = [x[0] for x in mapseq]
184-
q = self.database.prepare(realquery)
185-
q.prepare()
186-
if q._input is not None:
187-
q.load(convert_keyword_parameters(nseq, param_iter))
201+
def executemany(self, statement, parameters):
202+
sql, pxf, nparams = self._convert_query(statement)
203+
ps = self.database.prepare(sql)
204+
if ps._input is not None:
205+
ps.load(map(pxf, parameters))
188206
else:
189-
q.load(param_iter)
207+
ps.load(parameters)
190208
return self
191209

192210
def close(self):

postgresql/test/test_dbapi20.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,13 @@ def test_execute(self):
326326
finally:
327327
con.close()
328328

329+
def test_format_execute(self):
330+
self.driver.paramstyle = 'format'
331+
try:
332+
self.test_execute()
333+
finally:
334+
self.driver.paramstyle = 'pyformat'
335+
329336
def _paraminsert(self,cur):
330337
self.executeDDL1(cur)
331338
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
@@ -428,6 +435,13 @@ def test_executemany(self):
428435
finally:
429436
con.close()
430437

438+
def test_format_executemany(self):
439+
self.driver.paramstyle = 'format'
440+
try:
441+
self.test_executemany()
442+
finally:
443+
self.driver.paramstyle = 'pyformat'
444+
431445
def test_fetchone(self):
432446
con = self._connect()
433447
try:

0 commit comments

Comments
 (0)