|
2 | 2 |
|
3 | 3 | import re |
4 | 4 |
|
5 | | -from sqlparse import tokens as T |
| 5 | +from os.path import abspath, join |
| 6 | + |
6 | 7 | from sqlparse import sql |
| 8 | +from sqlparse import tokens as T |
| 9 | +from sqlparse.engine import FilterStack |
| 10 | +from sqlparse.tokens import ( |
| 11 | + Comment, Keyword, Name, |
| 12 | + Punctuation, String, Whitespace, |
| 13 | +) |
7 | 14 |
|
8 | 15 |
|
9 | 16 | class Filter(object): |
@@ -52,6 +59,81 @@ def process(self, stack, stream): |
52 | 59 | yield ttype, value |
53 | 60 |
|
54 | 61 |
|
| 62 | +class GetComments(Filter): |
| 63 | + """Get the comments from a stack""" |
| 64 | + def process(self, stack, stream): |
| 65 | + for token_type, value in stream: |
| 66 | + if token_type in Comment: |
| 67 | + yield token_type, value |
| 68 | + |
| 69 | + |
| 70 | +class StripComments(Filter): |
| 71 | + """Strip the comments from a stack""" |
| 72 | + def process(self, stack, stream): |
| 73 | + for token_type, value in stream: |
| 74 | + if token_type not in Comment: |
| 75 | + yield token_type, value |
| 76 | + |
| 77 | + |
| 78 | +class IncludeStatement(Filter): |
| 79 | + """Filter that enable a INCLUDE statement""" |
| 80 | + |
| 81 | + def __init__(self, dirpath=".", maxRecursive=10): |
| 82 | + self.dirpath = abspath(dirpath) |
| 83 | + self.maxRecursive = maxRecursive |
| 84 | + |
| 85 | + self.detected = False |
| 86 | + |
| 87 | + def process(self, stack, stream): |
| 88 | + # Run over all tokens in the stream |
| 89 | + for token_type, value in stream: |
| 90 | + # INCLUDE statement found, set detected mode |
| 91 | + if token_type in Name and value.upper() == 'INCLUDE': |
| 92 | + self.detected = True |
| 93 | + continue |
| 94 | + |
| 95 | + # INCLUDE statement was found, parse it |
| 96 | + elif self.detected: |
| 97 | + # Omit whitespaces |
| 98 | + if token_type in Whitespace: |
| 99 | + pass |
| 100 | + |
| 101 | + # Get path of file to include |
| 102 | + path = None |
| 103 | + |
| 104 | + if token_type in String.Symbol: |
| 105 | +# if token_type in tokens.String.Symbol: |
| 106 | + path = join(self.dirpath, value[1:-1]) |
| 107 | + |
| 108 | + # Include file if path was found |
| 109 | + if path: |
| 110 | + try: |
| 111 | + f = open(path) |
| 112 | + raw_sql = f.read() |
| 113 | + f.close() |
| 114 | + except IOError, err: |
| 115 | + yield Comment, u'-- IOError: %s\n' % err |
| 116 | + |
| 117 | + else: |
| 118 | + # Create new FilterStack to parse readed file |
| 119 | + # and add all its tokens to the main stack recursively |
| 120 | + # [ToDo] Add maximum recursive iteration value |
| 121 | + stack = FilterStack() |
| 122 | + stack.preprocess.append(IncludeStatement(self.dirpath)) |
| 123 | + |
| 124 | + for tv in stack.run(raw_sql): |
| 125 | + yield tv |
| 126 | + |
| 127 | + # Set normal mode |
| 128 | + self.detected = False |
| 129 | + |
| 130 | + # Don't include any token while in detected mode |
| 131 | + continue |
| 132 | + |
| 133 | + # Normal token |
| 134 | + yield token_type, value |
| 135 | + |
| 136 | + |
55 | 137 | # ---------------------- |
56 | 138 | # statement process |
57 | 139 |
|
@@ -146,13 +228,14 @@ def _split_kwds(self, tlist): |
146 | 228 | split_words = ('FROM', 'JOIN$', 'AND', 'OR', |
147 | 229 | 'GROUP', 'ORDER', 'UNION', 'VALUES', |
148 | 230 | 'SET', 'BETWEEN') |
| 231 | + |
149 | 232 | def _next_token(i): |
150 | 233 | t = tlist.token_next_match(i, T.Keyword, split_words, |
151 | 234 | regex=True) |
152 | 235 | if t and t.value.upper() == 'BETWEEN': |
153 | | - t = _next_token(tlist.token_index(t)+1) |
| 236 | + t = _next_token(tlist.token_index(t) + 1) |
154 | 237 | if t and t.value.upper() == 'AND': |
155 | | - t = _next_token(tlist.token_index(t)+1) |
| 238 | + t = _next_token(tlist.token_index(t) + 1) |
156 | 239 | return t |
157 | 240 |
|
158 | 241 | idx = 0 |
@@ -316,6 +399,57 @@ def process(self, stack, group): |
316 | 399 | group.tokens = self._process(stack, group, group.tokens) |
317 | 400 |
|
318 | 401 |
|
| 402 | +class ColumnsSelect(Filter): |
| 403 | + """Get the columns names of a SELECT query""" |
| 404 | + def process(self, stack, stream): |
| 405 | + mode = 0 |
| 406 | + oldValue = "" |
| 407 | + parenthesis = 0 |
| 408 | + |
| 409 | + for token_type, value in stream: |
| 410 | + # Ignore comments |
| 411 | + if token_type in Comment: |
| 412 | + continue |
| 413 | + |
| 414 | + # We have not detected a SELECT statement |
| 415 | + if mode == 0: |
| 416 | + if token_type in Keyword and value == 'SELECT': |
| 417 | + mode = 1 |
| 418 | + |
| 419 | + # We have detected a SELECT statement |
| 420 | + elif mode == 1: |
| 421 | + if value == 'FROM': |
| 422 | + if oldValue: |
| 423 | + yield Name, oldValue |
| 424 | + |
| 425 | + mode = 3 # Columns have been checked |
| 426 | + |
| 427 | + elif value == 'AS': |
| 428 | + oldValue = "" |
| 429 | + mode = 2 |
| 430 | + |
| 431 | + elif (token_type == Punctuation |
| 432 | + and value == ',' and not parenthesis): |
| 433 | + if oldValue: |
| 434 | + yield Name, oldValue |
| 435 | + oldValue = "" |
| 436 | + |
| 437 | + elif token_type not in Whitespace: |
| 438 | + if value == '(': |
| 439 | + parenthesis += 1 |
| 440 | + elif value == ')': |
| 441 | + parenthesis -= 1 |
| 442 | + |
| 443 | + oldValue += value |
| 444 | + |
| 445 | + # We are processing an AS keyword |
| 446 | + elif mode == 2: |
| 447 | + # We check also for Keywords because a bug in SQLParse |
| 448 | + if token_type == Name or token_type == Keyword: |
| 449 | + yield Name, value |
| 450 | + mode = 1 |
| 451 | + |
| 452 | + |
319 | 453 | # --------------------------- |
320 | 454 | # postprocess |
321 | 455 |
|
@@ -422,3 +556,24 @@ def process(self, stack, stmt): |
422 | 556 | varname = self.varname |
423 | 557 | stmt.tokens = tuple(self._process(stmt.tokens, varname)) |
424 | 558 | return stmt |
| 559 | + |
| 560 | + |
| 561 | +class Limit(Filter): |
| 562 | + """Get the LIMIT of a query. |
| 563 | +
|
| 564 | + If not defined, return -1 (SQL specification for no LIMIT query) |
| 565 | + """ |
| 566 | + def process(self, stack, stream): |
| 567 | + index = 7 |
| 568 | + stream = list(stream) |
| 569 | + stream.reverse() |
| 570 | + |
| 571 | + # Run over all tokens in the stream from the end |
| 572 | + for token_type, value in stream: |
| 573 | + index -= 1 |
| 574 | + |
| 575 | +# if index and token_type in Keyword: |
| 576 | + if index and token_type in Keyword and value == 'LIMIT': |
| 577 | + return stream[4 - index][1] |
| 578 | + |
| 579 | + return -1 |
0 commit comments