diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c2a19417..bafb696a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -17,12 +17,22 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - name: Set up a Vertica server + - name: Set up a Keycloak docker container + timeout-minutes: 5 + run: | + docker network create -d bridge my-network + docker run -d -p 8080:8080 \ + --name keycloak --network my-network \ + -e KEYCLOAK_ADMIN=admin -e KEYCLOAK_ADMIN_PASSWORD=admin \ + quay.io/keycloak/keycloak:21.0.1 start-dev + docker container ls + + - name: Set up a Vertica server docker container timeout-minutes: 15 run: | docker run -d -p 5433:5433 -p 5444:5444 \ - --name vertica_docker \ - vertica/vertica-ce:12.0.4-0 + --name vertica_docker --network my-network \ + vertica/vertica-ce:23.3.0-0 echo "Vertica startup ..." until docker exec vertica_docker test -f /data/vertica/VMart/agent_start.out; do \ echo "..."; \ @@ -31,9 +41,59 @@ jobs: echo "Vertica is up" docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "\l" docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "select version()" + + - name: Configure Keycloak + run: | + echo "Wait for keycloak ready ..." + bash -c 'while true; do curl -s localhost:8080 &>/dev/null; ret=$?; [[ $ret -eq 0 ]] && break; echo "..."; sleep 3; done' + + REALM="test" + USER="oauth_user" + PASSWORD="password" + CLIENT_ID="vertica" + CLIENT_SECRET="P9f8350QQIUhFfK1GF5sMhq4Dm3P6Sbs" + + docker exec -i keycloak /bin/bash < access_token.txt + + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "CREATE AUTHENTICATION v_oauth METHOD 'oauth' HOST '0.0.0.0/0';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "ALTER AUTHENTICATION v_oauth SET client_id = '${CLIENT_ID}';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "ALTER AUTHENTICATION v_oauth SET client_secret = '${CLIENT_SECRET}';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "ALTER AUTHENTICATION v_oauth SET discovery_url = 'http://`hostname`:8080/realms/${REALM}/.well-known/openid-configuration';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "ALTER AUTHENTICATION v_oauth SET introspect_url = 'http://`hostname`:8080/realms/${REALM}/protocol/openid-connect/token/introspect';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "SELECT * FROM client_auth WHERE auth_name='v_oauth';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "CREATE USER ${USER};" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "GRANT AUTHENTICATION v_oauth TO ${USER};" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "GRANT ALL ON SCHEMA PUBLIC TO ${USER};" + # A dbadmin-specific authentication record (connect remotely) is needed after setting up an OAuth user + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "CREATE AUTHENTICATION v_dbadmin_hash METHOD 'hash' HOST '0.0.0.0/0';" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "ALTER AUTHENTICATION v_dbadmin_hash PRIORITY 10000;" + docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "GRANT AUTHENTICATION v_dbadmin_hash TO dbadmin;" + - name: Install dependencies run: pip install tox - name: Run tests run: | export VP_TEST_USER=dbadmin + export VP_TEST_OAUTH_ACCESS_TOKEN=`cat access_token.txt` + export VP_TEST_OAUTH_USER=oauth_user tox -e py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9072a7fd..a66c9844 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,7 @@ This document will guide you through the contribution process. There are a numbe If you find a bug, submit an [issue](https://github.com/vertica/vertica-python/issues) with a complete and reproducible bug report. If the issue can't be reproduced, it will be closed. If you opened an issue, but figured out the answer later on your own, comment on the issue to let people know, then close the issue. -For issues (e.g. security related issues) that are **not suitable** to be reported publicly on the GitHub issue system, report your issues to [Vertica open source team](mailto:vertica-opensrc@microfocus.com) directly or file a case with Vertica support if you have a support account. +For issues (e.g. security related issues) that are **not suitable** to be reported publicly on the GitHub issue system, report your issues to [Vertica open source team](mailto:vertica-opensrc@opentext.com) directly or file a case with Vertica support if you have a support account. # Feature Requests diff --git a/README.md b/README.md index 021822ea..0e175a07 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # vertica-python -[![PyPI version](https://badge.fury.io/py/vertica-python.svg)](https://badge.fury.io/py/vertica-python) +[![PyPI version](https://img.shields.io/pypi/v/vertica-python?color=brightgreen&label=PyPI%20package)](https://pypi.org/project/vertica-python/) +[![Conda Version](https://img.shields.io/conda/vn/conda-forge/vertica-python?color=yellowgreen)](https://anaconda.org/conda-forge/vertica-python) [![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://opensource.org/licenses/Apache-2.0) [![Python Version](https://img.shields.io/pypi/pyversions/vertica-python.svg)](https://www.python.org/downloads/) [![Downloads](https://pepy.tech/badge/vertica-python/week)](https://pepy.tech/project/vertica-python) @@ -11,7 +12,7 @@ Please check out [release notes](https://github.com/vertica/vertica-python/releases) to learn about the latest improvements. -vertica-python has been tested with Vertica 12.0.4 and Python 3.7/3.8/3.9/3.10/3.11. Feel free to submit issues and/or pull requests (Read up on our [contributing guidelines](#contributing-guidelines)). +vertica-python has been tested with Vertica 23.3.0 and Python 3.7/3.8/3.9/3.10/3.11. Feel free to submit issues and/or pull requests (Read up on our [contributing guidelines](#contributing-guidelines)). ## Installation @@ -102,12 +103,14 @@ with vertica_python.connect(**conn_info) as connection: | kerberos_service_name | See [Kerberos Authentication](#kerberos-authentication).
**_Default_**: "vertica" | | log_level | See [Logging](#logging). | | log_path | See [Logging](#logging). | +| oauth_access_token | To authenticate via OAuth, provide an OAuth Access Token that authorizes a user to the database.
**_Default_**: "" | +| request_complex_types | See [SQL Data conversion to Python objects](#sql-data-conversion-to-python-objects).
**_Default_**: True | | session_label | Sets a label for the connection on the server. This value appears in the client_label column of the _v_monitor.sessions_ system table.
**_Default_**: an auto-generated label with format of `vertica-python-{version}-{random_uuid}` | | ssl | See [TLS/SSL](#tlsssl).
**_Default_**: False (disabled) | | unicode_error | See [UTF-8 encoding issues](#utf-8-encoding-issues).
**_Default_**: 'strict' (throw error on invalid UTF-8 results) | | use_prepared_statements | See [Passing parameters to SQL queries](#passing-parameters-to-sql-queries).
**_Default_**: False | +| workload | Sets the workload name associated with this session. Valid values are workload names that already exist in a workload routing rule on the server. If a workload name that doesn't exist is entered, the server will reject it and it will be set to the default.
**_Default_**: "" | | dsn | See [Set Properties with Connection String](#set-properties-with-connection-string). | -| request_complex_types | See [SQL Data conversion to Python objects](#sql-data-conversion-to-python-objects).
**_Default_**: True | Below are a few important connection topics you may deal with, or you can skip and jump to the next section: [Send Queries and Retrieve Results](#send-queries-and-retrieve-results) @@ -138,16 +141,65 @@ with vertica_python.connect(dsn=connection_str, **additional_info) as conn: ``` #### TLS/SSL -You can pass an `ssl.SSLContext` to `ssl` to customize the SSL connection options. For example, +You can pass `True` to `ssl` to enable TLS/SSL connection (Internally [ssl.wrap_socket(sock)](https://docs.python.org/3/library/ssl.html#ssl.wrap_socket) is called). + +```python +import vertica_python + +# [TLSMode: require] +conn_info = {'host': '127.0.0.1', + 'port': 5433, + 'user': 'some_user', + 'password': 'some_password', + 'database': 'a_database', + 'ssl': True} +connection = vertica_python.connect(**conn_info) +``` + +You can pass an `ssl.SSLContext` to `ssl` to customize the SSL connection options. Server mode TLS examples: ```python import vertica_python import ssl -ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) +# [TLSMode: require] +# Ensure connection is encrypted. +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + +conn_info = {'host': '127.0.0.1', + 'port': 5433, + 'user': 'some_user', + 'password': 'some_password', + 'database': 'a_database', + 'ssl': ssl_context} +connection = vertica_python.connect(**conn_info) + + +# [TLSMode: verify-ca] +# Ensure connection is encrypted, and client trusts server certificate. +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.verify_mode = ssl.CERT_REQUIRED +ssl_context.check_hostname = False +ssl_context.load_verify_locations(cafile='/path/to/ca_file.pem') # CA certificate used to verify server certificate + +conn_info = {'host': '127.0.0.1', + 'port': 5433, + 'user': 'some_user', + 'password': 'some_password', + 'database': 'a_database', + 'ssl': ssl_context} +connection = vertica_python.connect(**conn_info) + + +# [TLSMode: verify-full] +# Ensure connection is encrypted, client trusts server certificate, +# and server hostname matches the one listed in the server certificate. +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = True -ssl_context.load_verify_locations(cafile='/path/to/ca_file.pem') +ssl_context.load_verify_locations(cafile='/path/to/ca_file.pem') # CA certificate used to verify server certificate conn_info = {'host': '127.0.0.1', 'port': 5433, @@ -156,7 +208,32 @@ conn_info = {'host': '127.0.0.1', 'database': 'a_database', 'ssl': ssl_context} connection = vertica_python.connect(**conn_info) +``` +Mutual mode TLS example: +```python +import vertica_python +import ssl + +# [TLSMode: verify-full] +# Ensure connection is encrypted, client trusts server certificate, +# and server hostname matches the one listed in the server certificate. +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.verify_mode = ssl.CERT_REQUIRED +ssl_context.check_hostname = True +ssl_context.load_verify_locations(cafile='/path/to/ca_file.pem') # CA certificate used to verify server certificate + +# For Mutual mode, provide client certificate and client private key to ssl_context. +# CA certificate used to verify client certificate should be set at the server side. +ssl_context.load_cert_chain(certfile='/path/to/client.pem', keyfile='/path/to/client.key') + +conn_info = {'host': '127.0.0.1', + 'port': 5433, + 'user': 'some_user', + 'password': 'some_password', + 'database': 'a_database', + 'ssl': ssl_context} +connection = vertica_python.connect(**conn_info) ``` See more on SSL options [here](https://docs.python.org/3/library/ssl.html). @@ -576,6 +653,10 @@ cur.execute("INSERT INTO table VALUES (%s, %s)", [100, value], use_prepared_stat cur.execute("INSERT INTO table VALUES (%s, %s::ARRAY[DATE])", [100, value], use_prepared_statements=False) # correct # converted into a SQL command: INSERT INTO vptest VALUES (100, ARRAY['2021-06-10','2021-06-12','2021-06-30']::ARRAY[DATE]) + +# Client-side binding of cursor.executemany is different from cursor.execute internally +# But it also supports some of complex types mapping +cur.executemany("INSERT INTO table (a, b) VALUES (%s, %s)", [[100, value]], use_prepared_statements=False) ``` ##### Register new SQL literal adapters @@ -746,6 +827,13 @@ with vertica_python.connect(**conn_info) as connection: print("Rows loaded 1:", cur.fetchall()) cur.nextset() print("Rows loaded 2:", cur.fetchall()) + + # Copy from local stdin (StringIO) + from io import StringIO + data = "Anna|123-456-789\nBrown|555-444-3333\nCindy|555-867-53093453453\nDodd|123-456-789\nEd|123-456-789" + cur.execute("COPY customers (firstNames, phoneNumbers) FROM LOCAL STDIN ENFORCELENGTH RETURNREJECTED AUTO", + copy_stdin=StringIO(data)) + ``` When connection option `disable_copy_local` set to True, disables COPY LOCAL operations, including copying data from local files/stdin and using local files to store data and exceptions. You can use this property to prevent users from writing to and copying from files on a Vertica host, including an MC host. Note that this property doesn't apply to `Cursor.copy()`. diff --git a/setup.py b/setup.py index 9c5d84ed..46b1b9f8 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ # version should use the format 'x.x.x' (instead of 'vx.x.x') setup( name='vertica-python', - version='1.3.2', + version='1.3.3', description='Official native Python client for the Vertica database.', long_description="vertica-python is the official Vertica database client for the Python programming language. Please check the [project homepage](https://github.com/vertica/vertica-python) for the details.", long_description_content_type='text/markdown', diff --git a/vertica_python/__init__.py b/vertica_python/__init__.py index 44a9e75e..e4277e3f 100644 --- a/vertica_python/__init__.py +++ b/vertica_python/__init__.py @@ -56,11 +56,11 @@ 'OperationalError', 'ProgrammingError'] # The version number of this library. -version_info = (1, 3, 2) +version_info = (1, 3, 3) __version__ = '.'.join(map(str, version_info)) -# The protocol version (3.14) implemented in this library. -PROTOCOL_VERSION = 3 << 16 | 14 +# The protocol version (3.15) implemented in this library. +PROTOCOL_VERSION = 3 << 16 | 15 apilevel = 2.0 threadsafety = 1 # Threads may share the module, but not connections! diff --git a/vertica_python/tests/common/base.py b/vertica_python/tests/common/base.py index a5b978dd..597c03c3 100644 --- a/vertica_python/tests/common/base.py +++ b/vertica_python/tests/common/base.py @@ -54,6 +54,9 @@ 'port': 5433, 'user': getpass.getuser(), 'password': '', + 'database': '', + 'oauth_access_token': '', + 'oauth_user': '', } @@ -68,8 +71,7 @@ def _load_test_config(cls, config_list): # load default configurations for key in config_list: - if key != 'database': - test_config[key] = default_configs[key] + test_config[key] = default_configs[key] # override with the configuration file confparser = ConfigParser() @@ -94,8 +96,6 @@ def _load_test_config(cls, config_list): # value is string when loaded from configuration file and environment variable if 'port' in test_config: test_config['port'] = int(test_config['port']) - if 'database' in config_list and 'user' in test_config: - test_config.setdefault('database', test_config['user']) if 'log_level' in test_config: levels = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] if isinstance(test_config['log_level'], str): diff --git a/vertica_python/tests/common/vp_test.conf.example b/vertica_python/tests/common/vp_test.conf.example index 285c8c53..9ca6d174 100644 --- a/vertica_python/tests/common/vp_test.conf.example +++ b/vertica_python/tests/common/vp_test.conf.example @@ -14,3 +14,8 @@ VP_TEST_USER=dbadmin # Valid VP_TEST_LOG_LEVEL options: 'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' VP_TEST_LOG_LEVEL=DEBUG VP_TEST_LOG_DIR=mylog/vp_tox_tests_log + +# OAuth authentication information +#VP_TEST_OAUTH_USER= +#VP_TEST_OAUTH_ACCESS_TOKEN=****** + diff --git a/vertica_python/tests/integration_tests/base.py b/vertica_python/tests/integration_tests/base.py index 7cf25328..72cd39a3 100644 --- a/vertica_python/tests/integration_tests/base.py +++ b/vertica_python/tests/integration_tests/base.py @@ -54,7 +54,8 @@ class VerticaPythonIntegrationTestCase(VerticaPythonTestCase): @classmethod def setUpClass(cls): config_list = ['log_dir', 'log_level', 'host', 'port', - 'user', 'password', 'database'] + 'user', 'password', 'database', + 'oauth_access_token', 'oauth_user',] cls.test_config = cls._load_test_config(config_list) # Test logger @@ -73,6 +74,10 @@ def setUpClass(cls): 'log_level': cls.test_config['log_level'], 'log_path': logfile, } + cls._oauth_info = { + 'access_token': cls.test_config['oauth_access_token'], + 'user': cls.test_config['oauth_user'], + } cls.db_node_num = cls._get_node_num() cls.logger.info("Number of database node(s) = {}".format(cls.db_node_num)) diff --git a/vertica_python/tests/integration_tests/test_authentication.py b/vertica_python/tests/integration_tests/test_authentication.py index 4aca45ce..c63f05e0 100644 --- a/vertica_python/tests/integration_tests/test_authentication.py +++ b/vertica_python/tests/integration_tests/test_authentication.py @@ -26,6 +26,8 @@ def setUp(self): def tearDown(self): self._conn_info['user'] = self._user self._conn_info['password'] = self._password + if 'oauth_access_token' in self._conn_info: + del self._conn_info['oauth_access_token'] super(AuthenticationTestCase, self).tearDown() def test_SHA512(self): @@ -107,5 +109,18 @@ def test_password_expire(self): cur.execute("DROP AUTHENTICATION IF EXISTS testIPv6hostHash CASCADE") cur.execute("DROP AUTHENTICATION IF EXISTS testlocalHash CASCADE") + def test_oauth(self): + self.require_protocol_at_least(3 << 16 | 11) + if not self._oauth_info['access_token']: + self.skipTest('OAuth not set') + + self._conn_info['user'] = self._oauth_info['user'] + self._conn_info['oauth_access_token'] = self._oauth_info['access_token'] + with self._connect() as conn: + cur = conn.cursor() + cur.execute("SELECT authentication_method FROM sessions WHERE session_id=(SELECT current_session())") + res = cur.fetchone() + self.assertEqual(res[0], 'OAuth') + exec(AuthenticationTestCase.createPrepStmtClass()) diff --git a/vertica_python/tests/integration_tests/test_connection.py b/vertica_python/tests/integration_tests/test_connection.py index a286b6ac..de23069f 100644 --- a/vertica_python/tests/integration_tests/test_connection.py +++ b/vertica_python/tests/integration_tests/test_connection.py @@ -36,6 +36,7 @@ from __future__ import print_function, division, absolute_import import getpass +import socket import uuid from .base import VerticaPythonIntegrationTestCase @@ -48,9 +49,14 @@ def tearDown(self): del self._conn_info['session_label'] if 'autocommit' in self._conn_info: del self._conn_info['autocommit'] + if 'workload' in self._conn_info: + del self._conn_info['workload'] def test_client_os_user_name_metadata(self): - value = getpass.getuser() + try: + value = getpass.getuser() + except Exception as e: + value = '' # Metadata client_os_user_name sent from client should be captured into system tables query = 'SELECT client_os_user_name FROM v_monitor.current_session' @@ -69,6 +75,30 @@ def test_client_os_user_name_metadata(self): res = self._query_and_fetchone(query) self.assertEqual(res[0], value) + def test_client_os_hostname_metadata(self): + self.require_protocol_at_least(3 << 16 | 14) + try: + value = socket.gethostname() + except Exception as e: + value = '' + + # Metadata client_os_hostname sent from client should be captured into system tables + query = 'SELECT client_os_hostname FROM v_monitor.current_session' + res = self._query_and_fetchone(query) + self.assertEqual(res[0], value) + + query = 'SELECT client_os_hostname FROM v_monitor.sessions WHERE session_id=(SELECT current_session())' + res = self._query_and_fetchone(query) + self.assertEqual(res[0], value) + + query = 'SELECT client_os_hostname FROM v_monitor.user_sessions WHERE session_id=(SELECT current_session())' + res = self._query_and_fetchone(query) + self.assertEqual(res[0], value) + + query = 'SELECT client_os_hostname FROM v_internal.dc_session_starts WHERE session_id=(SELECT current_session())' + res = self._query_and_fetchone(query) + self.assertEqual(res[0], value) + def test_session_label(self): label = str(uuid.uuid1()) self._conn_info['session_label'] = label @@ -106,5 +136,26 @@ def test_autocommit_off(self): # Set with attribute setter conn.autocommit = True self.assertTrue(conn.autocommit) + + def test_workload_default(self): + self.require_protocol_at_least(3 << 16 | 15) + with self._connect() as conn: + query = "SHOW WORKLOAD" + res = self._query_and_fetchone(query) + self.assertEqual(res[1], '') + + def test_workload_set_property(self): + self.require_protocol_at_least(3 << 16 | 15) + self._conn_info['workload'] = 'python_test_workload' + with self._connect() as conn: + # we use dc_client_server_messages to test that the client is working properly. + # We do not regularly test on a multi subcluster database and the server will reject this + # workload from the startup packet, returning a parameter status message with an empty string. + query = ("SELECT contents FROM dc_client_server_messages" + " WHERE session_id = current_session()" + " AND message_type = '^+'" + " AND contents LIKE '%workload%'") + res = self._query_and_fetchone(query) + self.assertEqual(res[0], 'workload: python_test_workload') exec(ConnectionTestCase.createPrepStmtClass()) diff --git a/vertica_python/tests/integration_tests/test_datatypes.py b/vertica_python/tests/integration_tests/test_datatypes.py index 9f7fb43c..8fd6aad6 100644 --- a/vertica_python/tests/integration_tests/test_datatypes.py +++ b/vertica_python/tests/integration_tests/test_datatypes.py @@ -90,108 +90,119 @@ def tearDown(self): cur.execute(f"DROP TABLE IF EXISTS {self._table}") super(InsertComplexTypeTestCase, self).tearDown() - def _test_insert_complex_type(self, col_type, values, expected=None): + def _test_insert_complex_type(self, col_type, values, expected=None, test_executemany=False): if expected is None: expected = values with self._connect() as conn: cur = conn.cursor() cur.execute(f"DROP TABLE IF EXISTS {self._table}") cur.execute(f"CREATE TABLE {self._table} (a INT, b {col_type})") - a = 1 - for value in values: + seq_of_values = [(i, values[i]) for i in range(len(values))] + for value in seq_of_values: # Some cases need explicit typecasting - cur.execute(f"INSERT INTO {self._table} (a, b) VALUES (%s, %s::{col_type})", [a, value], use_prepared_statements=False) - a += 1 + cur.execute(f"INSERT INTO {self._table} (a, b) VALUES (%s, %s::{col_type})", value, use_prepared_statements=False) rows = cur.execute(f"SELECT b FROM {self._table} ORDER BY a").fetchall() results = [row[0] for row in rows] self.assertEqual(results, expected) + if not test_executemany: + return + # test cursor.executemany + cur.execute(f"TRUNCATE TABLE {self._table}") + cur.executemany(f"INSERT INTO {self._table} (a, b) VALUES (%s, %s)", seq_of_values, use_prepared_statements=False) + rows = cur.execute(f"SELECT b FROM {self._table} ORDER BY a").fetchall() + results = [row[0] for row in rows] + self.assertEqual(results, expected) + + ####################### # tests for ARRAY type ####################### def test_Array_boolean_type(self): - self._test_insert_complex_type('ARRAY[BOOL]', [[True, False, None], None, [], [None]]) + self._test_insert_complex_type('ARRAY[BOOL]', [[True, False, None], None, [], [None]], test_executemany=True) def test_Array_integer_type(self): - self._test_insert_complex_type('ARRAY[INT]', [[1,-2,3], [4,None,5], None, [], [None]]) + self._test_insert_complex_type('ARRAY[INT]', [[1,-2,3], [4,None,5], None, [], [None]], test_executemany=True) self._test_insert_complex_type('ARRAY[ARRAY[INT]]', [[[1,2], [3,4], None, [5,None], []], None, [], [None]]) self._test_insert_complex_type('ARRAY[ARRAY[ARRAY[ARRAY[INT]]]]', [[[[None,[1,2,3],None,[1,None,3],[None,None,None],[4,5],[],None]]], None, [], [None]]) def test_Array_float_type(self): - self._test_insert_complex_type('ARRAY[FLOAT]', [[1.23456e-18,float('Inf'),float('-Inf'),None,-1.234,0.0], None, [], [None]]) + self._test_insert_complex_type('ARRAY[FLOAT]', [[1.23456e-18,float('Inf'),float('-Inf'),None,-1.234,0.0], None, [], [None]], test_executemany=True) def test_Array_numeric_type(self): self._test_insert_complex_type('ARRAY[NUMERIC]', [[Decimal('-1.1200000000'), Decimal('0E-10'), None, Decimal('1234567890123456789.0123456789')], - None, [], [None]]) + None, [], [None]], test_executemany=True) def test_Array_char_type(self): - self._test_insert_complex_type('ARRAY[CHAR(3)]', [['a', u'\u16b1', None, 'foo'], None, [], [None]], [['a ', u'\u16b1', None, 'foo'], None, [], [None]]) + self._test_insert_complex_type('ARRAY[CHAR(3)]', [['a', u'\u16b1', None, 'foo'], None, [], [None]], [['a ', u'\u16b1', None, 'foo'], None, [], [None]], test_executemany=True) def test_Array_varchar_type(self): - self._test_insert_complex_type('ARRAY[VARCHAR(10)]', [['', u'\u16b1\nb', None, 'foo'], None, [], [None]]) + self._test_insert_complex_type('ARRAY[VARCHAR(10)]', [['', u'\u16b1\nb', None, 'foo'], None, [], [None]], test_executemany=True) + self._test_insert_complex_type('ARRAY[VARCHAR]', [[chr(i)] for i in range(1, 128)], test_executemany=True) def test_Array_date_type(self): - self._test_insert_complex_type('ARRAY[DATE]', [[date(2021, 6, 10),None,date(221, 5, 2)], None, [], [None]]) + self._test_insert_complex_type('ARRAY[DATE]', [[date(2021, 6, 10),None,date(221, 5, 2)], None, [], [None]], test_executemany=True) def test_Array_time_type(self): - self._test_insert_complex_type('ARRAY[TIME(3)]', [[time(0, 0, 0),None,time(22, 36, 33, 124000)], None, [], [None]]) + self._test_insert_complex_type('ARRAY[TIME(3)]', [[time(0, 0, 0),None,time(22, 36, 33, 124000)], None, [], [None]], test_executemany=True) def test_Array_timetz_type(self): self._test_insert_complex_type('ARRAY[TIMETZ(3)]', [[time(22, 36, 33, 123000, tzinfo=tzoffset(None, 23400)),None, - time(22, 36, 33, 123000, tzinfo=tzoffset(None, -10800))], None, [], [None]]) + time(22, 36, 33, 123000, tzinfo=tzoffset(None, -10800))], None, [], [None]], test_executemany=True) def test_Array_timestamp_type(self): - self._test_insert_complex_type('ARRAY[TIMESTAMP]', [[datetime(276, 12, 1, 11, 22, 33),None,datetime(2001, 12, 1, 0, 30, 45, 87000)], None, [], [None]]) + self._test_insert_complex_type('ARRAY[TIMESTAMP]', [[datetime(276, 12, 1, 11, 22, 33),None,datetime(2001, 12, 1, 0, 30, 45, 87000)], None, [], [None]], test_executemany=True) def test_Array_timestamptz_type(self): - self._test_insert_complex_type('ARRAY[TIMESTAMPTZ]', [[datetime(276, 11, 30, 23, 32, 57, tzinfo=tzoffset(None, 3600)),None,datetime(2001, 12, 1, 0, 30, 45, 87000, tzinfo=tzoffset(None, -18000))], None, [], [None]]) + self._test_insert_complex_type('ARRAY[TIMESTAMPTZ]', [[datetime(276, 11, 30, 23, 32, 57, tzinfo=tzoffset(None, 3600)),None,datetime(2001, 12, 1, 0, 30, 45, 87000, tzinfo=tzoffset(None, -18000))], None, [], [None]], test_executemany=True) def test_Array_UUID_type(self): - self._test_insert_complex_type('ARRAY[UUID]', [[UUID('00010203-0405-0607-0809-0a0b0c0d0e0f'),None,UUID('123e4567-e89b-12d3-a456-426655440a00')], None, [], [None]]) + self._test_insert_complex_type('ARRAY[UUID]', [[UUID('00010203-0405-0607-0809-0a0b0c0d0e0f'),None,UUID('123e4567-e89b-12d3-a456-426655440a00')], None, [], [None]], test_executemany=True) ##################### # tests for SET type ##################### def test_1DSet_boolean_type(self): - self._test_insert_complex_type('SET[BOOL]', [{True, False, None}, None, set(), {None}]) + self._test_insert_complex_type('SET[BOOL]', [{True, False, None}, None, set(), {None}], test_executemany=True) def test_1DSet_integer_type(self): - self._test_insert_complex_type('SET[INT]', [{0, 1, -2, 3, None}, None, set(), {None}]) + self._test_insert_complex_type('SET[INT]', [{0, 1, -2, 3, None}, None, set(), {None}], test_executemany=True) def test_1DSet_float_type(self): - self._test_insert_complex_type('SET[FLOAT]', [{float('Inf'), float('-Inf'), None, -1.234, 0.0, 1.23456e-18}, None, set(), {None}]) + self._test_insert_complex_type('SET[FLOAT]', [{float('Inf'), float('-Inf'), None, -1.234, 0.0, 1.23456e-18}, None, set(), {None}], test_executemany=True) def test_1DSet_numeric_type(self): self._test_insert_complex_type('SET[NUMERIC]', [{Decimal('-1.12'), Decimal('0E-15'), None, Decimal('1234567890123456789.0123456789')}, - None, set(), {None}]) + None, set(), {None}], test_executemany=True) def test_1DSet_char_type(self): - self._test_insert_complex_type('SET[CHAR(3)]', [{'a ', u'\u16b1', None, 'foo'}, None, set(), {None}]) + self._test_insert_complex_type('SET[CHAR(3)]', [{'a ', u'\u16b1', None, 'foo'}, None, set(), {None}], test_executemany=True) def test_1DSet_varchar_type(self): - self._test_insert_complex_type('SET[VARCHAR(10)]', [{'', u'\u16b1\nb', None, 'foo'}, None, set(), {None}]) + self._test_insert_complex_type('SET[VARCHAR(10)]', [{'', u'\u16b1\nb', None, 'foo'}, None, set(), {None}], test_executemany=True) + self._test_insert_complex_type('SET[VARCHAR]', [{chr(i)} for i in range(1, 128)], test_executemany=True) def test_1DSet_date_type(self): - self._test_insert_complex_type('SET[DATE]', [{date(2021, 6, 10), None, date(221, 5, 2)}, None, set(), {None}]) + self._test_insert_complex_type('SET[DATE]', [{date(2021, 6, 10), None, date(221, 5, 2)}, None, set(), {None}], test_executemany=True) def test_1DSet_time_type(self): - self._test_insert_complex_type('SET[TIME(3)]', [{time(0, 0, 0), None, time(22, 36, 33, 124000)}, None, set(), {None}]) + self._test_insert_complex_type('SET[TIME(3)]', [{time(0, 0, 0), None, time(22, 36, 33, 124000)}, None, set(), {None}], test_executemany=True) def test_1DSet_timetz_type(self): self._test_insert_complex_type('SET[TIMETZ(3)]', [{time(22, 36, 33, 123000, tzinfo=tzoffset(None, 23400)),None, - time(22, 36, 33, 123000, tzinfo=tzoffset(None, -10800))}, None, set(), {None}]) + time(22, 36, 33, 123000, tzinfo=tzoffset(None, -10800))}, None, set(), {None}], test_executemany=True) def test_1DSet_timestamp_type(self): - self._test_insert_complex_type('SET[TIMESTAMP]', [{datetime(276, 12, 1, 11, 22, 33),None,datetime(2001, 12, 1, 0, 30, 45, 87000)}, None, set(), {None}]) + self._test_insert_complex_type('SET[TIMESTAMP]', [{datetime(276, 12, 1, 11, 22, 33),None,datetime(2001, 12, 1, 0, 30, 45, 87000)}, None, set(), {None}], test_executemany=True) def test_1DSet_timestamptz_type(self): self._test_insert_complex_type('SET[TIMESTAMPTZ]', [{datetime(276, 11, 30, 23, 32, 57, tzinfo=tzoffset(None, 3600)),None, - datetime(2001, 12, 1, 0, 30, 45, 87000, tzinfo=tzoffset(None, -18000))}, None, set(), {None}]) + datetime(2001, 12, 1, 0, 30, 45, 87000, tzinfo=tzoffset(None, -18000))}, None, set(), {None}], test_executemany=True) def test_1DSet_UUID_type(self): - self._test_insert_complex_type('SET[UUID]', [{UUID('00010203-0405-0607-0809-0a0b0c0d0e0f'),None,UUID('123e4567-e89b-12d3-a456-426655440a00')}, None, set(), {None}]) + self._test_insert_complex_type('SET[UUID]', [{UUID('00010203-0405-0607-0809-0a0b0c0d0e0f'),None,UUID('123e4567-e89b-12d3-a456-426655440a00')}, None, set(), {None}], test_executemany=True) ##################### # tests for ROW type diff --git a/vertica_python/tests/integration_tests/test_tls.py b/vertica_python/tests/integration_tests/test_tls.py new file mode 100644 index 00000000..2a0bb4af --- /dev/null +++ b/vertica_python/tests/integration_tests/test_tls.py @@ -0,0 +1,185 @@ +# Copyright (c) 2023 Open Text. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division, absolute_import + +import os +import socket +import ssl +from tempfile import NamedTemporaryFile + +from ... import errors +from .base import VerticaPythonIntegrationTestCase + + +class TlsTestCase(VerticaPythonIntegrationTestCase): + def tearDown(self): + if 'ssl' in self._conn_info: + del self._conn_info['ssl'] + with self._connect() as conn: + cur = conn.cursor() + cur.execute("ALTER TLS CONFIGURATION server CERTIFICATE NULL TLSMODE 'DISABLE'") + if hasattr(self, 'client_cert'): + os.remove(self.client_cert.name) + cur.execute("ALTER TLS CONFIGURATION server REMOVE CA CERTIFICATES vp_CA_cert") + if hasattr(self, 'client_key'): + os.remove(self.client_key.name) + cur.execute("DROP KEY IF EXISTS vp_client_key CASCADE") + cur.execute("DROP KEY IF EXISTS vp_server_key CASCADE") + cur.execute("DROP KEY IF EXISTS vp_CA_key CASCADE") + super(TlsTestCase, self).tearDown() + + def _generate_and_set_certificates(self, mutual_mode=False): + with self._connect() as conn: + cur = conn.cursor() + + # Generate a root CA private key + cur.execute("CREATE KEY vp_CA_key TYPE 'RSA' LENGTH 4096") + # Generate a root CA certificate + cur.execute("CREATE CA CERTIFICATE vp_CA_cert " + "SUBJECT '/C=US/ST=Massachusetts/L=Burlington/O=OpenText/OU=Vertica/CN=Vertica Root CA' " + "VALID FOR 3650 EXTENSIONS 'nsComment' = 'Self-signed root CA cert' KEY vp_CA_key") + cur.execute("SELECT certificate_text FROM CERTIFICATES WHERE name='vp_CA_cert'") + vp_CA_cert = cur.fetchone()[0] + + # Generate a server private key + cur.execute("CREATE KEY vp_server_key TYPE 'RSA' LENGTH 4096") + # Generate a server certificate + host = self._conn_info['host'] + hostname_for_verify = ('IP:' if host.count('.') == 3 else 'DNS:') + host + cur.execute("CREATE CERTIFICATE vp_server_cert " + "SUBJECT '/C=US/ST=MA/L=Cambridge/O=Foo/OU=Vertica/CN=Vertica server/emailAddress=abc@example.com' " + "SIGNED BY vp_CA_cert EXTENSIONS 'nsComment' = 'Vertica server cert', 'extendedKeyUsage' = 'serverAuth', " + f"'subjectAltName' = '{hostname_for_verify}' KEY vp_server_key") + + if mutual_mode: + # Generate a client private key + cur.execute("CREATE KEY vp_client_key TYPE 'RSA' LENGTH 4096") + cur.execute("SELECT key FROM CRYPTOGRAPHIC_KEYS WHERE name='vp_client_key'") + vp_client_key = cur.fetchone()[0] + with NamedTemporaryFile(delete=False) as self.client_key: + self.client_key.write(vp_client_key.encode()) + # Generate a client certificate + cur.execute("CREATE CERTIFICATE vp_client_cert " + "SUBJECT '/C=US/ST=MA/L=Boston/O=Bar/OU=Vertica/CN=Vertica client/emailAddress=def@example.com' " + "SIGNED BY vp_CA_cert EXTENSIONS 'nsComment' = 'Vertica client cert', 'extendedKeyUsage' = 'clientAuth' " + "KEY vp_client_key") + cur.execute("SELECT certificate_text FROM CERTIFICATES WHERE name='vp_client_cert'") + vp_client_cert = cur.fetchone()[0] + with NamedTemporaryFile(delete=False) as self.client_cert: + self.client_cert.write(vp_client_cert.encode()) + + # In order to use Mutual Mode, set a server and CA certificate. + # This CA certificate is used to verify client certificates + cur.execute('ALTER TLS CONFIGURATION server CERTIFICATE vp_server_cert ADD CA CERTIFICATES vp_CA_cert') + # Enable TLS. Connection succeeds if Vertica verifies that the client certificate is from a trusted CA. + # If the client does not present a client certificate, the connection uses plaintext. + cur.execute("ALTER TLS CONFIGURATION server TLSMODE 'VERIFY_CA'") + + else: + # In order to use Server Mode, set the server certificate for the server's TLS Configuration + cur.execute('ALTER TLS CONFIGURATION server CERTIFICATE vp_server_cert') + # Enable TLS. Server does not check client certificates. + cur.execute("ALTER TLS CONFIGURATION server TLSMODE 'ENABLE'") + + # For debug + # SELECT * FROM tls_configurations WHERE name='server'; + # SELECT * FROM CRYPTOGRAPHIC_KEYS; + # SELECT * FROM CERTIFICATES; + + return vp_CA_cert + + + def test_TLSMode_disable(self): + self._conn_info['ssl'] = False + with self._connect() as conn: + cur = conn.cursor() + res = self._query_and_fetchone('SELECT ssl_state FROM sessions WHERE session_id=(SELECT current_session())') + self.assertEqual(res[0], 'None') + + def test_TLSMode_require_server_disable(self): + # Requires that the server use TLS. If the TLS connection attempt fails, the client rejects the connection. + self._conn_info['ssl'] = True + self.assertConnectionFail(err_type=errors.SSLNotSupported, + err_msg='SSL requested but not supported by server') + + def test_TLSMode_require(self): + # Setting certificates with TLS configuration + self._generate_and_set_certificates() + + # Option 1 + self._conn_info['ssl'] = True + with self._connect() as conn: + cur = conn.cursor() + res = self._query_and_fetchone('SELECT ssl_state FROM sessions WHERE session_id=(SELECT current_session())') + self.assertEqual(res[0], 'Server') + + # Option 2 + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + self._conn_info['ssl'] = ssl_context + with self._connect() as conn: + cur = conn.cursor() + res = self._query_and_fetchone('SELECT ssl_state FROM sessions WHERE session_id=(SELECT current_session())') + self.assertEqual(res[0], 'Server') + + def test_TLSMode_verify_ca(self): + # Setting certificates with TLS configuration + CA_cert = self._generate_and_set_certificates() + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = False + ssl_context.load_verify_locations(cadata=CA_cert) # CA certificate used to verify server certificate + self._conn_info['ssl'] = ssl_context + + with self._connect() as conn: + cur = conn.cursor() + res = self._query_and_fetchone('SELECT ssl_state FROM sessions WHERE session_id=(SELECT current_session())') + self.assertEqual(res[0], 'Server') + + def test_TLSMode_verify_full(self): + # Setting certificates with TLS configuration + CA_cert = self._generate_and_set_certificates() + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True # hostname in server cert's subjectAltName + ssl_context.load_verify_locations(cadata=CA_cert) # CA certificate used to verify server certificate + + self._conn_info['ssl'] = ssl_context + with self._connect() as conn: + cur = conn.cursor() + res = self._query_and_fetchone('SELECT ssl_state FROM sessions WHERE session_id=(SELECT current_session())') + self.assertEqual(res[0], 'Server') + + def test_mutual_TLS(self): + # Setting certificates with TLS configuration + CA_cert = self._generate_and_set_certificates(mutual_mode=True) + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True # hostname in server cert's subjectAltName + ssl_context.load_verify_locations(cadata=CA_cert) # CA certificate used to verify server certificate + ssl_context.load_cert_chain(certfile=self.client_cert.name, keyfile=self.client_key.name) + + self._conn_info['ssl'] = ssl_context + with self._connect() as conn: + cur = conn.cursor() + res = self._query_and_fetchone('SELECT ssl_state FROM sessions WHERE session_id=(SELECT current_session())') + self.assertEqual(res[0], 'Mutual') + + +exec(TlsTestCase.createPrepStmtClass()) diff --git a/vertica_python/tests/unit_tests/test_parsedsn.py b/vertica_python/tests/unit_tests/test_parsedsn.py index b344a994..28e84ae3 100644 --- a/vertica_python/tests/unit_tests/test_parsedsn.py +++ b/vertica_python/tests/unit_tests/test_parsedsn.py @@ -41,11 +41,15 @@ def test_str_arguments(self): dsn = ('vertica://john:pwd@localhost:5433/db1?' 'session_label=vpclient&unicode_error=strict&' 'log_path=/home/admin/vClient.log&log_level=DEBUG&' + 'oauth_access_token=GciOiJSUzI1NiI&' + 'workload=python_test_workload&' 'kerberos_service_name=krb_service&kerberos_host_name=krb_host') expected = {'database': 'db1', 'host': 'localhost', 'user': 'john', 'password': 'pwd', 'port': 5433, 'log_level': 'DEBUG', 'session_label': 'vpclient', 'unicode_error': 'strict', - 'log_path': '/home/admin/vClient.log', + 'log_path': '/home/admin/vClient.log', + 'oauth_access_token': 'GciOiJSUzI1NiI', + 'workload': 'python_test_workload', 'kerberos_service_name': 'krb_service', 'kerberos_host_name': 'krb_host'} parsed = parse_dsn(dsn) diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 9b047a53..ed41fb40 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -73,6 +73,8 @@ DEFAULT_LOG_PATH = 'vertica_python.log' DEFAULT_BINARY_TRANSFER = False DEFAULT_REQUEST_COMPLEX_TYPES = True +DEFAULT_OAUTH_ACCESS_TOKEN = '' +DEFAULT_WORKLOAD = '' try: DEFAULT_USER = getpass.getuser() except Exception as e: @@ -288,9 +290,11 @@ def __init__(self, options=None): raise KeyError(msg) self.options.setdefault('database', DEFAULT_DATABASE) self.options.setdefault('password', DEFAULT_PASSWORD) + self.options.setdefault('oauth_access_token', DEFAULT_OAUTH_ACCESS_TOKEN) self.options.setdefault('autocommit', DEFAULT_AUTOCOMMIT) self.options.setdefault('session_label', _generate_session_label()) self.options.setdefault('backup_server_node', DEFAULT_BACKUP_SERVER_NODE) + self.options.setdefault('workload', DEFAULT_WORKLOAD) self.options.setdefault('kerberos_service_name', DEFAULT_KRB_SERVICE_NAME) # Kerberos authentication hostname defaults to the host value here so # the correct value cannot be overwritten by load balancing or failover @@ -330,8 +334,9 @@ def __init__(self, options=None): self.startup_connection() # Complex types metadata is returned since protocol version 3.12 - self.complex_types_enabled = self.parameters['protocol_version'] >= (3 << 16 | 12) and \ + self.complex_types_enabled = self.parameters.get('protocol_version', 0) >= (3 << 16 | 12) and \ self.parameters.get('request_complex_types', 'off') == 'on' + self._logger.info('Connection is ready') ############################################# @@ -823,8 +828,11 @@ def startup_connection(self): autocommit = self.options['autocommit'] binary_transfer = self.options['binary_transfer'] request_complex_types = self.options['request_complex_types'] + oauth_access_token = self.options['oauth_access_token'] + workload = self.options['workload'] - self.write(messages.Startup(user, database, session_label, os_user_name, autocommit, binary_transfer, request_complex_types)) + self.write(messages.Startup(user, database, session_label, os_user_name, autocommit, binary_transfer, + request_complex_types, oauth_access_token, workload)) while True: message = self.read_message() @@ -843,6 +851,8 @@ def startup_connection(self): self._logger.warning(password_grace) elif message.code == messages.Authentication.GSS: self.make_GSS_authentication() + elif message.code == messages.Authentication.OAUTH: + self.write(messages.Password(oauth_access_token, message.code)) else: self.write(messages.Password(password, message.code, {'user': user, diff --git a/vertica_python/vertica/cursor.py b/vertica_python/vertica/cursor.py index 870c0f15..eccef2f7 100644 --- a/vertica_python/vertica/cursor.py +++ b/vertica_python/vertica/cursor.py @@ -303,13 +303,14 @@ def executemany(self, operation, seq_of_parameters, use_prepared_statements=None variables = ",".join([variable.strip().strip('"') for variable in variables.split(",")]) values = as_text(m.group('values')) - values = ",".join([value.strip().strip('"') for value in values.split(",")]) + values = "|".join([value.strip().strip('"') for value in values.split(",")]) seq_of_values = [self.format_operation_with_parameters(values, parameters, is_copy_data=True) for parameters in seq_of_parameters] data = "\n".join(seq_of_values) copy_statement = ( - u"COPY {0} ({1}) FROM STDIN DELIMITER ',' ENCLOSED BY '\"' " + u"COPY {0} ({1}) FROM STDIN " + u"ENCLOSED BY '''' " # '/r' will have trouble if ENCLOSED BY is not set u"ENFORCELENGTH ABORT ON ERROR{2}").format(target, variables, " NO COMMIT" if not self.connection.autocommit else '') @@ -590,7 +591,7 @@ def format_row_as_array(self, row_data): return [convert(value) for convert, value in zip(self._deserializers, row_data.values)] - def object_to_string(self, py_obj, is_copy_data): + def object_to_string(self, py_obj, is_copy_data, is_collection=False): """Return the SQL representation of the object as a string""" if type(py_obj) in self._sql_literal_adapters and not is_copy_data: adapter = self._sql_literal_adapters[type(py_obj)] @@ -601,15 +602,15 @@ def object_to_string(self, py_obj, is_copy_data): return as_text(result) if isinstance(py_obj, type(None)): - return '' if is_copy_data else 'NULL' + return '' if is_copy_data and not is_collection else 'NULL' elif isinstance(py_obj, bool): return str(py_obj) elif isinstance(py_obj, (str, bytes)): - return self.format_quote(as_text(py_obj), is_copy_data) + return self.format_quote(as_text(py_obj), is_copy_data, is_collection) elif isinstance(py_obj, (int, Decimal)): return str(py_obj) elif isinstance(py_obj, float): - if py_obj in (float('Inf'), float('-Inf')) or isnan(py_obj): + if not is_copy_data and py_obj in (float('Inf'), float('-Inf')) or isnan(py_obj): return f"'{str(py_obj)}'::FLOAT" return str(py_obj) elif isinstance(py_obj, tuple): # tuple and namedtuple @@ -617,20 +618,31 @@ def object_to_string(self, py_obj, is_copy_data): for i in range(len(py_obj)): elements[i] = self.object_to_string(py_obj[i], is_copy_data) return "(" + ",".join(elements) + ")" - elif isinstance(py_obj, list) and not is_copy_data: + elif isinstance(py_obj, list): elements = [None] * len(py_obj) - for i in range(len(py_obj)): - elements[i] = self.object_to_string(py_obj[i], False) - # Use the ARRAY keyword to construct an array value - return f'ARRAY[{",".join(elements)}]' - elif isinstance(py_obj, set) and not is_copy_data: + if is_copy_data: + for i in range(len(py_obj)): + elements[i] = self.object_to_string(py_obj[i], True, True) + return f'[{",".join(elements)}]' + else: + for i in range(len(py_obj)): + elements[i] = self.object_to_string(py_obj[i], False) + # Use the ARRAY keyword to construct an array value + return f'ARRAY[{",".join(elements)}]' + elif isinstance(py_obj, set): elements = [None] * len(py_obj) i = 0 - for o in py_obj: - elements[i] = self.object_to_string(o, False) - i += 1 - # Use the SET keyword to construct a set value - return f'SET[{",".join(elements)}]' + if is_copy_data: + for o in py_obj: + elements[i] = self.object_to_string(o, True, True) + i += 1 + return f'[{",".join(elements)}]' + else: + for o in py_obj: + elements[i] = self.object_to_string(o, False) + i += 1 + # Use the SET keyword to construct a set value + return f'SET[{",".join(elements)}]' elif isinstance(py_obj, dict) and not is_copy_data: elements = [None] * len(py_obj) i = 0 @@ -640,7 +652,7 @@ def object_to_string(self, py_obj, is_copy_data): # Use the ROW keyword to construct a row value return f'ROW({",".join(elements)})' elif isinstance(py_obj, (datetime.datetime, datetime.date, datetime.time, UUID)): - return self.format_quote(as_text(str(py_obj)), is_copy_data) + return self.format_quote(as_text(str(py_obj)), is_copy_data, is_collection) else: if is_copy_data: return str(py_obj) @@ -686,13 +698,19 @@ def format_operation_with_parameters(self, operation, parameters, is_copy_data=F return operation - def format_quote(self, param, is_copy_data): - if is_copy_data: + def format_quote(self, param, is_copy_data, is_collection): + if is_collection: # COPY COLLECTIONENCLOSE s = list(param) for i, c in enumerate(param): - if c in u'()[]{}?"*+-|^$\\.&~# \t\n\r\v\f': + if c in '\\\n\"': s[i] = "\\" + c return u'"{0}"'.format(u"".join(s)) + elif is_copy_data: # COPY ENCLOSED BY + s = list(param) + for i, c in enumerate(param): + if c in '\\|\n\'': + s[i] = "\\" + c + return u"'{0}'".format(u"".join(s)) else: return u"'{0}'".format(param.replace(u"'", u"''")) diff --git a/vertica_python/vertica/messages/backend_messages/command_complete.py b/vertica_python/vertica/messages/backend_messages/command_complete.py index 5127a234..405462a1 100644 --- a/vertica_python/vertica/messages/backend_messages/command_complete.py +++ b/vertica_python/vertica/messages/backend_messages/command_complete.py @@ -59,24 +59,12 @@ def __init__(self, data): try: self.command_tag = data.decode('utf-8') except Exception as e: - # (workaround for #493) something wrong in the server, hide the problem for now - warnings.warn("Hit a known server bug\n" - f"{'='*80}\n" - "We'd like to gather client-side information to help with the bug investigation.\n" - "Please leave a comment under https://github.com/vertica/vertica-python/issues/493" - " with the following info:\n" - f"{'-'*80}\n" - f"command tag length: {len(data)}\n" - f"command tag content: {data}\n" - f"{type(e).__name__}: {str(e)}\n" - "Server version: xxx\n" - "Query executed (if possible): xxx\n" - "The OS of each server node (if possible): xxx\n" - "The locale of each server node (if possible): xxx\n" - f"{'-'*80}\n" - f"We appreciate your help!\n" - f"{'='*80}\n" - ) + # VER-86494 + warnings.warn( + f"\n{'-'*70}\n" + "Hit a known server bug (#493). To fix it,\n" + "please upgrade your server to 12.0.4-3 or higher version.\n" + f"{'-'*70}\n") self.command_tag = 'x' def __str__(self): diff --git a/vertica_python/vertica/messages/frontend_messages/password.py b/vertica_python/vertica/messages/frontend_messages/password.py index 6b9b6504..ee248b24 100644 --- a/vertica_python/vertica/messages/frontend_messages/password.py +++ b/vertica_python/vertica/messages/frontend_messages/password.py @@ -60,7 +60,6 @@ def __init__(self, password, auth_method=None, options=None): self._auth_method = Authentication.CLEARTEXT_PASSWORD def encoded_password(self): - if self._auth_method == Authentication.CLEARTEXT_PASSWORD: return self._password elif self._auth_method == Authentication.CRYPT_PASSWORD: @@ -83,8 +82,10 @@ def encoded_password(self): return prefix + self._password elif self._auth_method == Authentication.GSS: return self._password + elif self._auth_method == Authentication.OAUTH: + return self._password else: - raise ValueError("unsupported authentication method: {0}".format(self._auth_method)) + raise ValueError(f"unsupported authentication method: {self._auth_method}") def read_bytes(self): encoded_pw = self.encoded_password() diff --git a/vertica_python/vertica/messages/frontend_messages/startup.py b/vertica_python/vertica/messages/frontend_messages/startup.py index 8cf2ea88..68f74fd1 100644 --- a/vertica_python/vertica/messages/frontend_messages/startup.py +++ b/vertica_python/vertica/messages/frontend_messages/startup.py @@ -58,7 +58,8 @@ class Startup(BulkFrontendMessage): message_id = None def __init__(self, user, database, session_label, os_user_name, autocommit, - binary_transfer, request_complex_types): + binary_transfer, request_complex_types, oauth_access_token, + workload): BulkFrontendMessage.__init__(self) try: @@ -95,8 +96,13 @@ def __init__(self, user, database, session_label, os_user_name, autocommit, b'binary_data_protocol': '1' if binary_transfer else '0', # Defaults to text format '0' b'protocol_features': '{"request_complex_types":' + request_complex_types + '}', b'protocol_compat': 'VER', + b'workload': workload, } + if len(oauth_access_token) > 0: + self.parameters[b'oauth_access_token'] = oauth_access_token # protocol version 3.11 + self.parameters[b'auth_category'] = 'OAuth' # protocol version 3.12+ + def read_bytes(self): # The fixed protocol version is followed by pairs of parameter name and value strings. # A zero byte is required as a terminator after the last name/value pair. diff --git a/vertica_python/vertica/messages/frontend_messages/verified_files.py b/vertica_python/vertica/messages/frontend_messages/verified_files.py index 516be4bc..486ad3fa 100644 --- a/vertica_python/vertica/messages/frontend_messages/verified_files.py +++ b/vertica_python/vertica/messages/frontend_messages/verified_files.py @@ -29,7 +29,7 @@ def __init__(self, file_list): self.filenames = file_list def read_bytes(self): - bytes_ = pack('!H', len(self.filenames)) + bytes_ = pack('!I', len(self.filenames)) for filename in self.filenames: utf_filename = filename.encode('utf-8') bytes_ += pack('!{0}sx'.format(len(utf_filename)), utf_filename)