Skip to content

Commit bade0e3

Browse files
committed
Major code refactoring - centralized all kb.dbms* info for both retrieval and set.
1 parent 4bdc19d commit bade0e3

39 files changed

Lines changed: 926 additions & 821 deletions

lib/controller/action.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"""
99

1010
from lib.controller.handler import setHandler
11-
from lib.core.common import getErrorParsedDBMSesFormatted
11+
from lib.core.common import backend
12+
from lib.core.common import format
1213
from lib.core.common import dataToStdout
1314
from lib.core.data import conf
1415
from lib.core.data import kb
@@ -30,8 +31,8 @@ def action():
3031
# system to be able to go ahead with the injection
3132
setHandler()
3233

33-
if not kb.dbmsDetected or not conf.dbmsHandler:
34-
htmlParsed = getErrorParsedDBMSesFormatted()
34+
if not backend.getDbms() or not conf.dbmsHandler:
35+
htmlParsed = format.getErrorParsedDBMSes()
3536

3637
errMsg = "sqlmap was not able to fingerprint the "
3738
errMsg += "back-end database management system"

lib/controller/checks.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313

1414
from lib.core.agent import agent
1515
from lib.core.common import aliasToDbmsEnum
16+
from lib.core.common import backend
1617
from lib.core.common import beep
1718
from lib.core.common import extractRegexResult
1819
from lib.core.common import findDynamicContent
20+
from lib.core.common import format
1921
from lib.core.common import getComparePageRatio
2022
from lib.core.common import getCompiledRegex
21-
from lib.core.common import getErrorParsedDBMSes
22-
from lib.core.common import getErrorParsedDBMSesFormatted
23-
from lib.core.common import getIdentifiedDBMS
24-
from lib.core.common import getInjectionTests
23+
from lib.core.common import getSortedInjectionTests
2524
from lib.core.common import getUnicode
2625
from lib.core.common import popValue
2726
from lib.core.common import pushValue
@@ -50,6 +49,7 @@
5049
from lib.core.exception import sqlmapUserQuitException
5150
from lib.core.session import setDynamicMarkings
5251
from lib.core.settings import CONSTANT_RATIO
52+
from lib.core.settings import UNKNOWN_DBMS_VERSION
5353
from lib.core.settings import UPPER_RATIO_BOUND
5454
from lib.core.threads import getCurrentThreadData
5555
from lib.core.unescaper import unescaper
@@ -78,8 +78,8 @@ def unescapeDbms(payload, injection, dbms):
7878
payload = unescape(payload, dbms=dbms)
7979
elif conf.dbms is not None:
8080
payload = unescape(payload, dbms=conf.dbms)
81-
elif getIdentifiedDBMS() is not None:
82-
payload = unescape(payload, dbms=getIdentifiedDBMS())
81+
elif backend.getIdentifiedDbms() is not None:
82+
payload = unescape(payload, dbms=backend.getIdentifiedDbms())
8383

8484
return payload
8585

@@ -91,7 +91,7 @@ def checkSqlInjection(place, parameter, value):
9191
# Set the flag for sql injection test mode
9292
kb.testMode = True
9393

94-
for test in getInjectionTests():
94+
for test in getSortedInjectionTests():
9595
try:
9696
if kb.endDetection:
9797
break
@@ -164,19 +164,19 @@ def checkSqlInjection(place, parameter, value):
164164

165165
continue
166166

167-
if len(getErrorParsedDBMSes()) > 0 and dbms not in getErrorParsedDBMSes() and kb.skipOthersDbms is None:
167+
if len(backend.getErrorParsedDBMSes()) > 0 and dbms not in backend.getErrorParsedDBMSes() and kb.skipOthersDbms is None:
168168
msg = "parsed error message(s) showed that the "
169-
msg += "back-end DBMS could be '%s'. " % getErrorParsedDBMSesFormatted()
169+
msg += "back-end DBMS could be %s. " % format.getErrorParsedDBMSes()
170170
msg += "Do you want to skip test payloads specific for other DBMSes? [Y/n]"
171171

172172
if conf.realTest or readInput(msg, default="Y") in ("y", "Y"):
173-
kb.skipOthersDbms = getErrorParsedDBMSes()
173+
kb.skipOthersDbms = backend.getErrorParsedDBMSes()
174174

175175
if kb.skipOthersDbms and dbms not in kb.skipOthersDbms:
176176
debugMsg = "skipping test '%s' because " % title
177177
debugMsg += "the parsed error message(s) showed "
178178
debugMsg += "that the back-end DBMS could be "
179-
debugMsg += "%s" % getErrorParsedDBMSesFormatted()
179+
debugMsg += "%s" % format.getErrorParsedDBMSes()
180180
logger.debug(debugMsg)
181181

182182
continue
@@ -395,7 +395,7 @@ def checkSqlInjection(place, parameter, value):
395395

396396
# Force back-end DBMS according to the current
397397
# test value for proper payload unescaping
398-
kb.misc.forcedDbms = dbms
398+
backend.forceDbms(dbms)
399399

400400
# Skip test if the user provided custom column
401401
# range and this is not a custom UNION test
@@ -407,7 +407,7 @@ def checkSqlInjection(place, parameter, value):
407407

408408
configUnion(test.request.char, test.request.columns)
409409

410-
if not getIdentifiedDBMS():
410+
if not backend.getIdentifiedDbms():
411411
warnMsg = "using unescaped version of the test "
412412
warnMsg += "because of zero knowledge of the "
413413
warnMsg += "back-end DBMS"
@@ -426,8 +426,8 @@ def checkSqlInjection(place, parameter, value):
426426
# by unionTest() directly
427427
where = vector[6]
428428

429-
# Reset back-end DBMS value
430-
kb.misc.forcedDbms = None
429+
# Reset forced back-end DBMS value
430+
backend.flushForcedDbms()
431431

432432
# If the injection test was successful feed the injection
433433
# object with the test's details
@@ -481,18 +481,18 @@ def checkSqlInjection(place, parameter, value):
481481
if inp == injection.dbms:
482482
break
483483
elif inp == dValue:
484-
kb.dbms = aliasToDbmsEnum(inp)
484+
backend.setDbms(inp)
485485
injection.dbms = aliasToDbmsEnum(inp)
486486
injection.dbms_version = None
487487
break
488488
else:
489489
warnMsg = "invalid value"
490490
logger.warn(warnMsg)
491491
elif dKey == "dbms" and injection.dbms is None:
492-
kb.dbms = aliasToDbmsEnum(dValue)
492+
backend.setDbms(dValue)
493493
injection.dbms = aliasToDbmsEnum(dValue)
494494
elif dKey == "dbms_version" and injection.dbms_version is None:
495-
kb.dbmsVersion = [ dValue ]
495+
backend.setVersion(dValue)
496496
injection.dbms_version = dValue
497497
elif dKey == "os" and injection.os is None:
498498
injection.os = dValue
@@ -558,7 +558,7 @@ def heuristicCheckSqlInjection(place, parameter):
558558
infoMsg += "parameter '%s' might " % parameter
559559

560560
if result:
561-
infoMsg += "be injectable (possible DBMS: %s)" % (getErrorParsedDBMSesFormatted() or 'Unknown')
561+
infoMsg += "be injectable (possible DBMS: %s)" % (format.getErrorParsedDBMSes() or UNKNOWN_DBMS_VERSION)
562562
logger.info(infoMsg)
563563
else:
564564
infoMsg += "not be injectable"

lib/controller/handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
See the file 'doc/COPYING' for copying permission
88
"""
99

10-
from lib.core.common import getIdentifiedDBMS
10+
from lib.core.common import backend
1111
from lib.core.common import popValue
1212
from lib.core.common import pushValue
1313
from lib.core.data import conf
@@ -63,11 +63,11 @@ def setHandler():
6363
( SYBASE_ALIASES, SybaseMap, SybaseConn ),
6464
]
6565

66-
if getIdentifiedDBMS() is not None:
66+
if backend.getIdentifiedDbms() is not None:
6767
for i in xrange(len(dbmsObj)):
6868
dbmsAliases, _, _ = dbmsObj[i]
6969

70-
if getIdentifiedDBMS().lower() in dbmsAliases:
70+
if backend.getIdentifiedDbms().lower() in dbmsAliases:
7171
if i > 0:
7272
pushValue(dbmsObj[i])
7373
dbmsObj.remove(dbmsObj[i])
@@ -94,12 +94,12 @@ def setHandler():
9494
conf.dbmsConnector.connect()
9595

9696
if handler.checkDbms():
97-
kb.dbmsDetected = True
9897
conf.dbmsHandler = handler
9998

10099
break
101100
else:
102101
conf.dbmsConnector = None
103102

104-
# At this point proper back-end DBMS is fingerprinted (kb.dbms)
105-
kb.misc.forcedDbms = None
103+
# At this point back-end DBMS is correctly fingerprinted, no need
104+
# to enforce it anymore
105+
backend.flushForcedDbms()

lib/core/agent.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111

1212
from xml.etree import ElementTree as ET
1313

14+
from lib.core.common import backend
1415
from lib.core.common import getCompiledRegex
15-
from lib.core.common import getErrorParsedDBMSes
16-
from lib.core.common import getIdentifiedDBMS
1716
from lib.core.common import isDBMSVersionAtLeast
1817
from lib.core.common import isTechniqueAvailable
1918
from lib.core.common import randomInt
@@ -206,8 +205,8 @@ def cleanupPayload(self, payload, origvalue=None, query=None):
206205
payload = payload.replace("[ORIGVALUE]", origvalue)
207206

208207
if "[INFERENCE]" in payload:
209-
if getIdentifiedDBMS() is not None:
210-
inference = queries[getIdentifiedDBMS()].inference
208+
if backend.getIdentifiedDbms() is not None:
209+
inference = queries[backend.getIdentifiedDbms()].inference
211210

212211
if "dbms_version" in inference:
213212
if isDBMSVersionAtLeast(inference.dbms_version):
@@ -265,17 +264,17 @@ def nullAndCastField(self, field):
265264

266265
# SQLite version 2 does not support neither CAST() nor IFNULL(),
267266
# introduced only in SQLite version 3
268-
if getIdentifiedDBMS() == DBMS.SQLITE:
267+
if backend.getIdentifiedDbms() == DBMS.SQLITE:
269268
return field
270269

271270
if field.startswith("(CASE"):
272271
nulledCastedField = field
273272
else:
274-
nulledCastedField = queries[getIdentifiedDBMS()].cast.query % field
275-
if getIdentifiedDBMS() == DBMS.ACCESS:
276-
nulledCastedField = queries[getIdentifiedDBMS()].isnull.query % (nulledCastedField, nulledCastedField)
273+
nulledCastedField = queries[backend.getIdentifiedDbms()].cast.query % field
274+
if backend.getIdentifiedDbms() == DBMS.ACCESS:
275+
nulledCastedField = queries[backend.getIdentifiedDbms()].isnull.query % (nulledCastedField, nulledCastedField)
277276
else:
278-
nulledCastedField = queries[getIdentifiedDBMS()].isnull.query % nulledCastedField
277+
nulledCastedField = queries[backend.getIdentifiedDbms()].isnull.query % nulledCastedField
279278

280279
return nulledCastedField
281280

@@ -309,12 +308,12 @@ def nullCastConcatFields(self, fields):
309308
@rtype: C{str}
310309
"""
311310

312-
if not kb.dbmsDetected:
311+
if not backend.getDbms():
313312
return fields
314313

315314
fields = fields.replace(", ", ",")
316315
fieldsSplitted = fields.split(",")
317-
dbmsDelimiter = queries[getIdentifiedDBMS()].delimiter.query
316+
dbmsDelimiter = queries[backend.getIdentifiedDbms()].delimiter.query
318317
nulledCastedFields = []
319318

320319
for field in fieldsSplitted:
@@ -377,13 +376,13 @@ def getFields(self, query):
377376
def simpleConcatQuery(self, query1, query2):
378377
concatenatedQuery = ""
379378

380-
if getIdentifiedDBMS() == DBMS.MYSQL:
379+
if backend.getIdentifiedDbms() == DBMS.MYSQL:
381380
concatenatedQuery = "CONCAT(%s,%s)" % (query1, query2)
382381

383-
elif getIdentifiedDBMS() in (DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE):
382+
elif backend.getIdentifiedDbms() in (DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE):
384383
concatenatedQuery = "%s||%s" % (query1, query2)
385384

386-
elif getIdentifiedDBMS() in (DBMS.MSSQL, DBMS.SYBASE):
385+
elif backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
387386
concatenatedQuery = "%s+%s" % (query1, query2)
388387

389388
return concatenatedQuery
@@ -425,7 +424,7 @@ def concatQuery(self, query, unpack=True):
425424
concatenatedQuery = query
426425
fieldsSelectFrom, fieldsSelect, fieldsNoSelect, fieldsSelectTop, fieldsSelectCase, _, fieldsToCastStr, fieldsExists = self.getFields(query)
427426

428-
if getIdentifiedDBMS() == DBMS.MYSQL:
427+
if backend.getIdentifiedDbms() == DBMS.MYSQL:
429428
if fieldsExists:
430429
concatenatedQuery = concatenatedQuery.replace("SELECT ", "CONCAT('%s'," % kb.misc.start, 1)
431430
concatenatedQuery += ",'%s')" % kb.misc.stop
@@ -438,7 +437,7 @@ def concatQuery(self, query, unpack=True):
438437
elif fieldsNoSelect:
439438
concatenatedQuery = "CONCAT('%s',%s,'%s')" % (kb.misc.start, concatenatedQuery, kb.misc.stop)
440439

441-
elif getIdentifiedDBMS() in (DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE):
440+
elif backend.getIdentifiedDbms() in (DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE):
442441
if fieldsExists:
443442
concatenatedQuery = concatenatedQuery.replace("SELECT ", "'%s'||" % kb.misc.start, 1)
444443
concatenatedQuery += "||'%s'" % kb.misc.stop
@@ -451,10 +450,10 @@ def concatQuery(self, query, unpack=True):
451450
elif fieldsNoSelect:
452451
concatenatedQuery = "'%s'||%s||'%s'" % (kb.misc.start, concatenatedQuery, kb.misc.stop)
453452

454-
if getIdentifiedDBMS() == DBMS.ORACLE and " FROM " not in concatenatedQuery and (fieldsSelect or fieldsNoSelect):
453+
if backend.getIdentifiedDbms() == DBMS.ORACLE and " FROM " not in concatenatedQuery and (fieldsSelect or fieldsNoSelect):
455454
concatenatedQuery += " FROM DUAL"
456455

457-
elif getIdentifiedDBMS() in (DBMS.MSSQL, DBMS.SYBASE):
456+
elif backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
458457
if fieldsExists:
459458
concatenatedQuery = concatenatedQuery.replace("SELECT ", "'%s'+" % kb.misc.start, 1)
460459
concatenatedQuery += "+'%s'" % kb.misc.stop
@@ -520,8 +519,8 @@ def forgeInbandQuery(self, query, position, count, comment, prefix, suffix, char
520519
intoRegExp = intoRegExp.group(1)
521520
query = query[:query.index(intoRegExp)]
522521

523-
if getIdentifiedDBMS() in FROM_TABLE and inbandQuery.endswith(FROM_TABLE[getIdentifiedDBMS()]):
524-
inbandQuery = inbandQuery[:-len(FROM_TABLE[getIdentifiedDBMS()])]
522+
if backend.getIdentifiedDbms() in FROM_TABLE and inbandQuery.endswith(FROM_TABLE[backend.getIdentifiedDbms()]):
523+
inbandQuery = inbandQuery[:-len(FROM_TABLE[backend.getIdentifiedDbms()])]
525524

526525
for element in range(count):
527526
if element > 0:
@@ -540,9 +539,9 @@ def forgeInbandQuery(self, query, position, count, comment, prefix, suffix, char
540539
conditionIndex = query.index(" FROM ")
541540
inbandQuery += query[conditionIndex:]
542541

543-
if getIdentifiedDBMS() in FROM_TABLE:
542+
if backend.getIdentifiedDbms() in FROM_TABLE:
544543
if " FROM " not in inbandQuery:
545-
inbandQuery += FROM_TABLE[getIdentifiedDBMS()]
544+
inbandQuery += FROM_TABLE[backend.getIdentifiedDbms()]
546545

547546
if intoRegExp:
548547
inbandQuery += intoRegExp
@@ -559,8 +558,8 @@ def forgeInbandQuery(self, query, position, count, comment, prefix, suffix, char
559558
else:
560559
inbandQuery += char
561560

562-
if getIdentifiedDBMS() in FROM_TABLE:
563-
inbandQuery += FROM_TABLE[getIdentifiedDBMS()]
561+
if backend.getIdentifiedDbms() in FROM_TABLE:
562+
inbandQuery += FROM_TABLE[backend.getIdentifiedDbms()]
564563

565564
inbandQuery = self.suffixQuery(inbandQuery, comment, suffix)
566565

@@ -589,21 +588,21 @@ def limitQuery(self, num, query, field=None):
589588
"""
590589

591590
limitedQuery = query
592-
limitStr = queries[getIdentifiedDBMS()].limit.query
591+
limitStr = queries[backend.getIdentifiedDbms()].limit.query
593592
fromIndex = limitedQuery.index(" FROM ")
594593
untilFrom = limitedQuery[:fromIndex]
595594
fromFrom = limitedQuery[fromIndex+1:]
596595
orderBy = False
597596

598-
if getIdentifiedDBMS() in (DBMS.MYSQL, DBMS.PGSQL, DBMS.SQLITE):
599-
limitStr = queries[getIdentifiedDBMS()].limit.query % (num, 1)
597+
if backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL, DBMS.SQLITE):
598+
limitStr = queries[backend.getIdentifiedDbms()].limit.query % (num, 1)
600599
limitedQuery += " %s" % limitStr
601600

602-
elif getIdentifiedDBMS() == DBMS.FIREBIRD:
603-
limitStr = queries[getIdentifiedDBMS()].limit.query % (num+1, num+1)
601+
elif backend.getIdentifiedDbms() == DBMS.FIREBIRD:
602+
limitStr = queries[backend.getIdentifiedDbms()].limit.query % (num+1, num+1)
604603
limitedQuery += " %s" % limitStr
605604

606-
elif getIdentifiedDBMS() == DBMS.ORACLE:
605+
elif backend.getIdentifiedDbms() == DBMS.ORACLE:
607606
if " ORDER BY " in limitedQuery and "(SELECT " in limitedQuery:
608607
orderBy = limitedQuery[limitedQuery.index(" ORDER BY "):]
609608
limitedQuery = limitedQuery[:limitedQuery.index(" ORDER BY ")]
@@ -615,7 +614,7 @@ def limitQuery(self, num, query, field=None):
615614
limitedQuery = limitedQuery % fromFrom
616615
limitedQuery += "=%d" % (num + 1)
617616

618-
elif getIdentifiedDBMS() in (DBMS.MSSQL, DBMS.SYBASE):
617+
elif backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
619618
forgeNotIn = True
620619

621620
if " ORDER BY " in limitedQuery:
@@ -629,7 +628,7 @@ def limitQuery(self, num, query, field=None):
629628
limitedQuery = limitedQuery.replace("DISTINCT %s" % notDistinct, notDistinct)
630629

631630
if limitedQuery.startswith("SELECT TOP ") or limitedQuery.startswith("TOP "):
632-
topNums = re.search(queries[getIdentifiedDBMS()].limitregexp.query, limitedQuery, re.I)
631+
topNums = re.search(queries[backend.getIdentifiedDbms()].limitregexp.query, limitedQuery, re.I)
633632

634633
if topNums:
635634
topNums = topNums.groups()
@@ -675,8 +674,8 @@ def forgeCaseStatement(self, expression):
675674
@rtype: C{str}
676675
"""
677676

678-
if getIdentifiedDBMS() is not None and hasattr(queries[getIdentifiedDBMS()], "case"):
679-
return queries[getIdentifiedDBMS()].case.query % expression
677+
if backend.getIdentifiedDbms() is not None and hasattr(queries[backend.getIdentifiedDbms()], "case"):
678+
return queries[backend.getIdentifiedDbms()].case.query % expression
680679
else:
681680
return expression
682681

0 commit comments

Comments
 (0)