Skip to content

Commit e771d71

Browse files
committed
Allow specifying types as Python classes (PyGreSQL#38)
1 parent cc55191 commit e771d71

5 files changed

Lines changed: 197 additions & 48 deletions

File tree

docs/contents/changelog.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ Version 5.2 (to be released)
1616
instead of Exception, as required by the DB-API 2 compliance test.
1717
- Connection arguments containing single quotes caused problems
1818
(reported and fixed by Tyler Ramer and Jamie McAtamney).
19-
- The `types` parameer of `format_query` can now be passed as a string
20-
that will be split on whitespace when values are passed as a sequence.
19+
- The `types` parameter of `format_query` can now be passed as a string
20+
that will be split on whitespace when values are passed as a sequence,
21+
and the types can now also be specified using actual Python types
22+
instead of type names (#38, suggested by Justin Pryzby).
2123

2224
Version 5.1.2 (2020-04-19)
2325
--------------------------

docs/contents/pg/db_wrapper.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,24 @@ Example::
514514
"update employees set phone=%(phone)s where name=%(name)s",
515515
dict(name=name, phone=phone)).getresult()[0][0]
516516

517+
Example with specification of types::
518+
519+
db.query_formatted(
520+
"update orders set info=%s where id=%s",
521+
({'customer': 'Joe', 'product': 'beer'}, 'id': 7),
522+
types=('json', 'int'))
523+
# or
524+
db.query_formatted(
525+
"update orders set info=%s where id=%s",
526+
({'customer': 'Joe', 'product': 'beer'}, 'id': 7),
527+
types=('json int'))
528+
# or
529+
db.query_formatted(
530+
"update orders set info=%(info)s where id=%(id)s",
531+
{'info': {'customer': 'Joe', 'product': 'beer'}, 'id': 7},
532+
types={'info': 'json', 'id': 'int'})
533+
534+
517535
query_prepared -- execute a prepared statement
518536
----------------------------------------------
519537

pg.py

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@
8787
except NameError: # Python >= 3.0
8888
long = int
8989

90+
try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable
91+
unicode
92+
except NameError: # Python >= 3.0
93+
unicode = str
94+
9095
try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable
9196
basestring
9297
except NameError: # Python >= 3.0
@@ -253,10 +258,51 @@ def _oid_key(table):
253258
return 'oid(%s)' % table
254259

255260

261+
class Bytea(bytes):
262+
"""Wrapper class for marking Bytea values."""
263+
264+
265+
class Hstore(dict):
266+
"""Wrapper class for marking hstore values."""
267+
268+
_re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
269+
270+
@classmethod
271+
def _quote(cls, s):
272+
if s is None:
273+
return 'NULL'
274+
if not isinstance(s, basestring):
275+
s = str(s)
276+
if not s:
277+
return '""'
278+
s = s.replace('"', '\\"')
279+
if cls._re_quote.search(s):
280+
s = '"%s"' % s
281+
return s
282+
283+
def __str__(self):
284+
q = self._quote
285+
return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
286+
287+
288+
class Json:
289+
"""Wrapper class for marking Json values."""
290+
291+
def __init__(self, obj, encode=None):
292+
self.obj = obj
293+
self.encode = encode or jsonencode
294+
295+
def __str__(self):
296+
obj = self.obj
297+
if isinstance(obj, basestring):
298+
return obj
299+
return self.encode(obj)
300+
301+
256302
class _SimpleTypes(dict):
257303
"""Dictionary mapping pg_type names to simple type names."""
258304

259-
_types = {
305+
_type_strings = {
260306
'bool': 'bool',
261307
'bytea': 'bytea',
262308
'date': 'date interval time timetz timestamp timestamptz'
@@ -267,9 +313,20 @@ class _SimpleTypes(dict):
267313
'num': 'numeric', 'money': 'money',
268314
'text': 'bpchar char name text varchar'}
269315

316+
_type_classes = {
317+
bool: 'bool', float: 'float', int: 'int',
318+
bytes: 'text' if bytes is str else 'bytea', unicode: 'text',
319+
date: 'date', time: 'date', datetime: 'date', timedelta: 'date',
320+
Decimal: 'num', Bytea: 'bytea', Json: 'json', Hstore: 'hstore',
321+
}
322+
323+
if long is not int:
324+
_type_classes[long] = 'num'
325+
270326
# noinspection PyMissingConstructor
271327
def __init__(self):
272-
for typ, keys in self._types.items():
328+
self.update(self._type_classes)
329+
for typ, keys in self._type_strings.items():
273330
for key in keys.split():
274331
self[key] = typ
275332
self['_%s' % key] = '%s[]' % typ
@@ -312,38 +369,6 @@ def add(self, value, typ=None):
312369
return '$%d' % len(self)
313370

314371

315-
class Bytea(bytes):
316-
"""Wrapper class for marking Bytea values."""
317-
318-
319-
class Hstore(dict):
320-
"""Wrapper class for marking hstore values."""
321-
322-
_re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
323-
324-
@classmethod
325-
def _quote(cls, s):
326-
if s is None:
327-
return 'NULL'
328-
if not s:
329-
return '""'
330-
s = s.replace('"', '\\"')
331-
if cls._re_quote.search(s):
332-
s = '"%s"' % s
333-
return s
334-
335-
def __str__(self):
336-
q = self._quote
337-
return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
338-
339-
340-
class Json:
341-
"""Wrapper class for marking Json values."""
342-
343-
def __init__(self, obj):
344-
self.obj = obj
345-
346-
347372
class Literal(str):
348373
"""Wrapper class for marking literal SQL values."""
349374

@@ -427,8 +452,22 @@ def _adapt_json(self, v):
427452
return None
428453
if isinstance(v, basestring):
429454
return v
455+
if isinstance(v, Json):
456+
return str(v)
430457
return self.db.encode_json(v)
431458

459+
def _adapt_hstore(self, v):
460+
"""Adapt a hstore parameter."""
461+
if not v:
462+
return None
463+
if isinstance(v, basestring):
464+
return v
465+
if isinstance(v, Hstore):
466+
return str(v)
467+
if isinstance(v, dict):
468+
return str(Hstore(v))
469+
raise TypeError('Hstore parameter %s has wrong type' % v)
470+
432471
@classmethod
433472
def _adapt_text_array(cls, v):
434473
"""Adapt a text type array parameter."""
@@ -588,8 +627,6 @@ def guess_simple_type(cls, value):
588627
return cls._frequent_simple_types[type(value)]
589628
except KeyError:
590629
pass
591-
if isinstance(value, Bytea):
592-
return 'bytea'
593630
if isinstance(value, basestring):
594631
return 'text'
595632
if isinstance(value, bool):
@@ -602,6 +639,12 @@ def guess_simple_type(cls, value):
602639
return 'num'
603640
if isinstance(value, (date, time, datetime, timedelta)):
604641
return 'date'
642+
if isinstance(value, Bytea):
643+
return 'bytea'
644+
if isinstance(value, Json):
645+
return 'json'
646+
if isinstance(value, Hstore):
647+
return 'hstore'
605648
if isinstance(value, list):
606649
return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
607650
if isinstance(value, tuple):
@@ -638,12 +681,6 @@ def adapt_inline(self, value, nested=False):
638681
value = self.db.escape_bytea(value)
639682
if bytes is not str: # Python >= 3.0
640683
value = value.decode('ascii')
641-
elif isinstance(value, Json):
642-
# noinspection PyUnresolvedReferences
643-
if value.encode:
644-
# noinspection PyUnresolvedReferences
645-
return value.encode()
646-
value = self.db.encode_json(value)
647684
elif isinstance(value, (datetime, date, time, timedelta)):
648685
value = str(value)
649686
if isinstance(value, basestring):
@@ -666,6 +703,12 @@ def adapt_inline(self, value, nested=False):
666703
if isinstance(value, tuple):
667704
q = self.adapt_inline
668705
return '(%s)' % ','.join(str(q(v)) for v in value)
706+
if isinstance(value, Json):
707+
value = self.db.escape_string(str(value))
708+
return "'%s'::json" % value
709+
if isinstance(value, Hstore):
710+
value = self.db.escape_string(str(value))
711+
return "'%s'::hstore" % value
669712
pg_repr = getattr(value, '__pg_repr__', None)
670713
if not pg_repr:
671714
raise InterfaceError(
@@ -691,6 +734,9 @@ def format_query(self, command, values=None, types=None, inline=False):
691734
The optional types describe the values and must be passed as a list,
692735
tuple or string (that will be split on whitespace) when values are
693736
passed as a list or tuple, or as a dict if values are passed as a dict.
737+
738+
If inline is set to True, then parameters will be passed inline
739+
together with the query string.
694740
"""
695741
if not values:
696742
return command, []

pgdb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,8 @@ class Hstore(dict):
18551855
def _quote(cls, s):
18561856
if s is None:
18571857
return 'NULL'
1858+
if not isinstance(s, basestring):
1859+
s = str(s)
18581860
if not s:
18591861
return '""'
18601862
quote = cls._re_quote.search(s)

tests/test_classic_dbwrapper.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4343,17 +4343,64 @@ def testAdaptQueryTypedList(self):
43434343
self.assertEqual(sql, 'select $1')
43444344
self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
43454345

4346-
def testAdaptQueryTypedListWithString(self):
4346+
def testAdaptQueryTypedListWithTypesAsString(self):
43474347
format_query = self.adapter.format_query
4348-
self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',))
4348+
self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), 'int2')
43494349
self.assertRaises(
4350-
TypeError, format_query, '%s,%s', (1,), ('int2', 'int2'))
4350+
TypeError, format_query, '%s,%s', (1,), 'int2 int2')
43514351
values = (3, 7.5, 'hello', True)
4352-
types = 'int4 float4 text bool' # pass types as list
4352+
types = 'int4 float4 text bool' # pass types as string
43534353
sql, params = format_query("select %s,%s,%s,%s", values, types)
43544354
self.assertEqual(sql, 'select $1,$2,$3,$4')
43554355
self.assertEqual(params, [3, 7.5, 'hello', 't'])
43564356

4357+
def testAdaptQueryTypedListWithTypesAsClasses(self):
4358+
format_query = self.adapter.format_query
4359+
self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), (int,))
4360+
self.assertRaises(
4361+
TypeError, format_query, '%s,%s', (1,), (int, int))
4362+
values = (3, 7.5, 'hello', True)
4363+
types = (int, float, str, bool) # pass types as classes
4364+
sql, params = format_query("select %s,%s,%s,%s", values, types)
4365+
self.assertEqual(sql, 'select $1,$2,$3,$4')
4366+
self.assertEqual(params, [3, 7.5, 'hello', 't'])
4367+
4368+
def testAdaptQueryTypedListWithJson(self):
4369+
format_query = self.adapter.format_query
4370+
value = {'test': [1, "it's fine", 3]}
4371+
sql, params = format_query("select %s", (value,), 'json')
4372+
self.assertEqual(sql, 'select $1')
4373+
self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}'])
4374+
value = pg.Json({'test': [1, "it's fine", 3]})
4375+
sql, params = format_query("select %s", (value,), 'json')
4376+
self.assertEqual(sql, 'select $1')
4377+
self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}'])
4378+
value = {'test': [1, "it's fine", 3]}
4379+
sql, params = format_query("select %s", [value], [pg.Json])
4380+
self.assertEqual(sql, 'select $1')
4381+
self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}'])
4382+
4383+
def testAdaptQueryTypedWithHstore(self):
4384+
format_query = self.adapter.format_query
4385+
value = {'one': "it's fine", 'two': 2}
4386+
sql, params = format_query("select %s", (value,), 'hstore')
4387+
self.assertEqual(sql, "select $1")
4388+
if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict
4389+
params[0] = ','.join(sorted(params[0].split(',')))
4390+
self.assertEqual(params, ['one=>"it\'s fine\",two=>2'])
4391+
value = pg.Hstore({'one': "it's fine", 'two': 2})
4392+
sql, params = format_query("select %s", (value,), 'hstore')
4393+
self.assertEqual(sql, "select $1")
4394+
if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict
4395+
params[0] = ','.join(sorted(params[0].split(',')))
4396+
self.assertEqual(params, ['one=>"it\'s fine\",two=>2'])
4397+
value = pg.Hstore({'one': "it's fine", 'two': 2})
4398+
sql, params = format_query("select %s", [value], [pg.Hstore])
4399+
self.assertEqual(sql, "select $1")
4400+
if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict
4401+
params[0] = ','.join(sorted(params[0].split(',')))
4402+
self.assertEqual(params, ['one=>"it\'s fine\",two=>2'])
4403+
43574404
def testAdaptQueryTypedDict(self):
43584405
format_query = self.adapter.format_query
43594406
self.assertRaises(
@@ -4423,6 +4470,22 @@ def testAdaptQueryUntypedList(self):
44234470
self.assertEqual(sql, 'select $1')
44244471
self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
44254472

4473+
def testAdaptQueryUntypedListWithJson(self):
4474+
format_query = self.adapter.format_query
4475+
value = pg.Json({'test': [1, "it's fine", 3]})
4476+
sql, params = format_query("select %s", (value,))
4477+
self.assertEqual(sql, 'select $1')
4478+
self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}'])
4479+
4480+
def testAdaptQueryUntypedWithHstore(self):
4481+
format_query = self.adapter.format_query
4482+
value = pg.Hstore({'one': "it's fine", 'two': 2})
4483+
sql, params = format_query("select %s", (value,))
4484+
self.assertEqual(sql, "select $1")
4485+
if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict
4486+
params[0] = ','.join(sorted(params[0].split(',')))
4487+
self.assertEqual(params, ['one=>"it\'s fine\",two=>2'])
4488+
44264489
def testAdaptQueryUntypedDict(self):
44274490
format_query = self.adapter.format_query
44284491
values = dict(i=3, f=7.5, t='hello', b=True)
@@ -4478,6 +4541,24 @@ def testAdaptQueryInlineList(self):
44784541
sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
44794542
self.assertEqual(params, [])
44804543

4544+
def testAdaptQueryInlineListWithJson(self):
4545+
format_query = self.adapter.format_query
4546+
value = pg.Json({'test': [1, "it's fine", 3]})
4547+
sql, params = format_query("select %s", (value,), inline=True)
4548+
self.assertEqual(
4549+
sql, "select '{\"test\": [1, \"it''s fine\", 3]}'::json")
4550+
self.assertEqual(params, [])
4551+
4552+
def testAdaptQueryInlineListWithHstore(self):
4553+
format_query = self.adapter.format_query
4554+
value = pg.Hstore({'one': "it's fine", 'two': 2})
4555+
sql, params = format_query("select %s", (value,), inline=True)
4556+
if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict
4557+
sql = sql[:8] + ','.join(sorted(sql[8:-9].split(','))) + sql[-9:]
4558+
self.assertEqual(
4559+
sql, "select 'one=>\"it''s fine\",two=>2'::hstore")
4560+
self.assertEqual(params, [])
4561+
44814562
def testAdaptQueryInlineDict(self):
44824563
format_query = self.adapter.format_query
44834564
values = dict(i=3, f=7.5, t='hello', b=True)

0 commit comments

Comments
 (0)