|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# |
| 3 | +# Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com |
| 4 | +# |
| 5 | +# This module is part of python-sqlparse and is released under |
| 6 | +# the BSD License: http://www.opensource.org/licenses/bsd-license.php |
| 7 | + |
| 8 | +from sqlparse import sql, tokens as T |
| 9 | + |
| 10 | + |
| 11 | +class AlignedIndentFilter(object): |
| 12 | + join_words = (r'((LEFT\s+|RIGHT\s+|FULL\s+)?' |
| 13 | + r'(INNER\s+|OUTER\s+|STRAIGHT\s+)?|' |
| 14 | + r'(CROSS\s+|NATURAL\s+)?)?JOIN\b') |
| 15 | + split_words = ('FROM', |
| 16 | + join_words, 'ON', |
| 17 | + 'WHERE', 'AND', 'OR', |
| 18 | + 'GROUP', 'HAVING', 'LIMIT', |
| 19 | + 'ORDER', 'UNION', 'VALUES', |
| 20 | + 'SET', 'BETWEEN', 'EXCEPT') |
| 21 | + |
| 22 | + def __init__(self, char=' ', line_width=None): |
| 23 | + self.char = char |
| 24 | + self._max_kwd_len = len('select') |
| 25 | + |
| 26 | + def newline(self): |
| 27 | + return sql.Token(T.Newline, '\n') |
| 28 | + |
| 29 | + def whitespace(self, chars=0, newline_before=False, newline_after=False): |
| 30 | + return sql.Token(T.Whitespace, ('\n' if newline_before else '') + |
| 31 | + self.char * chars + ('\n' if newline_after else '')) |
| 32 | + |
| 33 | + def _process_statement(self, tlist, base_indent=0): |
| 34 | + if tlist.tokens[0].is_whitespace() and base_indent == 0: |
| 35 | + tlist.tokens.pop(0) |
| 36 | + |
| 37 | + # process the main query body |
| 38 | + return self._process(sql.TokenList(tlist.tokens), |
| 39 | + base_indent=base_indent) |
| 40 | + |
| 41 | + def _process_parenthesis(self, tlist, base_indent=0): |
| 42 | + if not tlist.token_next_by(m=(T.DML, 'SELECT')): |
| 43 | + # if this isn't a subquery, don't re-indent |
| 44 | + return tlist |
| 45 | + |
| 46 | + # add two for the space and parens |
| 47 | + sub_indent = base_indent + self._max_kwd_len + 2 |
| 48 | + tlist.insert_after(tlist.tokens[0], |
| 49 | + self.whitespace(sub_indent, newline_before=True)) |
| 50 | + # de-indent the last parenthesis |
| 51 | + tlist.insert_before(tlist.tokens[-1], |
| 52 | + self.whitespace(sub_indent - 1, |
| 53 | + newline_before=True)) |
| 54 | + |
| 55 | + # process the inside of the parantheses |
| 56 | + tlist.tokens = ( |
| 57 | + [tlist.tokens[0]] + |
| 58 | + self._process(sql.TokenList(tlist._groupable_tokens), |
| 59 | + base_indent=sub_indent).tokens + |
| 60 | + [tlist.tokens[-1]] |
| 61 | + ) |
| 62 | + return tlist |
| 63 | + |
| 64 | + def _process_identifierlist(self, tlist, base_indent=0): |
| 65 | + # columns being selected |
| 66 | + new_tokens = [] |
| 67 | + identifiers = list(filter( |
| 68 | + lambda t: t.ttype not in (T.Punctuation, T.Whitespace, T.Newline), |
| 69 | + tlist.tokens)) |
| 70 | + for i, token in enumerate(identifiers): |
| 71 | + if i > 0: |
| 72 | + new_tokens.append(self.newline()) |
| 73 | + new_tokens.append( |
| 74 | + self.whitespace(self._max_kwd_len + base_indent + 1)) |
| 75 | + new_tokens.append(token) |
| 76 | + if i < len(identifiers) - 1: |
| 77 | + # if not last column in select, add a comma seperator |
| 78 | + new_tokens.append(sql.Token(T.Punctuation, ',')) |
| 79 | + tlist.tokens = new_tokens |
| 80 | + |
| 81 | + # process any sub-sub statements (like case statements) |
| 82 | + for sgroup in tlist.get_sublists(): |
| 83 | + self._process(sgroup, base_indent=base_indent) |
| 84 | + return tlist |
| 85 | + |
| 86 | + def _process_case(self, tlist, base_indent=0): |
| 87 | + base_offset = base_indent + self._max_kwd_len + len('case ') |
| 88 | + case_offset = len('when ') |
| 89 | + cases = tlist.get_cases(skip_ws=True) |
| 90 | + # align the end as well |
| 91 | + end_token = tlist.token_next_by(m=(T.Keyword, 'END')) |
| 92 | + cases.append((None, [end_token])) |
| 93 | + |
| 94 | + condition_width = max( |
| 95 | + len(' '.join(map(str, cond))) for cond, value in cases if cond) |
| 96 | + for i, (cond, value) in enumerate(cases): |
| 97 | + if cond is None: # else or end |
| 98 | + stmt = value[0] |
| 99 | + line = value |
| 100 | + else: |
| 101 | + stmt = cond[0] |
| 102 | + line = cond + value |
| 103 | + if i > 0: |
| 104 | + tlist.insert_before(stmt, self.whitespace( |
| 105 | + base_offset + case_offset - len(str(stmt)))) |
| 106 | + if cond: |
| 107 | + tlist.insert_after(cond[-1], self.whitespace( |
| 108 | + condition_width - len(' '.join(map(str, cond))))) |
| 109 | + |
| 110 | + if i < len(cases) - 1: |
| 111 | + # if not the END add a newline |
| 112 | + tlist.insert_after(line[-1], self.newline()) |
| 113 | + |
| 114 | + def _process_substatement(self, tlist, base_indent=0): |
| 115 | + def _next_token(i): |
| 116 | + t = tlist.token_next_by(m=(T.Keyword, self.split_words, True), |
| 117 | + idx=i) |
| 118 | + # treat "BETWEEN x and y" as a single statement |
| 119 | + if t and t.value.upper() == 'BETWEEN': |
| 120 | + t = _next_token(tlist.token_index(t) + 1) |
| 121 | + if t and t.value.upper() == 'AND': |
| 122 | + t = _next_token(tlist.token_index(t) + 1) |
| 123 | + return t |
| 124 | + |
| 125 | + idx = 0 |
| 126 | + token = _next_token(idx) |
| 127 | + while token: |
| 128 | + # joins are special case. only consider the first word as aligner |
| 129 | + if token.match(T.Keyword, self.join_words, regex=True): |
| 130 | + token_indent = len(token.value.split()[0]) |
| 131 | + else: |
| 132 | + token_indent = len(str(token)) |
| 133 | + tlist.insert_before(token, self.whitespace( |
| 134 | + self._max_kwd_len - token_indent + base_indent, |
| 135 | + newline_before=True)) |
| 136 | + next_idx = tlist.token_index(token) + 1 |
| 137 | + token = _next_token(next_idx) |
| 138 | + |
| 139 | + # process any sub-sub statements |
| 140 | + for sgroup in tlist.get_sublists(): |
| 141 | + prev_token = tlist.token_prev(tlist.token_index(sgroup)) |
| 142 | + indent_offset = 0 |
| 143 | + # HACK: make "group/order by" work. Longer than _max_kwd_len. |
| 144 | + if prev_token and prev_token.match(T.Keyword, 'BY'): |
| 145 | + # TODO: generalize this |
| 146 | + indent_offset = 3 |
| 147 | + self._process(sgroup, base_indent=base_indent + indent_offset) |
| 148 | + return tlist |
| 149 | + |
| 150 | + def _process(self, tlist, base_indent=0): |
| 151 | + token_name = tlist.__class__.__name__.lower() |
| 152 | + func_name = '_process_%s' % token_name |
| 153 | + func = getattr(self, func_name, self._process_substatement) |
| 154 | + return func(tlist, base_indent=base_indent) |
| 155 | + |
| 156 | + def process(self, stmt): |
| 157 | + self._process(stmt) |
0 commit comments