|
10 | 10 | apilevel = '2.0' |
11 | 11 |
|
12 | 12 | from operator import itemgetter |
| 13 | +from functools import partial |
13 | 14 | import re |
14 | 15 | import postgresql.driver as pg_driver |
15 | 16 | import postgresql.types as pg_type |
16 | 17 | import postgresql.string as pg_str |
17 | 18 | import datetime, time |
18 | 19 |
|
19 | | -find_parameters = re.compile(r'%\(([^)]+)\)s') |
| 20 | +## |
| 21 | +# Basically, is it a mapping, or is it a sequence? |
| 22 | +# If findall()'s first index is 's', it's a sequence. |
| 23 | +# If it starts with '(', it's mapping. |
| 24 | +# The pain here is due to a need to recognize any %% escapes. |
| 25 | +parameters_re = re.compile( |
| 26 | + r'(?:%%)+|%(s|[(][^)]*[)]s)' |
| 27 | +) |
| 28 | +def percent_parameters(sql): |
| 29 | + # filter any %% matches(empty strings). |
| 30 | + return [ |
| 31 | + x for x in parameters_re.findall(sql) if x |
| 32 | + ] |
| 33 | + |
| 34 | +def convert_keywords(keys, mapping): |
| 35 | + return [ |
| 36 | + mapping[k] for k in keys |
| 37 | + ] |
20 | 38 |
|
21 | 39 | from postgresql.exceptions import \ |
22 | 40 | Error, DataError, InternalError, \ |
@@ -61,15 +79,6 @@ def dbapi_type(typid): |
61 | 79 | elif typid == pg_type.OIDOID: |
62 | 80 | return ROWID |
63 | 81 |
|
64 | | -def convert_keyword_parameters(nseq, seq): |
65 | | - """ |
66 | | - Given a sequence of keywords, `nseq`, yield each mapping object in `seq` |
67 | | - as a tuple whose objects are the values of the keys specified in `nseq` in |
68 | | - an order consistent with that in `nseq` |
69 | | - """ |
70 | | - for x in seq: |
71 | | - yield [x[y] for y in nseq] |
72 | | - |
73 | 82 | class Cursor(object): |
74 | 83 | rowcount = -1 |
75 | 84 | arraysize = 1 |
@@ -117,76 +126,85 @@ def nextset(self): |
117 | 126 | del self._portal |
118 | 127 | return len(self.__portals) or None |
119 | 128 |
|
120 | | - def execute(self, query, parameters = None): |
121 | | - if parameters: |
122 | | - parameters = list(parameters.items()) |
123 | | - pnmap = {} |
124 | | - plist = [] |
125 | | - for x in range(len(parameters)): |
126 | | - pnmap[parameters[x][0]] = '$' + str(x + 1) |
127 | | - plist.append(parameters[x][1]) |
128 | | - # Substitute %(key)s with the $x positional parameter number |
129 | | - rqparts = [] |
130 | | - for qpart in pg_str.split(query): |
131 | | - if type(qpart) is type(()): |
132 | | - # quoted section |
133 | | - rqparts.append(qpart) |
| 129 | + def _convert_query(self, string): |
| 130 | + parts = list(pg_str.split(string)) |
| 131 | + style = None |
| 132 | + count = 0 |
| 133 | + keys = [] |
| 134 | + kmap = {} |
| 135 | + transformer = tuple |
| 136 | + rparts = [] |
| 137 | + for part in parts: |
| 138 | + if type(part) is type(()): |
| 139 | + # skip quoted portions |
| 140 | + rparts.append(part) |
| 141 | + else: |
| 142 | + r = percent_parameters(part) |
| 143 | + pcount = 0 |
| 144 | + for x in r: |
| 145 | + if x == 's': |
| 146 | + pcount += 1 |
| 147 | + else: |
| 148 | + x = x[1:-2] |
| 149 | + if x not in keys: |
| 150 | + kmap[x] = '$' + str(len(keys) + 1) |
| 151 | + keys.append(x) |
| 152 | + if r: |
| 153 | + if pcount: |
| 154 | + # format |
| 155 | + params = tuple([ |
| 156 | + '$' + str(i+1) for i in range(count, count + pcount) |
| 157 | + ]) |
| 158 | + count += pcount |
| 159 | + rparts.append(part % params) |
| 160 | + else: |
| 161 | + # pyformat |
| 162 | + rparts.append(part % kmap) |
134 | 163 | else: |
135 | | - rqparts.append(qpart % pnmap) |
136 | | - q = self.database.prepare(pg_str.unsplit(rqparts)) |
137 | | - r = q(*plist) |
138 | | - else: |
139 | | - q = self.database.prepare(query) |
140 | | - r = q() |
141 | | - |
142 | | - if q._output is not None and len(q._output) > 0: |
| 164 | + # no parameters identified in string |
| 165 | + rparts.append(part) |
| 166 | + |
| 167 | + if keys: |
| 168 | + if count: |
| 169 | + raise TypeError( |
| 170 | + "keyword parameters and positional parameters used in query" |
| 171 | + ) |
| 172 | + transformer = partial(convert_keywords, keys) |
| 173 | + count = len(keys) |
| 174 | + |
| 175 | + return (pg_str.unsplit(rparts) if rparts else string, transformer, count) |
| 176 | + |
| 177 | + def execute(self, statement, parameters = ()): |
| 178 | + sql, pxf, nparams = self._convert_query(statement) |
| 179 | + if nparams != -1 and len(parameters) != nparams: |
| 180 | + raise TypeError( |
| 181 | + "statement require %d parameters, given %d" %( |
| 182 | + nparams, len(parameters) |
| 183 | + ) |
| 184 | + ) |
| 185 | + ps = self.database.prepare(sql) |
| 186 | + c = ps(*pxf(parameters)) |
| 187 | + if ps._output is not None and len(ps._output) > 0: |
143 | 188 | # name, relationId, columnNumber, typeId, typlen, typmod, format |
144 | 189 | self.description = tuple([ |
145 | 190 | (self.database.typio.decode(x[0]), dbapi_type(x[3]), |
146 | 191 | None, None, None, None, None) |
147 | | - for x in q._output |
| 192 | + for x in ps._output |
148 | 193 | ]) |
149 | | - self.__portals.insert(0, r) |
| 194 | + self.__portals.insert(0, c) |
150 | 195 | else: |
151 | 196 | self.description = None |
152 | 197 | if self.__portals: |
153 | 198 | del self._portal |
154 | 199 | return self |
155 | 200 |
|
156 | | - def _convert_query(self, string, map): |
157 | | - rqparts = [] |
158 | | - for qpart in pg_str.split(string): |
159 | | - if type(qpart) is type(()): |
160 | | - rqparts.append(qpart) |
161 | | - else: |
162 | | - rqparts.append(qpart % map) |
163 | | - return pg_str.unsplit(rqparts) |
164 | | - |
165 | | - def _statement_params(self, string): |
166 | | - map = {} |
167 | | - param_num = 1 |
168 | | - for qpart in pg_str.split(string): |
169 | | - if type(qpart) is not type(()): |
170 | | - for x in find_parameters.finditer(qpart): |
171 | | - pname = x.group(1) |
172 | | - if pname not in map: |
173 | | - map[pname] = param_num |
174 | | - param_num += 1 |
175 | | - return map |
176 | | - |
177 | | - def executemany(self, query, param_iter): |
178 | | - mapseq = list(self._statement_params(query).items()) |
179 | | - realquery = self._convert_query(query, { |
180 | | - k : '$' + str(v) for k,v in mapseq |
181 | | - }) |
182 | | - mapseq.sort(key = itemgetter(1)) |
183 | | - nseq = [x[0] for x in mapseq] |
184 | | - q = self.database.prepare(realquery) |
185 | | - q.prepare() |
186 | | - if q._input is not None: |
187 | | - q.load(convert_keyword_parameters(nseq, param_iter)) |
| 201 | + def executemany(self, statement, parameters): |
| 202 | + sql, pxf, nparams = self._convert_query(statement) |
| 203 | + ps = self.database.prepare(sql) |
| 204 | + if ps._input is not None: |
| 205 | + ps.load(map(pxf, parameters)) |
188 | 206 | else: |
189 | | - q.load(param_iter) |
| 207 | + ps.load(parameters) |
190 | 208 | return self |
191 | 209 |
|
192 | 210 | def close(self): |
|
0 commit comments