This repository was archived by the owner on Jan 3, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathvexecute.py
More file actions
308 lines (249 loc) · 10.3 KB
/
Copy pathvexecute.py
File metadata and controls
308 lines (249 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import logging
import socket
import sys
import sqlparse
import vertica_python as vertica
from sqlparse.tokens import Token as _Token
from sqlparse.sql import Token
from .packages import vspecial as special
from .encodingutils import PY2
_logger = logging.getLogger(__name__)
class VExecute(object):
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
search_path_query = '''
SELECT current_schemas(true)'''
schemata_query = '''
SELECT schema_name
FROM v_catalog.schemata
ORDER BY 1'''
tables_query = '''
SELECT table_schema, table_name
FROM v_catalog.tables
ORDER BY 1, 2'''
views_query = '''
SELECT table_schema, table_name
FROM v_catalog.views
ORDER BY 1, 2'''
table_columns_query = '''
SELECT table_schema, table_name, column_name
FROM v_catalog.columns
ORDER BY 1, 2, 3'''
view_columns_query = '''
SELECT table_schema, table_name, column_name
FROM v_catalog.view_columns
ORDER BY 1, 2, 3'''
functions_query = '''
SELECT schema_name, function_name
FROM v_catalog.user_functions
WHERE schema_name NOT IN ('v_catalog', 'v_monitor', 'v_internal')
ORDER BY 1, 2'''
databases_query = '''
SELECT database_name, owner_id, 'UTF8' AS encoding,
'en_US.utf8' AS collate, 'en_US.utf8' AS ctype
FROM v_catalog.databases
ORDER BY 1'''
datatypes_query = '''
SELECT schema_name, type_name
FROM v_catalog.types, v_catalog.schemata
WHERE schema_name NOT IN ('v_catalog', 'v_monitor', 'v_internal')
ORDER BY 1, 2'''
def __init__(self, database, user, password, host, port):
self.dbname = database
self.user = user
self.password = password
self.host = host
self.port = port
self.connect()
def connect(self, database=None, user=None, password=None, host=None,
port=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
host = (host or self.host)
port = (port or self.port)
conn = vertica.connect(database=db, user=user, password=password,
host=host, port=int(port))
# Print notice message for PROFILE (#42)
def print_notice(message):
print('%(Severity)s: %(Message)s' % message.values)
hint = message.values.get('Hint')
if hint:
print('HINT: ' + hint)
conn.notice_handler = print_notice
# HACK: Modify vertica_python's connection socket to do keep alive
# TODO: Keep alive for Windows and other platforms
# http://stackoverflow.com/questions/12248132/how-to-change-tcp-keepalive-timer-using-python-script
sock = conn._socket()
if sys.platform == 'darwin': # Mac OS X
tcp_keepalive = 0x10
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, tcp_keepalive, 60)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
elif sys.platform.startswith('linux'):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
if hasattr(self, 'conn'):
self.conn.close()
self.conn = conn
# self.conn.autocommit = True
self.dbname = db
self.user = user
self.password = password
self.host = host
self.port = port
# register_json_typecasters(self.conn, self._json_typecaster)
# register_hstore_typecaster(self.conn)
def _json_typecaster(self, json_data):
"""Interpret incoming JSON data as a string.
The raw data is decoded using the connection's encoding, which defaults
to the database's encoding.
See http://initd.org/psycopg/docs/connection.html#connection.encoding
"""
if PY2:
return json_data.decode(self.conn.encoding)
else:
return json_data
def run(self, statement, vspecial=None):
"""Execute the sql in the database and return the results.
:param statement: A string containing one or more sql statements
:param vspecial: VSpecial object
:return: List of tuples containing (title, rows, headers, status,
force_stdout)
"""
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
yield (None, None, None, None, True)
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(';')
if vspecial:
# First try to run each query as special
try:
_logger.debug('Trying a vspecial command. sql: %r', sql)
cur = self.conn.cursor()
for result in vspecial.execute(cur, sql):
yield result
return
except special.CommandNotFound:
pass
yield self.execute_normal_sql(sql)
def execute_normal_sql(self, split_sql):
_logger.debug('Regular sql statement. sql: %r', split_sql)
cur = self.conn.cursor()
try:
tree = sqlparse.parse(split_sql)[0]
except IndexError:
return (None, None, None, None, True)
if _is_copy_from_local(tree):
_execute_copy_from_local_sql(tree, cur)
else:
cur.execute(split_sql)
title = None
statusmessage = None
first_token = split_sql.split()[0].lower()
if cur.description and first_token in ('select', 'update', 'delete', 'with',
'insert', 'explain', 'profile'):
headers = [x[0] for x in cur.description]
return (title, cur, headers, statusmessage, False)
else:
_logger.debug('No rows in result.')
return (title, None, None, statusmessage, True)
def search_path(self):
"""Returns the current search path as a list of schema names"""
with self.conn.cursor() as cur:
_logger.debug('Search path query. sql: %r', self.search_path_query)
cur.execute(self.search_path_query)
names = cur.fetchone()[0]
return names.split(b',')
def schemata(self):
"""Returns a list of schema names in the database"""
with self.conn.cursor() as cur:
_logger.debug('Schemata Query. sql: %r', self.schemata_query)
cur.execute(self.schemata_query)
return (x[0] for x in cur.fetchall())
def tables(self):
"""Yields (schema_name, table_name) tuples"""
with self.conn.cursor() as cur:
_logger.debug('Tables Query. sql: %r', self.tables_query)
cur.execute(self.tables_query)
for row in cur.iterate():
yield tuple(row)
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
with self.conn.cursor() as cur:
_logger.debug('Views Query. sql: %r', self.views_query)
cur.execute(self.views_query)
for row in cur.iterate():
yield tuple(row)
def table_columns(self):
with self.conn.cursor() as cur:
_logger.debug('Columns Query. sql: %r', self.table_columns_query)
cur.execute(self.table_columns_query)
for row in cur.iterate():
yield tuple(row)
def view_columns(self):
with self.conn.cursor() as cur:
_logger.debug('Columns Query. sql: %r', self.view_columns_query)
cur.execute(self.view_columns_query)
for row in cur.iterate():
yield tuple(row)
def databases(self):
with self.conn.cursor() as cur:
_logger.debug('Databases Query. sql: %r', self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]
def functions(self):
"""Yields tuples of (schema_name, function_name)"""
with self.conn.cursor() as cur:
_logger.debug('Functions Query. sql: %r', self.functions_query)
cur.execute(self.functions_query)
for row in cur.iterate():
yield tuple(row)
def datatypes(self):
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
_logger.debug('Datatypes Query. sql: %r', self.datatypes_query)
cur.execute(self.datatypes_query)
for row in cur.iterate():
yield tuple(row)
def _is_copy_from_local(sql_tree):
first_token = sql_tree.tokens[0]
if not (first_token.is_keyword and first_token.value.lower() == 'copy'):
return False
# Search for 'LOCAL' keyword
found = False
for i, token in enumerate(sql_tree.tokens):
if token.is_keyword and token.value.lower() == 'local':
found = True
break
if not found:
return False
# After 'LCOAL', there should be a whitespace then a file path
try:
token = sql_tree.tokens[i + 2]
except IndexError:
return False
return token.ttype is _Token.Literal.String.Single
def _execute_copy_from_local_sql(sql_tree, cursor):
# Search for 'LOCAL' keyword
for i, token in enumerate(sql_tree.tokens):
if token.is_keyword and token.value.lower() == 'local':
break
file_path = sql_tree.tokens[i + 2].value.strip('\'"')
# Replace "LOCAL <file_path>" with "stdin"
sql_tree.tokens = sql_tree.tokens[0:i] + [
Token(_Token.Keyword, 'stdin')
] + sql_tree.tokens[i + 3:]
new_sql = sql_tree.to_unicode()
cursor.flush_to_query_ready()
with open(file_path, 'rb') as f:
cursor.copy(new_sql, f)
cursor.flush_to_query_ready()