diff --git a/pg/adapt.py b/pg/adapt.py index 6f65be0d..68a690d7 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -90,7 +90,7 @@ class _SimpleTypes(dict): _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ 'bool': [bool], - 'bytea': [Bytea], + 'bytea': [Bytea, bytes], 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', 'abstime', 'reltime', # these are very old 'datetime', 'timedelta', # these do not really exist @@ -99,7 +99,7 @@ class _SimpleTypes(dict): 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], 'num': ['numeric', Decimal], 'money': [], - 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] + 'text': ['bpchar', 'char', 'name', 'varchar', str] }) # noinspection PyMissingConstructor diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 7dc24053..3fe59209 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1714,7 +1714,7 @@ def setUpClass(cls): "i2 smallint, i4 integer, i8 bigint," "b boolean, dt date, ti time, ts timestamp, td interval," "d numeric, f4 real, f8 double precision, m money," - "c char(1), v4 varchar(4), c4 char(4), t text)") + "c char(1), v4 varchar(4), c4 char(4), t text, by bytea)") # Check whether the test database uses SQL_ASCII - this means # that it does not consider encoding when calculating lengths. c.query("set client_encoding=utf8") @@ -1748,16 +1748,17 @@ def tearDown(self): data: Sequence[tuple] = [ (-1, -1, -1, True, '1492-10-12', '08:30:00', '1492-10-12 08:30:00', '-3 days', - -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), + -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz', b'aaa'), (0, 0, 0, False, '1607-04-14', '09:00:00', '1607-04-14 09:00:00', '7 days', - 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), + 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890', b'bbb'), (1, 1, 1, True, '1801-03-04', '03:45:00', '1801-03-04 03:45:00', '3 mons', - 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), + 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g', b'ccc'), (2, 2, 2, False, '1903-12-17', '11:22:00', '1903-12-17 11:22:00', '1 year', - 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] + 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!', + b'ddd')] @classmethod def db_len(cls, s, encoding): @@ -1819,6 +1820,8 @@ def get_back(self, encoding='utf-8'): row[14] = row[14].rstrip() if row[15] is not None: # text self.assertIsInstance(row[15], str) + if row[16] is not None: # bytes + self.assertIsInstance(row[16], bytes) row = tuple(row) data.append(row) return data @@ -1901,7 +1904,7 @@ def test_inserttable_multiple_calls(self): self.assertEqual(r, num_rows) def test_inserttable_null_values(self): - data = [(None,) * 16] * 100 + data = [(None,) * 17] * 100 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -1919,10 +1922,16 @@ def test_inserttable_datetime_adapt(self): self.assertEqual(back, [( '1999-01-02', '11:12:13', '1999-01-02 11:12:13', '123 days')]) + def test_inserttable_bytea_adapt(self): + data = [(b'123',)] + self.c.inserttable('test', data, ['by']) + back = [row[16:17] for row in self.get_back()] + self.assertEqual(back, data) + def test_inserttable_only_one_column(self): data: list[tuple] = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) - data = [tuple([42 if i == 1 else None for i in range(16)])] * 50 + data = [tuple([42 if i == 1 else None for i in range(17)])] * 50 self.assertEqual(self.get_back(), data) def test_inserttable_only_two_columns(self): @@ -1930,7 +1939,7 @@ def test_inserttable_only_two_columns(self): self.c.inserttable('test', data, ('b', 'f4')) # noinspection PyTypeChecker data = [(None,) * 3 + (bool(i % 2),) + (None,) * 5 + (i * .5,) - + (None,) * 6 for i in range(20)] + + (None,) * 7 for i in range(20)] self.assertEqual(self.get_back(), data) def test_inserttable_with_dotted_table_name(self): @@ -2003,7 +2012,7 @@ def test_inserttable_max_values(self): data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, True, '2999-12-31', '11:59:59', '2999-12-31 23:59:59', '9999 years', 1e99, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, - "1", "1234", "1234", "1234" * 100)] + "1", "1234", "1234", "1234" * 100, b'1'*100)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -2018,7 +2027,7 @@ def test_inserttable_byte_values(self): 0, 0, 0, False, '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', - c, 'bäd', 'bäd', "käse сыр pont-l'évêque") + c, 'bäd', 'bäd', "käse сыр pont-l'évêque", b'f') row_bytes = tuple( s.encode() if isinstance(s, str) else s for s in row_unicode) @@ -2038,7 +2047,7 @@ def test_inserttable_unicode_utf8(self): 0, 0, 0, False, '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', - c, 'bäd', 'bäd', "käse сыр pont-l'évêque") + c, 'bäd', 'bäd', "käse сыр pont-l'évêque", b'g') data = [row_unicode] * 2 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -2055,7 +2064,7 @@ def test_inserttable_unicode_latin1(self): 0, 0, 0, False, '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', - c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €", b'') data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) @@ -2079,7 +2088,7 @@ def test_inserttable_unicode_latin9(self): 0, 0, 0, False, '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', - c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €", b'') data = [row_unicode] * 2 self.c.inserttable('test', data) self.assertEqual(self.get_back('latin9'), data) @@ -2102,11 +2111,12 @@ def test_inserttable_from_query(self): "select 2::int2 as i2, 4::int4 as i4, 8::int8 as i8, true as b," "null as dt, null as ti, null as ts, null as td, null as d," "4.5::float as float4, 8.5::float8 as f8," - "null as m, 'c' as c, 'v4' as v4, null as c4, 'text' as text") + "null as m, 'c' as c, 'v4' as v4, null as c4, 'text' as text," + "'bytes'::bytea AS by") self.c.inserttable('test', data) self.assertEqual(self.get_back(), [ (2, 4, 8, True, None, None, None, None, None, 4.5, 8.5, - None, 'c', 'v4', None, 'text')]) + None, 'c', 'v4', None, 'text', b'bytes')]) def test_inserttable_special_chars(self): class S: diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 61a06ed9..44b8ae04 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4308,8 +4308,9 @@ def tearDown(self): def test_guess_simple_type(self): f = self.adapter.guess_simple_type self.assertEqual(f(pg.Bytea(b'test')), 'bytea') + self.assertEqual(f(b'test'), 'bytea') self.assertEqual(f('string'), 'text') - self.assertEqual(f(b'string'), 'text') + self.assertEqual(f(b'string'), 'bytea') self.assertEqual(f(True), 'bool') self.assertEqual(f(3), 'int') self.assertEqual(f(2.75), 'float')