-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathdbtest.py
More file actions
363 lines (306 loc) · 11 KB
/
dbtest.py
File metadata and controls
363 lines (306 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""
The framework for making database tests.
"""
from __future__ import print_function
import logging
import os
import sys
from pytest import raises, skip
import sqlobject
from sqlobject.col import use_microseconds
import sqlobject.conftest as conftest
if sys.platform[:3] == "win":
def getcwd():
return os.getcwd().replace('\\', '/')
else:
getcwd = os.getcwd
"""
supportsMatrix defines what database backends support what features.
Each feature has a name, if you see a key like '+featureName' then
only the databases listed support the feature. Conversely,
'-featureName' means all databases *except* the ones listed support
the feature. The databases are given by their SQLObject string name,
separated by spaces.
The function supports(featureName) returns True or False based on this,
and you can use it like::
def test_featureX():
if not supports('featureX'):
pytest.skip("Doesn't support featureX")
"""
supportsMatrix = {
'-blobData': 'mssql',
'-decimalColumn': 'mssql',
'-dropTableCascade': 'sybase mssql mysql',
'-emptyTable': 'mssql',
'+exceptions': 'mysql postgres sqlite',
'-expressionIndex': 'mysql sqlite firebird mssql',
'-limitSelect': 'mssql',
'+memorydb': 'sqlite',
'+rlike': 'mysql postgres sqlite',
'+schema': 'postgres',
'-transactions': ' ',
}
def setupClass(soClasses, force=False):
"""
Makes sure the classes have a corresponding and correct table.
This won't recreate the table if it already exists. It will check
that the table is properly defined (in case you change your table
definition).
You can provide a single class or a list of classes; if a list
then classes will be created in the order you provide, and
destroyed in the opposite order. So if class A depends on class
B, then do setupClass([B, A]) and B won't be destroyed or cleared
until after A is destroyed or cleared.
If force is true, then the database will be recreated no matter
what.
"""
# global hub
if not isinstance(soClasses, (list, tuple)):
soClasses = [soClasses]
connection = getConnection()
for soClass in soClasses:
# This would be an alternate way to register connections:
# try:
# hub
# except NameError:
# hub = sqlobject.dbconnection.ConnectionHub()
# soClass._connection = hub
# hub.threadConnection = connection
# hub.processConnection = connection
soClass._connection = connection
installOrClear(soClasses, force=force)
return soClasses
def speedupSQLiteConnection(connection):
connection.query("PRAGMA synchronous=OFF")
connection.query("PRAGMA count_changes=OFF")
connection.query("PRAGMA journal_mode=MEMORY")
connection.query("PRAGMA temp_store=MEMORY")
installedDBFilename = os.path.join(getcwd(), 'dbs_data.tmp')
installedDBTracker = sqlobject.connectionForURI(
'sqlite:///' + installedDBFilename)
speedupSQLiteConnection(installedDBTracker)
def getConnection(**kw):
name = getConnectionURI()
conn = sqlobject.connectionForURI(name, **kw)
if conftest.option.show_sql:
conn.debug = True
if conftest.option.show_sql_output:
conn.debugOutput = True
if (conn.dbName == 'postgres') and (conn.driver == 'pg8000') \
and conn._pool is not None:
conn._pool = None
if (conn.dbName == 'sqlite') and not conn._memory:
speedupSQLiteConnection(conn)
return conn
def getConnectionURI():
name = conftest.option.Database
if name in conftest.connectionShortcuts:
name = conftest.connectionShortcuts[name]
return name
try:
connection = getConnection()
except Exception as e:
# At least this module should be importable...
# The module was imported during documentation building
if 'sphinx' not in sys.modules:
print("Could not open database: %s" % e, file=sys.stderr)
else:
if (connection.dbName == 'firebird'):
use_microseconds(False)
class InstalledTestDatabase(sqlobject.SQLObject):
"""
This table is set up in SQLite (always, regardless of --Database) and
tracks what tables have been set up in the 'real' database. This
way we don't keep recreating the tables over and over when there
are multiple tests that use a table.
"""
_connection = installedDBTracker
table_name = sqlobject.StringCol(notNull=True)
createSQL = sqlobject.StringCol(notNull=True)
connectionURI = sqlobject.StringCol(notNull=True)
@classmethod
def installOrClear(cls, soClasses, force=False):
cls.setup()
reversed = list(soClasses)[:]
reversed.reverse()
# If anything needs to be dropped, they all must be dropped
# But if we're forcing it, then we'll always drop
if force:
any_drops = True
else:
any_drops = False
for soClass in reversed:
table = soClass.sqlmeta.table
if not soClass._connection.tableExists(table):
continue
items = list(cls.selectBy(
table_name=table,
connectionURI=soClass._connection.uri()))
if items:
instance = items[0]
sql = instance.createSQL
else:
sql = None
newSQL, constraints = soClass.createTableSQL()
if sql != newSQL:
if sql is not None:
instance.destroySelf()
any_drops = True
break
for soClass in reversed:
if soClass._connection.tableExists(soClass.sqlmeta.table):
if any_drops:
cls.drop(soClass)
else:
cls.clear(soClass)
for soClass in soClasses:
table = soClass.sqlmeta.table
if not soClass._connection.tableExists(table):
cls.install(soClass)
@classmethod
def install(cls, soClass):
"""
Creates the given table in its database.
"""
sql = getattr(soClass, soClass._connection.dbName + 'Create',
None)
all_extra = []
if sql:
soClass._connection.query(sql)
else:
sql, extra_sql = soClass.createTableSQL()
soClass.createTable(applyConstraints=False)
all_extra.extend(extra_sql)
cls(table_name=soClass.sqlmeta.table,
createSQL=sql,
connectionURI=soClass._connection.uri())
for extra_sql in all_extra:
soClass._connection.query(extra_sql)
@classmethod
def drop(cls, soClass):
"""
Drops a the given table from its database
"""
sql = getattr(soClass, soClass._connection.dbName + 'Drop', None)
if sql:
soClass._connection.query(sql)
else:
soClass.dropTable()
@classmethod
def clear(cls, soClass):
"""
Removes all the rows from a table.
"""
soClass.clearTable()
@classmethod
def setup(cls):
"""
This sets up *this* table.
"""
if not cls._connection.tableExists(cls.sqlmeta.table):
cls.createTable()
installOrClear = InstalledTestDatabase.installOrClear
class Dummy(object):
"""
Used for creating fake objects; a really poor 'mock object'.
"""
def __init__(self, **kw):
for name, value in kw.items():
setattr(self, name, value)
def inserts(cls, data, schema=None):
"""
Creates a bunch of rows.
You can use it like::
inserts(Person, [{'fname': 'blah', 'lname': 'doe'}, ...])
Or::
inserts(Person, [('blah', 'doe')], schema=
['fname', 'lname'])
If you give a single string for the `schema` then it'll split
that string to get the list of column names.
"""
if schema:
if isinstance(schema, str):
schema = schema.split()
keywordData = []
for item in data:
itemDict = {}
for name, value in zip(schema, item):
itemDict[name] = value
keywordData.append(itemDict)
data = keywordData
results = []
for args in data:
results.append(cls(**args))
return results
def supports(feature):
dbName = connection.dbName
support = supportsMatrix.get('+' + feature, None)
notSupport = supportsMatrix.get('-' + feature, None)
if support is not None and dbName in support.split():
return True
elif support:
return False
if notSupport is not None and dbName in notSupport.split():
return False
elif notSupport:
return True
assert notSupport is not None or support is not None, (
"The supportMatrix does not list this feature: %r"
% feature)
# To avoid name clashes:
_inserts = inserts
def setSQLiteConnectionFactory(TableClass, factory):
from sqlobject.sqlite.sqliteconnection import SQLiteConnection
conn = TableClass._connection
TableClass._connection = SQLiteConnection(
filename=conn.filename,
name=conn.name, debug=conn.debug, debugOutput=conn.debugOutput,
cache=conn.cache, style=conn.style, autoCommit=conn.autoCommit,
debugThreading=conn.debugThreading, registry=conn.registry,
factory=factory
)
speedupSQLiteConnection(TableClass._connection)
installOrClear([TableClass])
def deprecated_module():
sqlobject.main.warnings_level = None
sqlobject.main.exception_level = None
def setup_module(mod):
# modules with '_old' test backward compatible methods, so they
# don't get warnings or errors.
mod_name = str(mod.__name__)
if mod_name.endswith('/py'):
mod_name = mod_name[:-3]
if mod_name.endswith('_old'):
sqlobject.main.warnings_level = None
sqlobject.main.exception_level = None
else:
sqlobject.main.warnings_level = None
sqlobject.main.exception_level = 0
def teardown_module(mod=None):
sqlobject.main.warnings_level = None
sqlobject.main.exception_level = 0
def setupLogging():
fmt = '[%(asctime)s] %(name)s %(levelname)s: %(message)s'
formatter = logging.Formatter(fmt)
hdlr = logging.StreamHandler(sys.stderr)
hdlr.setFormatter(formatter)
hdlr.setLevel(logging.NOTSET)
logger = logging.getLogger()
logger.addHandler(hdlr)
def setupCyclicClasses(*classes):
if not supports('dropTableCascade'):
skip("dropTableCascade isn't supported")
conn = getConnection()
for soClass in classes:
soClass.setConnection(conn)
soClass.dropTable(ifExists=True, cascade=True)
constraints = []
for soClass in classes:
constraints += soClass.createTable(ifNotExists=True,
applyConstraints=False)
for constraint in constraints:
conn.query(constraint)
__all__ = ['Dummy', 'deprecated_module', 'getConnection', 'getConnectionURI',
'inserts', 'raises', 'setupClass', 'setupCyclicClasses',
'setupLogging', 'setup_module', 'supports', 'teardown_module',
]