diff --git a/.asf.yaml b/.asf.yaml new file mode 100644 index 0000000000..0bacf232d1 --- /dev/null +++ b/.asf.yaml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +notifications: + commits: commits@cassandra.apache.org + issues: commits@cassandra.apache.org + pullrequests: pr@cassandra.apache.org + jira_options: link worklog + +github: + description: "Python Driver for Apache Cassandra®" + homepage: https://docs.datastax.com/en/developer/python-driver/3.29/index.html + enabled_merge_buttons: + squash: false + merge: false + rebase: true + features: + wiki: false + issues: false + projects: false + discussions: false + autolink_jira: + - CASSANDRA + - CASSPYTHON + protected_branches: + trunk: + required_linear_history: true diff --git a/.gitignore b/.gitignore index c4fca1c5f2..7983f44b87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,12 @@ -*.pyc +*.py[co] *.swp *.swo *.so +*.egg *.egg-info +*.attr +.tox +.python-version build MANIFEST dist @@ -10,3 +14,33 @@ dist cover/ docs/_build/ tests/integration/ccm +setuptools*.tar.gz +setuptools*.egg + +cassandra/*.c +!cassandra/cmurmur3.c +cassandra/*.html +tests/unit/cython/bytesio_testhelper.c + +# OSX +.DS_Store + +# IDE +.project +.pydevproject +.settings/ +.idea/ +*.iml + +.DS_Store + +# Unit test / coverage reports +.coverage +.tox + +#iPython +*.ipynb + +venv +docs/venv +.eggs \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c429a0811b..fbc7c07cba 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,1794 @@ +3.30.0 +====== +March 23, 2026 + +Features +-------- +* Introduce pyproject.toml to explicitly declare build dependencies (CASSPYTHON-7) +* Add Python 3.14 to CI, remove Python 3.9 (CASSPYTHON-4) +* Mark eventlet, gevent and Twisted event loops as deprecated (CASSPYTHON-12) + +Bug Fixes +--------- +* Do not set timeout to None when calling execute_async in execute_concurrent (PYTHON-1354) +* No C extension .so files in published binary Python whl packages of 3.29.3 (CASSPYTHON-3) +* Win32 wheels do not include compiled libev modules (CASSPYTHON-5) + +Others +------ +* Remove obsolete __future__ import absolute_import (PR 1263) +* Remove ez_setup for compatibility with setuptools v82 (PR 1268) +* Replace usage of with await lock (PR 1270) +* Update cassandra.util.Version to better support Cassandra version strings (CASSPYTHON-10) +* Update DRIVER_NAME after donation to ASF (CASSPYTHON-17) + +3.29.3 +====== +October 20, 2025 + +Features +-------- +* Upgraded cython to 3.0.x (PR 1221 & PYTHON-1390) +* Add support for DSE 6.9.x and HCD releases to CI (PYTHON-1402) +* Add execute_concurrent_async and expose execute_concurrent_* in Session (PR 1229) + +Bug Fixes +--------- +* Update geomet to align with requirements.txt (PR 1236) +* Connection failure to SNI endpoint when first host is unavailable (PYTHON-1419) +* Maintain compatibility with CPython 3.13 (PR 1242) + +Others +------ +* Remove duplicated condition in primary key check (PR 1240) +* Remove Python 3.8 which reached EOL on Oct 2024, update cryptography lib to 42 (PR 1247) +* Remove obsolete urllib2 from ez_setup.py (PR 1248) +* Remove stale dependency on sure (PR 1227) +* Removed 2.7 Cpython defines (PR 1252) + +3.29.2 +====== +September 9, 2024 + +Features +-------- +* Convert to pytest for running unit and integration tests (PYTHON-1297) +* Add support for Cassandra 4.1.x and 5.0 releases to CI (PYTHON-1393) +* Extend driver vector support to arbitrary subtypes and fix handling of variable length types (PYTHON-1369) + +Bug Fixes +--------- +* Python NumpyProtocolHandler does not work with NumPy 1.24.0 or greater (PYTHON-1359) +* cibuildwheel appears to not be stripping Cython-generated shared objects (PYTHON-1387) +* Windows build for Python 3.12 compiled without libev support (PYTHON-1386) + +Others +------ +* Update README.rst with badges for version and license (PR 1210) +* Remove dependency on old mock external module (PR 1201) +* Removed future print_function, division, and with and some pre 3.7 handling (PR 1208) +* Update geomet dependency (PR 1207) +* Remove problematic escape sequences in some docstrings to avoid SyntaxWarning in Python 3.12 (PR 1205) +* Use timezone-aware API to avoid deprecated warning (PR 1213) + +3.29.1 +====== +March 19, 2024 + +Bug Fixes +--------- +* cassandra-driver for Python 3.12 Linux is compiled without libev support (PYTHON-1378) +* Consider moving to native wheel builds for OS X and removing universal2 wheels (PYTHON-1379) + +3.29.0 +====== +December 19, 2023 + +Features +-------- +* Add support for Python 3.9 through 3.12, drop support for 3.7 (PYTHON-1283) +* Removal of dependency on six module (PR 1172) +* Raise explicit exception when deserializing a vector with a subtype that isn’t a constant size (PYTHON-1371) + +Others +------ +* Remove outdated Python pre-3.7 references (PR 1186) +* Remove backup(.bak) files (PR 1185) +* Fix doc typo in add_callbacks (PR 1177) + +3.28.0 +====== +June 5, 2023 + +Features +-------- +* Add support for vector type (PYTHON-1352) +* Cryptography module is now an optional dependency (PYTHON-1351) + +Bug Fixes +--------- +* Store IV along with encrypted text when using column-level encryption (PYTHON-1350) +* Create session-specific protocol handlers to contain session-specific CLE policies (PYTHON-1356) + +Others +------ +* Use Cython for smoke builds (PYTHON-1343) +* Don't fail when inserting UDTs with prepared queries with some missing fields (PR 1151) +* Convert print statement to function in docs (PR 1157) +* Update comment for retry policy (DOC-3278) +* Added error handling blog reference (DOC-2813) + +3.27.0 +====== +May 1, 2023 + +Features +-------- +* Add support for client-side encryption (PYTHON-1341) + +3.26.0 +====== +March 13, 2023 + +Features +-------- +* Add support for execution profiles in execute_concurrent (PR 1122) + +Bug Fixes +--------- +* Handle empty non-final result pages (PR 1110) +* Do not re-use stream IDs for in-flight requests (PR 1114) +* Asyncore race condition cause logging exception on shutdown (PYTHON-1266) + +Others +------ +* Fix deprecation warning in query tracing (PR 1103) +* Remove mutable default values from some tests (PR 1116) +* Remove dependency on unittest2 (PYTHON-1289) +* Fix deprecation warnings for asyncio.coroutine annotation in asyncioreactor (PYTHON-1290) +* Fix typos in source files (PR 1126) +* HostFilterPolicyInitTest fix for Python 3.11 (PR 1131) +* Fix for DontPrepareOnIgnoredHostsTest (PYTHON-1287) +* tests.integration.simulacron.test_connection failures (PYTHON-1304) +* tests.integration.standard.test_single_interface.py appears to be failing for C* 4.0 (PYTHON-1329) +* Authentication tests appear to be failing fraudulently (PYTHON-1328) +* PreparedStatementTests.test_fail_if_different_query_id_on_reprepare() failing unexpectedly (PTYHON-1327) +* Refactor deprecated unittest aliases for Python 3.11 compatibility (PR 1112) + +Deprecations +------------ +* This release removes support for Python 2.7.x as well as Python 3.5.x and 3.6.x + +3.25.0 +====== +March 18, 2021 + +Features +-------- +* Ensure the driver can connect when invalid peer hosts are in system.peers (PYTHON-1260) +* Implement protocol v5 checksumming (PYTHON-1258) +* Fix the default cqlengine connection mechanism to work with Astra (PYTHON-1265) + +Bug Fixes +--------- +* Asyncore race condition cause logging exception on shutdown (PYTHON-1266) +* Update list of reserved keywords (PYTHON-1269) + +Others +------ +* Drop Python 3.4 support (PYTHON-1220) +* Update security documentation and examples to use PROTOCOL_TLS (PYTHON-1264) + +3.24.0 +====== +June 18, 2020 + +Features +-------- +* Make geomet an optional dependency at runtime (PYTHON-1237) +* Add use_default_tempdir cloud config options (PYTHON-1245) +* Tcp flow control for libevreactor (PYTHON-1248) + +Bug Fixes +--------- +* Unable to connect to a cloud cluster using Ubuntu 20.04 (PYTHON-1238) +* PlainTextAuthProvider fails with unicode chars and Python3 (PYTHON-1241) +* [GRAPH] Graph execution profiles consistency level are not set to LOCAL_QUORUM with a cloud cluster (PYTHON-1240) +* [GRAPH] Can't write data in a Boolean field using the Fluent API (PYTHON-1239) +* [GRAPH] Fix elementMap() result deserialization (PYTHON-1233) + +Others +------ +* Bump geomet dependency version to 0.2 (PYTHON-1243) +* Bump gremlinpython dependency version to 3.4.6 (PYTHON-1212) +* Improve fluent graph documentation for core graphs (PYTHON-1244) + +3.23.0 +====== +April 6, 2020 + +Features +-------- +* Transient Replication Support (PYTHON-1207) +* Support system.peers_v2 and port discovery for C* 4.0 (PYTHON-700) + +Bug Fixes +--------- +* Asyncore logging exception on shutdown (PYTHON-1228) + +3.22.0 +====== +February 26, 2020 + +Features +-------- + +* Add all() function to the ResultSet API (PYTHON-1203) +* Parse new schema metadata in NGDG and generate table edges CQL syntax (PYTHON-996) +* Add GraphSON3 support (PYTHON-788) +* Use GraphSON3 as default for Native graphs (PYTHON-1004) +* Add Tuple and UDT types for native graph (PYTHON-1005) +* Add Duration type for native graph (PYTHON-1000) +* Add gx:ByteBuffer graphson type support for Blob field (PYTHON-1027) +* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045) +* Provide numerical wrappers to ensure proper graphson schema definition (PYTHON-1051) +* Resolve the row_factory automatically for native graphs (PYTHON-1056) +* Add g:TraversalMetrics/g:Metrics graph deserializers (PYTHON-1057) +* Add g:BulkSet graph deserializers (PYTHON-1060) +* Update Graph Engine names and the way to create a Classic/Native Graph (PYTHON-1090) +* Update Native to Core Graph Engine +* Add graphson3 and native graph support (PYTHON-1039) +* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045) +* Expose filter predicates for cql collections (PYTHON-1019) +* Add g:TraversalMetrics/Metrics deserializers (PYTHON-1057) +* Make graph metadata handling more robust (PYTHON-1204) + +Bug Fixes +--------- +* Make sure to only query the native_transport_address column with DSE (PYTHON-1205) + +3.21.0 +====== +January 15, 2020 + +Features +-------- +* Unified driver: merge core and DSE drivers into a single package (PYTHON-1130) +* Add Python 3.8 support (PYTHON-1189) +* Allow passing ssl context for Twisted (PYTHON-1161) +* Ssl context and cloud support for Eventlet (PYTHON-1162) +* Cloud Twisted support (PYTHON-1163) +* Add additional_write_policy and read_repair to system schema parsing (PYTHON-1048) +* Flexible version parsing (PYTHON-1174) +* Support NULL in collection deserializer (PYTHON-1123) +* [GRAPH] Ability to execute Fluent Graph queries asynchronously (PYTHON-1129) + +Bug Fixes +--------- +* Handle prepared id mismatch when repreparing on the fly (PYTHON-1124) +* re-raising the CQLEngineException will fail on Python 3 (PYTHON-1166) +* asyncio message chunks can be processed discontinuously (PYTHON-1185) +* Reconnect attempts persist after downed node removed from peers (PYTHON-1181) +* Connection fails to validate ssl certificate hostname when SSLContext.check_hostname is set (PYTHON-1186) +* ResponseFuture._set_result crashes on connection error when used with PrepareMessage (PYTHON-1187) +* Insights fail to serialize the startup message when the SSL Context is from PyOpenSSL (PYTHON-1192) + +Others +------ +* The driver has a new dependency: geomet. It comes from the dse-driver unification and + is used to support DSE geo types. +* Remove *read_repair_chance table options (PYTHON-1140) +* Avoid warnings about unspecified load balancing policy when connecting to a cloud cluster (PYTHON-1177) +* Add new DSE CQL keywords (PYTHON-1122) +* Publish binary wheel distributions (PYTHON-1013) + +Deprecations +------------ + +* DSELoadBalancingPolicy will be removed in the next major, consider using + the DefaultLoadBalancingPolicy. + +Merged from dse-driver: + +Features +-------- + +* Insights integration (PYTHON-1047) +* Graph execution profiles should preserve their graph_source when graph_options is overridden (PYTHON-1021) +* Add NodeSync metadata (PYTHON-799) +* Add new NodeSync failure values (PYTHON-934) +* DETERMINISTIC and MONOTONIC Clauses for Functions and Aggregates (PYTHON-955) +* GraphOptions should show a warning for unknown parameters (PYTHON-819) +* DSE protocol version 2 and continous paging backpressure (PYTHON-798) +* GraphSON2 Serialization/Deserialization Support (PYTHON-775) +* Add graph-results payload option for GraphSON format (PYTHON-773) +* Create an AuthProvider for the DSE transitional mode (PYTHON-831) +* Implement serializers for the Graph String API (PYTHON-778) +* Provide deserializers for GraphSON types (PYTHON-782) +* Add Graph DurationType support (PYTHON-607) +* Support DSE DateRange type (PYTHON-668) +* RLAC CQL output for materialized views (PYTHON-682) +* Add Geom Types wkt deserializer +* DSE Graph Client timeouts in custom payload (PYTHON-589) +* Make DSEGSSAPIAuthProvider accept principal name (PYTHON-574) +* Add config profiles to DSE graph execution (PYTHON-570) +* DSE Driver version checking (PYTHON-568) +* Distinct default timeout for graph queries (PYTHON-477) +* Graph result parsing for known types (PYTHON-479,487) +* Distinct read/write CL for graph execution (PYTHON-509) +* Target graph analytics query to spark master when available (PYTHON-510) + +Bug Fixes +--------- + +* Continuous paging sessions raise RuntimeError when results are not entirely consumed (PYTHON-1054) +* GraphSON Property deserializer should return a dict instead of a set (PYTHON-1033) +* ResponseFuture.has_more_pages may hold the wrong value (PYTHON-946) +* DETERMINISTIC clause in AGGREGATE misplaced in CQL generation (PYTHON-963) +* graph module import cause a DLL issue on Windows due to its cythonizing failure (PYTHON-900) +* Update date serialization to isoformat in graph (PYTHON-805) +* DateRange Parse Error (PYTHON-729) +* MontonicTimestampGenerator.__init__ ignores class defaults (PYTHON-728) +* metadata.get_host returning None unexpectedly (PYTHON-709) +* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673) +* Resolve FQDN from ip address and use that as host passed to SASLClient (PYTHON-566) +* Geospatial type implementations don't handle 'EMPTY' values. (PYTHON-481) +* Correctly handle other types in geo type equality (PYTHON-508) + +Other +----- +* Add tests around cqlengine and continuous paging (PYTHON-872) +* Add an abstract GraphStatement to handle different graph statements (PYTHON-789) +* Write documentation examples for DSE 2.0 features (PYTHON-732) +* DSE_V1 protocol should not include all of protocol v5 (PYTHON-694) + +3.20.2 +====== +November 19, 2019 + +Bug Fixes +--------- +* Fix import error for old python installation without SSLContext (PYTHON-1183) + +3.20.1 +====== +November 6, 2019 + +Bug Fixes +--------- +* ValueError: too many values to unpack (expected 2)" when there are two dashes in server version number (PYTHON-1172) + +3.20.0 +====== +October 28, 2019 + +Features +-------- +* DataStax Astra Support (PYTHON-1074) +* Use 4.0 schema parser in 4 alpha and snapshot builds (PYTHON-1158) + +Bug Fixes +--------- +* Connection setup methods prevent using ExecutionProfile in cqlengine (PYTHON-1009) +* Driver deadlock if all connections dropped by heartbeat whilst request in flight and request times out (PYTHON-1044) +* Exception when use pk__token__gt filter In python 3.7 (PYTHON-1121) + +3.19.0 +====== +August 26, 2019 + +Features +-------- +* Add Python 3.7 support (PYTHON-1016) +* Future-proof Mapping imports (PYTHON-1023) +* Include param values in cqlengine logging (PYTHON-1105) +* NTS Token Replica Map Generation is slow (PYTHON-622) + +Bug Fixes +--------- +* as_cql_query UDF/UDA parameters incorrectly includes "frozen" if arguments are collections (PYTHON-1031) +* cqlengine does not currently support combining TTL and TIMESTAMP on INSERT (PYTHON-1093) +* Fix incorrect metadata for compact counter tables (PYTHON-1100) +* Call ConnectionException with correct kwargs (PYTHON-1117) +* Can't connect to clusters built from source because version parsing doesn't handle 'x.y-SNAPSHOT' (PYTHON-1118) +* Discovered node doesn´t honor the configured Cluster port on connection (PYTHON-1127) +* Exception when use pk__token__gt filter In python 3.7 (PYTHON-1121) + +Other +----- +* Remove invalid warning in set_session when we initialize a default connection (PYTHON-1104) +* Set the proper default ExecutionProfile.row_factory value (PYTHON-1119) + +3.18.0 +====== +May 27, 2019 + +Features +-------- + +* Abstract Host Connection information (PYTHON-1079) +* Improve version parsing to support a non-integer 4th component (PYTHON-1091) +* Expose on_request_error method in the RetryPolicy (PYTHON-1064) +* Add jitter to ExponentialReconnectionPolicy (PYTHON-1065) + +Bug Fixes +--------- + +* Fix error when preparing queries with beta protocol v5 (PYTHON-1081) +* Accept legacy empty strings as column names (PYTHON-1082) +* Let util.SortedSet handle uncomparable elements (PYTHON-1087) + +3.17.1 +====== +May 2, 2019 + +Bug Fixes +--------- +* Socket errors EAGAIN/EWOULDBLOCK are not handled properly and cause timeouts (PYTHON-1089) + +3.17.0 +====== +February 19, 2019 + +Features +-------- +* Send driver name and version in startup message (PYTHON-1068) +* Add Cluster ssl_context option to enable SSL (PYTHON-995) +* Allow encrypted private keys for 2-way SSL cluster connections (PYTHON-995) +* Introduce new method ConsistencyLevel.is_serial (PYTHON-1067) +* Add Session.get_execution_profile (PYTHON-932) +* Add host kwarg to Session.execute/execute_async APIs to send a query to a specific node (PYTHON-993) + +Bug Fixes +--------- +* NoHostAvailable when all hosts are up and connectable (PYTHON-891) +* Serial consistency level is not used (PYTHON-1007) + +Other +----- +* Fail faster on incorrect lz4 import (PYTHON-1042) +* Bump Cython dependency version to 0.29 (PYTHON-1036) +* Expand Driver SSL Documentation (PYTHON-740) + +Deprecations +------------ + +* Using Cluster.ssl_options to enable SSL is deprecated and will be removed in + the next major release, use ssl_context. +* DowngradingConsistencyRetryPolicy is deprecated and will be + removed in the next major release. (PYTHON-937) + +3.16.0 +====== +November 12, 2018 + +Bug Fixes +--------- +* Improve and fix socket error-catching code in nonblocking-socket reactors (PYTHON-1024) +* Non-ASCII characters in schema break CQL string generation (PYTHON-1008) +* Fix OSS driver's virtual table support against DSE 6.0.X and future server releases (PYTHON-1020) +* ResultSet.one() fails if the row_factory is using a generator (PYTHON-1026) +* Log profile name on attempt to create existing profile (PYTHON-944) +* Cluster instantiation fails if any contact points' hostname resolution fails (PYTHON-895) + +Other +----- +* Fix tests when RF is not maintained if we decomission a node (PYTHON-1017) +* Fix wrong use of ResultSet indexing (PYTHON-1015) + +3.15.1 +====== +September 6, 2018 + +Bug Fixes +--------- +* C* 4.0 schema-parsing logic breaks running against DSE 6.0.X (PYTHON-1018) + +3.15.0 +====== +August 30, 2018 + +Features +-------- +* Parse Virtual Keyspace Metadata (PYTHON-992) + +Bug Fixes +--------- +* Tokenmap.get_replicas returns the wrong value if token coincides with the end of the range (PYTHON-978) +* Python Driver fails with "more than 255 arguments" python exception when > 255 columns specified in query response (PYTHON-893) +* Hang in integration.standard.test_cluster.ClusterTests.test_set_keyspace_twice (PYTHON-998) +* Asyncore reactors should use a global variable instead of a class variable for the event loop (PYTHON-697) + +Other +----- +* Use global variable for libev loops so it can be subclassed (PYTHON-973) +* Update SchemaParser for V4 (PYTHON-1006) +* Bump Cython dependency version to 0.28 (PYTHON-1012) + +3.14.0 +====== +April 17, 2018 + +Features +-------- +* Add one() function to the ResultSet API (PYTHON-947) +* Create an utility function to fetch concurrently many keys from the same replica (PYTHON-647) +* Allow filter queries with fields that have an index managed outside of cqlengine (PYTHON-966) +* Twisted SSL Support (PYTHON-343) +* Support IS NOT NULL operator in cqlengine (PYTHON-968) + +Other +----- +* Fix Broken Links in Docs (PYTHON-916) +* Reevaluate MONKEY_PATCH_LOOP in test codebase (PYTHON-903) +* Remove CASS_SERVER_VERSION and replace it for CASSANDRA_VERSION in tests (PYTHON-910) +* Refactor CASSANDRA_VERSION to a some kind of version object (PYTHON-915) +* Log warning when driver configures an authenticator, but server does not request authentication (PYTHON-940) +* Warn users when using the deprecated Session.default_consistency_level (PYTHON-953) +* Add DSE smoke test to OSS driver tests (PYTHON-894) +* Document long compilation times and workarounds (PYTHON-868) +* Improve error for batch WriteTimeouts (PYTHON-941) +* Deprecate ResultSet indexing (PYTHON-945) + +3.13.0 +====== +January 30, 2018 + +Features +-------- +* cqlengine: LIKE filter operator (PYTHON-512) +* Support cassandra.query.BatchType with cqlengine BatchQuery (PYTHON-888) + +Bug Fixes +--------- +* AttributeError: 'NoneType' object has no attribute 'add_timer' (PYTHON-862) +* Support retry_policy in PreparedStatement (PYTHON-861) +* __del__ method in Session is throwing an exception (PYTHON-813) +* LZ4 import issue with recent versions (PYTHON-897) +* ResponseFuture._connection can be None when returning request_id (PYTHON-853) +* ResultSet.was_applied doesn't support batch with LWT statements (PYTHON-848) + +Other +----- +* cqlengine: avoid warning when unregistering connection on shutdown (PYTHON-865) +* Fix DeprecationWarning of log.warn (PYTHON-846) +* Fix example_mapper.py for python3 (PYTHON-860) +* Possible deadlock on cassandra.concurrent.execute_concurrent (PYTHON-768) +* Add some known deprecated warnings for 4.x (PYTHON-877) +* Remove copyright dates from copyright notices (PYTHON-863) +* Remove "Experimental" tag from execution profiles documentation (PYTHON-840) +* request_timer metrics descriptions are slightly incorrect (PYTHON-885) +* Remove "Experimental" tag from cqlengine connections documentation (PYTHON-892) +* Set in documentation default consistency for operations is LOCAL_ONE (PYTHON-901) + +3.12.0 +====== +November 6, 2017 + +Features +-------- +* Send keyspace in QUERY, PREPARE, and BATCH messages (PYTHON-678) +* Add IPv4Address/IPv6Address support for inet types (PYTHON-751) +* WriteType.CDC and VIEW missing (PYTHON-794) +* Warn on Cluster init if contact points are specified but LBP isn't (legacy mode) (PYTHON-812) +* Warn on Cluster init if contact points are specified but LBP isn't (exection profile mode) (PYTHON-838) +* Include hash of result set metadata in prepared stmt id (PYTHON-808) +* Add NO_COMPACT startup option (PYTHON-839) +* Add new exception type for CDC (PYTHON-837) +* Allow 0ms in ConstantSpeculativeExecutionPolicy (PYTHON-836) +* Add asyncio reactor (PYTHON-507) + +Bug Fixes +--------- +* Both _set_final_exception/result called for the same ResponseFuture (PYTHON-630) +* Use of DCAwareRoundRobinPolicy raises NoHostAvailable exception (PYTHON-781) +* Not create two sessions by default in CQLEngine (PYTHON-814) +* Bug when subclassing AyncoreConnection (PYTHON-827) +* Error at cleanup when closing the asyncore connections (PYTHON-829) +* Fix sites where `sessions` can change during iteration (PYTHON-793) +* cqlengine: allow min_length=0 for Ascii and Text column types (PYTHON-735) +* Rare exception when "sys.exit(0)" after query timeouts (PYTHON-752) +* Dont set the session keyspace when preparing statements (PYTHON-843) +* Use of DCAwareRoundRobinPolicy raises NoHostAvailable exception (PYTHON-781) + +Other +------ +* Remove DeprecationWarning when using WhiteListRoundRobinPolicy (PYTHON-810) +* Bump Cython dependency version to 0.27 (PYTHON-833) + +3.11.0 +====== +July 24, 2017 + + +Features +-------- +* Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762) +* Add HostFilterPolicy (PYTHON-761) + +Bug Fixes +--------- +* is_idempotent flag is not propagated from PreparedStatement to BoundStatement (PYTHON-736) +* Fix asyncore hang on exit (PYTHON-767) +* Driver takes several minutes to remove a bad host from session (PYTHON-762) +* Installation doesn't always fall back to no cython in Windows (PYTHON-763) +* Avoid to replace a connection that is supposed to shutdown (PYTHON-772) +* request_ids may not be returned to the pool (PYTHON-739) +* Fix murmur3 on big-endian systems (PYTHON-653) +* Ensure unused connections are closed if a Session is deleted by the GC (PYTHON-774) +* Fix .values_list by using db names internally (cqlengine) (PYTHON-785) + + +Other +----- +* Bump Cython dependency version to 0.25.2 (PYTHON-754) +* Fix DeprecationWarning when using lz4 (PYTHON-769) +* Deprecate WhiteListRoundRobinPolicy (PYTHON-759) +* Improve upgrade guide for materializing pages (PYTHON-464) +* Documentation for time/date specifies timestamp inupt as microseconds (PYTHON-717) +* Point to DSA Slack, not IRC, in docs index + +3.10.0 +====== +May 24, 2017 + +Features +-------- +* Add Duration type to cqlengine (PYTHON-750) +* Community PR review: Raise error on primary key update only if its value changed (PYTHON-705) +* get_query_trace() contract is ambiguous (PYTHON-196) + +Bug Fixes +--------- +* Queries using speculative execution policy timeout prematurely (PYTHON-755) +* Fix `map` where results are not consumed (PYTHON-749) +* Driver fails to encode Duration's with large values (PYTHON-747) +* UDT values are not updated correctly in CQLEngine (PYTHON-743) +* UDT types are not validated in CQLEngine (PYTHON-742) +* to_python is not implemented for types columns.Type and columns.Date in CQLEngine (PYTHON-741) +* Clients spin infinitely trying to connect to a host that is drained (PYTHON-734) +* Resulset.get_query_trace returns empty trace sometimes (PYTHON-730) +* Memory grows and doesn't get removed (PYTHON-720) +* Fix RuntimeError caused by change dict size during iteration (PYTHON-708) +* fix ExponentialReconnectionPolicy may throw OverflowError problem (PYTHON-707) +* Avoid using nonexistent prepared statement in ResponseFuture (PYTHON-706) + +Other +----- +* Update README (PYTHON-746) +* Test python versions 3.5 and 3.6 (PYTHON-737) +* Docs Warning About Prepare "select *" (PYTHON-626) +* Increase Coverage in CqlEngine Test Suite (PYTHON-505) +* Example SSL connection code does not verify server certificates (PYTHON-469) + +3.9.0 +===== + +Features +-------- +* cqlengine: remove elements by key from a map (PYTHON-688) + +Bug Fixes +--------- +* improve error handling when connecting to non-existent keyspace (PYTHON-665) +* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673) +* rare flake on integration.standard.test_cluster.ClusterTests.test_clone_shared_lbp (PYTHON-727) +* MontonicTimestampGenerator.__init__ ignores class defaults (PYTHON-728) +* race where callback or errback for request may not be called (PYTHON-733) +* cqlengine: model.update() should not update columns with a default value that hasn't changed (PYTHON-657) +* cqlengine: field value manager's explicit flag is True when queried back from cassandra (PYTHON-719) + +Other +----- +* Connection not closed in example_mapper (PYTHON-723) +* Remove mention of pre-2.0 C* versions from OSS 3.0+ docs (PYTHON-710) + +3.8.1 +===== +March 16, 2017 + +Bug Fixes +--------- + +* implement __le__/__ge__/__ne__ on some custom types (PYTHON-714) +* Fix bug in eventlet and gevent reactors that could cause hangs (PYTHON-721) +* Fix DecimalType regression (PYTHON-724) + +3.8.0 +===== + +Features +-------- + +* Quote index names in metadata CQL generation (PYTHON-616) +* On column deserialization failure, keep error message consistent between python and cython (PYTHON-631) +* TokenAwarePolicy always sends requests to the same replica for a given key (PYTHON-643) +* Added cql types to result set (PYTHON-648) +* Add __len__ to BatchStatement (PYTHON-650) +* Duration Type for Cassandra (PYTHON-655) +* Send flags with PREPARE message in v5 (PYTHON-684) + +Bug Fixes +--------- + +* Potential Timing issue if application exits prior to session pool initialization (PYTHON-636) +* "Host X.X.X.X has been marked down" without any exceptions (PYTHON-640) +* NoHostAvailable or OperationTimedOut when using execute_concurrent with a generator that inserts into more than one table (PYTHON-642) +* ResponseFuture creates Timers and don't cancel them even when result is received which leads to memory leaks (PYTHON-644) +* Driver cannot connect to Cassandra version > 3 (PYTHON-646) +* Unable to import model using UserType without setuping connection since 3.7 (PYTHON-649) +* Don't prepare queries on ignored hosts on_up (PYTHON-669) +* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673) +* Make client timestamps strictly monotonic (PYTHON-676) +* cassandra.cqlengine.connection.register_connection broken when hosts=None (PYTHON-692) + +Other +----- + +* Create a cqlengine doc section explaining None semantics (PYTHON-623) +* Resolve warnings in documentation generation (PYTHON-645) +* Cython dependency (PYTHON-686) +* Drop Support for Python 2.6 (PYTHON-690) + +3.7.1 +===== +October 26, 2016 + +Bug Fixes +--------- +* Cython upgrade has broken stable version of cassandra-driver (PYTHON-656) + +3.7.0 +===== +September 13, 2016 + +Features +-------- +* Add v5 protocol failure map (PYTHON-619) +* Don't return from initial connect on first error (PYTHON-617) +* Indicate failed column when deserialization fails (PYTHON-361) +* Let Cluster.refresh_nodes force a token map rebuild (PYTHON-349) +* Refresh UDTs after "keyspace updated" event with v1/v2 protocol (PYTHON-106) +* EC2 Address Resolver (PYTHON-198) +* Speculative query retries (PYTHON-218) +* Expose paging state in API (PYTHON-200) +* Don't mark host down while one connection is active (PYTHON-498) +* Query request size information (PYTHON-284) +* Avoid quadratic ring processing with invalid replication factors (PYTHON-379) +* Improve Connection/Pool creation concurrency on startup (PYTHON-82) +* Add beta version native protocol flag (PYTHON-614) +* cqlengine: Connections: support of multiple keyspaces and sessions (PYTHON-613) + +Bug Fixes +--------- +* Race when adding a pool while setting keyspace (PYTHON-628) +* Update results_metadata when prepared statement is reprepared (PYTHON-621) +* CQL Export for Thrift Tables (PYTHON-213) +* cqlengine: default value not applied to UserDefinedType (PYTHON-606) +* cqlengine: columns are no longer hashable (PYTHON-618) +* cqlengine: remove clustering keys from where clause when deleting only static columns (PYTHON-608) + +3.6.0 +===== +August 1, 2016 + +Features +-------- +* Handle null values in NumpyProtocolHandler (PYTHON-553) +* Collect greplin scales stats per cluster (PYTHON-561) +* Update mock unit test dependency requirement (PYTHON-591) +* Handle Missing CompositeType metadata following C* upgrade (PYTHON-562) +* Improve Host.is_up state for HostDistance.IGNORED hosts (PYTHON-551) +* Utilize v2 protocol's ability to skip result set metadata for prepared statement execution (PYTHON-71) +* Return from Cluster.connect() when first contact point connection(pool) is opened (PYTHON-105) +* cqlengine: Add ContextQuery to allow cqlengine models to switch the keyspace context easily (PYTHON-598) +* Standardize Validation between Ascii and Text types in Cqlengine (PYTHON-609) + +Bug Fixes +--------- +* Fix geventreactor with SSL support (PYTHON-600) +* Don't downgrade protocol version if explicitly set (PYTHON-537) +* Nonexistent contact point tries to connect indefinitely (PYTHON-549) +* Execute_concurrent can exceed max recursion depth in failure mode (PYTHON-585) +* Libev loop shutdown race (PYTHON-578) +* Include aliases in DCT type string (PYTHON-579) +* cqlengine: Comparison operators for Columns (PYTHON-595) +* cqlengine: disentangle default_time_to_live table option from model query default TTL (PYTHON-538) +* cqlengine: pk__token column name issue with the equality operator (PYTHON-584) +* cqlengine: Fix "__in" filtering operator converts True to string "True" automatically (PYTHON-596) +* cqlengine: Avoid LWTExceptions when updating columns that are part of the condition (PYTHON-580) +* cqlengine: Cannot execute a query when the filter contains all columns (PYTHON-599) +* cqlengine: routing key computation issue when a primary key column is overriden by model inheritance (PYTHON-576) + +3.5.0 +===== +June 27, 2016 + +Features +-------- +* Optional Execution Profiles for the core driver (PYTHON-569) +* API to get the host metadata associated with the control connection node (PYTHON-583) +* Expose CDC option in table metadata CQL (PYTHON-593) + +Bug Fixes +--------- +* Clean up Asyncore socket map when fork is detected (PYTHON-577) +* cqlengine: QuerySet only() is not respected when there are deferred fields (PYTHON-560) + +3.4.1 +===== +May 26, 2016 + +Bug Fixes +--------- +* Gevent connection closes on IO timeout (PYTHON-573) +* "dictionary changed size during iteration" with Python 3 (PYTHON-572) + +3.4.0 +===== +May 24, 2016 + +Features +-------- +* Include DSE version and workload in Host data (PYTHON-555) +* Add a context manager to Cluster and Session (PYTHON-521) +* Better Error Message for Unsupported Protocol Version (PYTHON-157) +* Make the error message explicitly state when an error comes from the server (PYTHON-412) +* Short Circuit meta refresh on topo change if NEW_NODE already exists (PYTHON-557) +* Show warning when the wrong config is passed to SimpleStatement (PYTHON-219) +* Return namedtuple result pairs from execute_concurrent (PYTHON-362) +* BatchStatement should enforce batch size limit in a better way (PYTHON-151) +* Validate min/max request thresholds for connection pool scaling (PYTHON-220) +* Handle or warn about multiple hosts with the same rpc_address (PYTHON-365) +* Write docs around working with datetime and timezones (PYTHON-394) + +Bug Fixes +--------- +* High CPU utilization when using asyncore event loop (PYTHON-239) +* Fix CQL Export for non-ASCII Identifiers (PYTHON-447) +* Make stress scripts Python 2.6 compatible (PYTHON-434) +* UnicodeDecodeError when unicode characters in key in BOP (PYTHON-559) +* WhiteListRoundRobinPolicy should resolve hosts (PYTHON-565) +* Cluster and Session do not GC after leaving scope (PYTHON-135) +* Don't wait for schema agreement on ignored nodes (PYTHON-531) +* Reprepare on_up with many clients causes node overload (PYTHON-556) +* None inserted into host map when control connection node is decommissioned (PYTHON-548) +* weakref.ref does not accept keyword arguments (github #585) + +3.3.0 +===== +May 2, 2016 + +Features +-------- +* Add an AddressTranslator interface (PYTHON-69) +* New Retry Policy Decision - try next host (PYTHON-285) +* Don't mark host down on timeout (PYTHON-286) +* SSL hostname verification (PYTHON-296) +* Add C* version to metadata or cluster objects (PYTHON-301) +* Options to Disable Schema, Token Metadata Processing (PYTHON-327) +* Expose listen_address of node we get ring information from (PYTHON-332) +* Use A-record with multiple IPs for contact points (PYTHON-415) +* Custom consistency level for populating query traces (PYTHON-435) +* Normalize Server Exception Types (PYTHON-443) +* Propagate exception message when DDL schema agreement fails (PYTHON-444) +* Specialized exceptions for metadata refresh methods failure (PYTHON-527) + +Bug Fixes +--------- +* Resolve contact point hostnames to avoid duplicate hosts (PYTHON-103) +* GeventConnection stalls requests when read is a multiple of the input buffer size (PYTHON-429) +* named_tuple_factory breaks with duplicate "cleaned" col names (PYTHON-467) +* Connection leak if Cluster.shutdown() happens during reconnection (PYTHON-482) +* HostConnection.borrow_connection does not block when all request ids are used (PYTHON-514) +* Empty field not being handled by the NumpyProtocolHandler (PYTHON-550) + +3.2.2 +===== +April 19, 2016 + +* Fix counter save-after-no-update (PYTHON-547) + +3.2.1 +===== +April 13, 2016 + +* Introduced an update to allow deserializer compilation with recently released Cython 0.24 (PYTHON-542) + +3.2.0 +===== +April 12, 2016 + +Features +-------- +* cqlengine: Warn on sync_schema type mismatch (PYTHON-260) +* cqlengine: Automatically defer fields with the '=' operator (and immutable values) in select queries (PYTHON-520) +* cqlengine: support non-equal conditions for LWT (PYTHON-528) +* cqlengine: sync_table should validate the primary key composition (PYTHON-532) +* cqlengine: token-aware routing for mapper statements (PYTHON-535) + +Bug Fixes +--------- +* Deleting a column in a lightweight transaction raises a SyntaxException #325 (PYTHON-249) +* cqlengine: make Token function works with named tables/columns #86 (PYTHON-272) +* comparing models with datetime fields fail #79 (PYTHON-273) +* cython date deserializer integer math should be aligned with CPython (PYTHON-480) +* db_field is not always respected with UpdateStatement (PYTHON-530) +* Sync_table fails on column.Set with secondary index (PYTHON-533) + +3.1.1 +===== +March 14, 2016 + +Bug Fixes +--------- +* cqlengine: Fix performance issue related to additional "COUNT" queries (PYTHON-522) + +3.1.0 +===== +March 10, 2016 + +Features +-------- +* Pass name of server auth class to AuthProvider (PYTHON-454) +* Surface schema agreed flag for DDL statements (PYTHON-458) +* Automatically convert float and int to Decimal on serialization (PYTHON-468) +* Eventlet Reactor IO improvement (PYTHON-495) +* Make pure Python ProtocolHandler available even when Cython is present (PYTHON-501) +* Optional Cython deserializer for bytes as bytearray (PYTHON-503) +* Add Session.default_serial_consistency_level (github #510) +* cqlengine: Expose prior state information via cqlengine LWTException (github #343, PYTHON-336) +* cqlengine: Collection datatype "contains" operators support (Cassandra 2.1) #278 (PYTHON-258) +* cqlengine: Add DISTINCT query operator (PYTHON-266) +* cqlengine: Tuple cqlengine api (PYTHON-306) +* cqlengine: Add support for UPDATE/DELETE ... IF EXISTS statements (PYTHON-432) +* cqlengine: Allow nested container types (PYTHON-478) +* cqlengine: Add ability to set query's fetch_size and limit (PYTHON-323) +* cqlengine: Internalize default keyspace from successive set_session (PYTHON-486) +* cqlengine: Warn when Model.create() on Counters (to be deprecated) (PYTHON-333) + +Bug Fixes +--------- +* Bus error (alignment issues) when running cython on some ARM platforms (PYTHON-450) +* Overflow when decoding large collections (cython) (PYTHON-459) +* Timer heap comparison issue with Python 3 (github #466) +* Cython deserializer date overflow at 2^31 - 1 (PYTHON-452) +* Decode error encountered when cython deserializing large map results (PYTHON-459) +* Don't require Cython for build if compiler or Python header not present (PYTHON-471) +* Unorderable types in task scheduling with Python 3 (h(PYTHON-473) +* cqlengine: Fix crash when updating a UDT column with a None value (github #467) +* cqlengine: Race condition in ..connection.execute with lazy_connect (PYTHON-310) +* cqlengine: doesn't support case sensitive column family names (PYTHON-337) +* cqlengine: UserDefinedType mandatory in create or update (PYTHON-344) +* cqlengine: db_field breaks UserType (PYTHON-346) +* cqlengine: UDT badly quoted (PYTHON-347) +* cqlengine: Use of db_field on primary key prevents querying except while tracing. (PYTHON-351) +* cqlengine: DateType.deserialize being called with one argument vs two (PYTHON-354) +* cqlengine: Querying without setting up connection now throws AttributeError and not CQLEngineException (PYTHON-395) +* cqlengine: BatchQuery multiple time executing execute statements. (PYTHON-445) +* cqlengine: Better error for management functions when no connection set (PYTHON-451) +* cqlengine: Handle None values for UDT attributes in cqlengine (PYTHON-470) +* cqlengine: Fix inserting None for model save (PYTHON-475) +* cqlengine: EQ doesn't map to a QueryOperator (setup race condition) (PYTHON-476) +* cqlengine: class.MultipleObjectsReturned has DoesNotExist as base class (PYTHON-489) +* cqlengine: Typo in cqlengine UserType __len__ breaks attribute assignment (PYTHON-502) + + +Other +----- + +* cqlengine: a major improvement on queryset has been introduced. It + is a lot more efficient to iterate large datasets: the rows are + now fetched on demand using the driver pagination. + +* cqlengine: the queryset len() and count() behaviors have changed. It + now executes a "SELECT COUNT(*)" of the query rather than returning + the size of the internal result_cache (loaded rows). On large + queryset, you might want to avoid using them due to the performance + cost. Note that trying to access objects using list index/slicing + with negative indices also requires a count to be + executed. + + + +3.0.0 +===== +November 24, 2015 + +Features +-------- +* Support datetime.date objects as a DateType (PYTHON-212) +* Add Cluster.update_view_metadata (PYTHON-407) +* QueryTrace option to populate partial trace sessions (PYTHON-438) +* Attach column names to ResultSet (PYTHON-439) +* Change default consistency level to LOCAL_ONE + +Bug Fixes +--------- +* Properly SerDes nested collections when protocol_version < 3 (PYTHON-215) +* Evict UDTs from UserType cache on change (PYTHON-226) +* Make sure query strings are always encoded UTF-8 (PYTHON-334) +* Track previous value of columns at instantiation in CQLengine (PYTHON-348) +* UDT CQL encoding does not work for unicode values (PYTHON-353) +* NetworkTopologyStrategy#make_token_replica_map does not account for multiple racks in a DC (PYTHON-378) +* Cython integer overflow on decimal type deserialization (PYTHON-433) +* Query trace: if session hasn't been logged, query trace can throw exception (PYTHON-442) + +3.0.0rc1 +======== +November 9, 2015 + +Features +-------- +* Process Modernized Schema Tables for Cassandra 3.0 (PYTHON-276, PYTHON-408, PYTHON-400, PYTHON-422) +* Remove deprecated features (PYTHON-292) +* Don't assign trace data to Statements (PYTHON-318) +* Normalize results return (PYTHON-368) +* Process Materialized View Metadata/Events (PYTHON-371) +* Remove blist as soft dependency (PYTHON-385) +* Change default consistency level to LOCAL_QUORUM (PYTHON-416) +* Normalize CQL query/export in metadata model (PYTHON-405) + +Bug Fixes +--------- +* Implementation of named arguments bind is non-pythonic (PYTHON-178) +* CQL encoding is incorrect for NaN and Infinity floats (PYTHON-282) +* Protocol downgrade issue with C* 2.0.x, 2.1.x, and python3, with non-default logging (PYTHON-409) +* ValueError when accessing usertype with non-alphanumeric field names (PYTHON-413) +* NumpyProtocolHandler does not play well with PagedResult (PYTHON-430) + +2.7.2 +===== +September 14, 2015 + +Bug Fixes +--------- +* Resolve CQL export error for UDF with zero parameters (PYTHON-392) +* Remove futures dep. for Python 3 (PYTHON-393) +* Avoid Python closure in cdef (supports earlier Cython compiler) (PYTHON-396) +* Unit test runtime issues (PYTHON-397,398) + +2.7.1 +===== +August 25, 2015 + +Bug Fixes +--------- +* Explicitly include extension source files in Manifest + +2.7.0 +===== +August 25, 2015 + +Cython is introduced, providing compiled extensions for core modules, and +extensions for optimized results deserialization. + +Features +-------- +* General Performance Improvements for Throughput (PYTHON-283) +* Improve synchronous request performance with Timers (PYTHON-108) +* Enable C Extensions for PyPy Runtime (PYTHON-357) +* Refactor SerDes functionality for pluggable interface (PYTHON-313) +* Cython SerDes Extension (PYTHON-377) +* Accept iterators/generators for execute_concurrent() (PYTHON-123) +* cythonize existing modules (PYTHON-342) +* Pure Python murmur3 implementation (PYTHON-363) +* Make driver tolerant of inconsistent metadata (PYTHON-370) + +Bug Fixes +--------- +* Drop Events out-of-order Cause KeyError on Processing (PYTHON-358) +* DowngradingConsistencyRetryPolicy doesn't check response count on write timeouts (PYTHON-338) +* Blocking connect does not use connect_timeout (PYTHON-381) +* Properly protect partition key in CQL export (PYTHON-375) +* Trigger error callbacks on timeout (PYTHON-294) + +2.6.0 +===== +July 20, 2015 + +Bug Fixes +--------- +* Output proper CQL for compact tables with no clustering columns (PYTHON-360) + +2.6.0c2 +======= +June 24, 2015 + +Features +-------- +* Automatic Protocol Version Downgrade (PYTHON-240) +* cqlengine Python 2.6 compatibility (PYTHON-288) +* Double-dollar string quote UDF body (PYTHON-345) +* Set models.DEFAULT_KEYSPACE when calling set_session (github #352) + +Bug Fixes +--------- +* Avoid stall while connecting to mixed version cluster (PYTHON-303) +* Make SSL work with AsyncoreConnection in python 2.6.9 (PYTHON-322) +* Fix Murmur3Token.from_key() on Windows (PYTHON-331) +* Fix cqlengine TimeUUID rounding error for Windows (PYTHON-341) +* Avoid invalid compaction options in CQL export for non-SizeTiered (PYTHON-352) + +2.6.0c1 +======= +June 4, 2015 + +This release adds support for Cassandra 2.2 features, including version +4 of the native protocol. + +Features +-------- +* Default load balancing policy to TokenAware(DCAware) (PYTHON-160) +* Configuration option for connection timeout (PYTHON-206) +* Support User Defined Function and Aggregate metadata in C* 2.2 (PYTHON-211) +* Surface request client in QueryTrace for C* 2.2+ (PYTHON-235) +* Implement new request failure messages in protocol v4+ (PYTHON-238) +* Metadata model now maps index meta by index name (PYTHON-241) +* Support new types in C* 2.2: date, time, smallint, tinyint (PYTHON-245, 295) +* cqle: add Double column type and remove Float overload (PYTHON-246) +* Use partition key column information in prepared response for protocol v4+ (PYTHON-277) +* Support message custom payloads in protocol v4+ (PYTHON-280, PYTHON-329) +* Deprecate refresh_schema and replace with functions for specific entities (PYTHON-291) +* Save trace id even when trace complete times out (PYTHON-302) +* Warn when registering client UDT class for protocol < v3 (PYTHON-305) +* Support client warnings returned with messages in protocol v4+ (PYTHON-315) +* Ability to distinguish between NULL and UNSET values in protocol v4+ (PYTHON-317) +* Expose CQL keywords in API (PYTHON-324) + +Bug Fixes +--------- +* IPv6 address support on Windows (PYTHON-20) +* Convert exceptions during automatic re-preparation to nice exceptions (PYTHON-207) +* cqle: Quote keywords properly in table management functions (PYTHON-244) +* Don't default to GeventConnection when gevent is loaded, but not monkey-patched (PYTHON-289) +* Pass dynamic host from SaslAuthProvider to SaslAuthenticator (PYTHON-300) +* Make protocol read_inet work for Windows (PYTHON-309) +* cqle: Correct encoding for nested types (PYTHON-311) +* Update list of CQL keywords used quoting identifiers (PYTHON-319) +* Make ConstantReconnectionPolicy work with infinite retries (github #327, PYTHON-325) +* Accept UUIDs with uppercase hex as valid in cqlengine (github #335) + +2.5.1 +===== +April 23, 2015 + +Bug Fixes +--------- +* Fix thread safety in DC-aware load balancing policy (PYTHON-297) +* Fix race condition in node/token rebuild (PYTHON-298) +* Set and send serial consistency parameter (PYTHON-299) + +2.5.0 +===== +March 30, 2015 + +Features +-------- +* Integrated cqlengine object mapping package +* Utility functions for converting timeuuids and datetime (PYTHON-99) +* Schema metadata fetch window randomized, config options added (PYTHON-202) +* Support for new Date and Time Cassandra types (PYTHON-190) + +Bug Fixes +--------- +* Fix index target for collection indexes (full(), keys()) (PYTHON-222) +* Thread exception during GIL cleanup (PYTHON-229) +* Workaround for rounding anomaly in datetime.utcfromtime (Python 3.4) (PYTHON-230) +* Normalize text serialization for lookup in OrderedMap (PYTHON-231) +* Support reading CompositeType data (PYTHON-234) +* Preserve float precision in CQL encoding (PYTHON-243) + +2.1.4 +===== +January 26, 2015 + +Features +-------- +* SaslAuthenticator for Kerberos support (PYTHON-109) +* Heartbeat for network device keepalive and detecting failures on idle connections (PYTHON-197) +* Support nested, frozen collections for Cassandra 2.1.3+ (PYTHON-186) +* Schema agreement wait bypass config, new call for synchronous schema refresh (PYTHON-205) +* Add eventlet connection support (PYTHON-194) + +Bug Fixes +--------- +* Schema meta fix for complex thrift tables (PYTHON-191) +* Support for 'unknown' replica placement strategies in schema meta (PYTHON-192) +* Resolve stream ID leak on set_keyspace (PYTHON-195) +* Remove implicit timestamp scaling on serialization of numeric timestamps (PYTHON-204) +* Resolve stream id collision when using SASL auth (PYTHON-210) +* Correct unhexlify usage for user defined type meta in Python3 (PYTHON-208) + +2.1.3 +===== +December 16, 2014 + +Features +-------- +* INFO-level log confirmation that a connection was opened to a node that was marked up (PYTHON-116) +* Avoid connecting to peer with incomplete metadata (PYTHON-163) +* Add SSL support to gevent reactor (PYTHON-174) +* Use control connection timeout in wait for schema agreement (PYTHON-175) +* Better consistency level representation in unavailable+timeout exceptions (PYTHON-180) +* Update schema metadata processing to accommodate coming schema modernization (PYTHON-185) + +Bug Fixes +--------- +* Support large negative timestamps on Windows (PYTHON-119) +* Fix schema agreement for clusters with peer rpc_addres 0.0.0.0 (PYTHON-166) +* Retain table metadata following keyspace meta refresh (PYTHON-173) +* Use a timeout when preparing a statement for all nodes (PYTHON-179) +* Make TokenAware routing tolerant of statements with no keyspace (PYTHON-181) +* Update add_collback to store/invoke multiple callbacks (PYTHON-182) +* Correct routing key encoding for composite keys (PYTHON-184) +* Include compression option in schema export string when disabled (PYTHON-187) + +2.1.2 +===== +October 16, 2014 + +Features +-------- +* Allow DCAwareRoundRobinPolicy to be constructed without a local_dc, defaulting + instead to the DC of a contact_point (PYTHON-126) +* Set routing key in BatchStatement.add() if none specified in batch (PYTHON-148) +* Improved feedback on ValueError using named_tuple_factory with invalid column names (PYTHON-122) + +Bug Fixes +--------- +* Make execute_concurrent compatible with Python 2.6 (PYTHON-159) +* Handle Unauthorized message on schema_triggers query (PYTHON-155) +* Pure Python sorted set in support of UDTs nested in collections (PYTON-167) +* Support CUSTOM index metadata and string export (PYTHON-165) + +2.1.1 +===== +September 11, 2014 + +Features +-------- +* Detect triggers and include them in CQL queries generated to recreate + the schema (github-189) +* Support IPv6 addresses (PYTHON-144) (note: basic functionality added; Windows + platform not addressed (PYTHON-20)) + +Bug Fixes +--------- +* Fix NetworkTopologyStrategy.export_for_schema (PYTHON-120) +* Keep timeout for paged results (PYTHON-150) + +Other +----- +* Add frozen<> type modifier to UDTs and tuples to handle CASSANDRA-7857 + +2.1.0 +===== +August 7, 2014 + +Bug Fixes +--------- +* Correctly serialize and deserialize null values in tuples and + user-defined types (PYTHON-110) +* Include additional header and lib dirs, allowing libevwrapper to build + against Homebrew and Mac Ports installs of libev (PYTHON-112 and 804dea3) + +2.1.0c1 +======= +July 25, 2014 + +Bug Fixes +--------- +* Properly specify UDTs for columns in CREATE TABLE statements +* Avoid moving retries to a new host when using request ID zero (PYTHON-88) +* Don't ignore fetch_size arguments to Statement constructors (github-151) +* Allow disabling automatic paging on a per-statement basis when it's + enabled by default for the session (PYTHON-93) +* Raise ValueError when tuple query parameters for prepared statements + have extra items (PYTHON-98) +* Correctly encode nested tuples and UDTs for non-prepared statements (PYTHON-100) +* Raise TypeError when a string is used for contact_points (github #164) +* Include User Defined Types in KeyspaceMetadata.export_as_string() (PYTHON-96) + +Other +----- +* Return list collection columns as python lists instead of tuples + now that tuples are a specific Cassandra type + +2.1.0b1 +======= +July 11, 2014 + +This release adds support for Cassandra 2.1 features, including version +3 of the native protocol. + +Features +-------- +* When using the v3 protocol, only one connection is opened per-host, and + throughput is improved due to reduced pooling overhead and lock contention. +* Support for user-defined types (Cassandra 2.1+) +* Support for tuple type in (limited usage Cassandra 2.0.9, full usage + in Cassandra 2.1) +* Protocol-level client-side timestamps (see Session.use_client_timestamp) +* Overridable type encoding for non-prepared statements (see Session.encoders) +* Configurable serial consistency levels for batch statements +* Use io.BytesIO for reduced CPU consumption (github #143) +* Support Twisted as a reactor. Note that a Twisted-compatible + API is not exposed (so no Deferreds), this is just a reactor + implementation. (github #135, PYTHON-8) + +Bug Fixes +--------- +* Fix references to xrange that do not go through "six" in libevreactor and + geventreactor (github #138) +* Make BoundStatements inherit fetch_size from their parent + PreparedStatement (PYTHON-80) +* Clear reactor state in child process after forking to prevent errors with + multiprocessing when the parent process has connected a Cluster before + forking (github #141) +* Don't share prepared statement lock across Cluster instances +* Format CompositeType and DynamicCompositeType columns correctly in + CREATE TABLE statements. +* Fix cassandra.concurrent behavior when dealing with automatic paging + (PYTHON-81) +* Properly defunct connections after protocol errors +* Avoid UnicodeDecodeError when query string is unicode (PYTHON-76) +* Correctly capture dclocal_read_repair_chance for tables and + use it when generating CREATE TABLE statements (PYTHON-84) +* Avoid race condition with AsyncoreConnection that may cause messages + to fail to be written until a new message is pushed +* Make sure cluster.metadata.partitioner and cluster.metadata.token_map + are populated when all nodes in the cluster are included in the + contact points (PYTHON-90) +* Make Murmur3 hash match Cassandra's hash for all values (PYTHON-89, + github #147) +* Don't attempt to reconnect to hosts that should be ignored (according + to the load balancing policy) when a notification is received that the + host is down. +* Add CAS WriteType, avoiding KeyError on CAS write timeout (PYTHON-91) + +2.0.2 +===== +June 10, 2014 + +Bug Fixes +--------- +* Add six to requirements.txt +* Avoid KeyError during schema refresh when a keyspace is dropped + and TokenAwarePolicy is not in use +* Avoid registering multiple atexit cleanup functions when the + asyncore event loop is restarted multiple times +* Delay initialization of reactors in order to avoid problems + with shared state when using multiprocessing (PYTHON-60) +* Add python-six to debian dependencies, move python-blist to recommends +* Fix memory leak when libev connections are created and + destroyed (github #93) +* Ensure token map is rebuilt when hosts are removed from the cluster + +2.0.1 +===== +May 28, 2014 + +Bug Fixes +--------- +* Fix check for Cluster.is_shutdown in in @run_in_executor + decorator + +2.0.0 +===== +May 28, 2014 + +Features +-------- +* Make libev C extension Python3-compatible (PYTHON-70) +* Support v2 protocol authentication (PYTHON-73, github #125) + +Bug Fixes +--------- +* Fix murmur3 C extension compilation under Python3.4 (github #124) + +Merged From 1.x +--------------- + +Features +^^^^^^^^ +* Add Session.default_consistency_level (PYTHON-14) + +Bug Fixes +^^^^^^^^^ +* Don't strip trailing underscores from column names when using the + named_tuple_factory (PYTHON-56) +* Ensure replication factors are ints for NetworkTopologyStrategy + to avoid TypeErrors (github #120) +* Pass WriteType instance to RetryPolicy.on_write_timeout() instead + of the string name of the write type. This caused write timeout + errors to always be rethrown instead of retrying. (github #123) +* Avoid submitting tasks to the ThreadPoolExecutor after shutdown. With + retries enabled, this could cause Cluster.shutdown() to hang under + some circumstances. +* Fix unintended rebuild of token replica map when keyspaces are + discovered (on startup), added, or updated and TokenAwarePolicy is not + in use. +* Avoid rebuilding token metadata when cluster topology has not + actually changed +* Avoid preparing queries for hosts that should be ignored (such as + remote hosts when using the DCAwareRoundRobinPolicy) (PYTHON-75) + +Other +^^^^^ +* Add 1 second timeout to join() call on event loop thread during + interpreter shutdown. This can help to prevent the process from + hanging during shutdown. + +2.0.0b1 +======= +May 6, 2014 + +Upgrading from 1.x +------------------ +Cluster.shutdown() should always be called when you are done with a +Cluster instance. If it is not called, there are no guarantees that the +driver will not hang. However, if you *do* have a reproduceable case +where Cluster.shutdown() is not called and the driver hangs, please +report it so that we can attempt to fix it. + +If you're using the 2.0 driver against Cassandra 1.2, you will need +to set your protocol version to 1. For example: + + cluster = Cluster(..., protocol_version=1) + +Features +-------- +* Support v2 of Cassandra's native protocol, which includes the following + new features: automatic query paging support, protocol-level batch statements, + and lightweight transactions +* Support for Python 3.3 and 3.4 +* Allow a default query timeout to be set per-Session + +Bug Fixes +--------- +* Avoid errors during interpreter shutdown (the driver attempts to cleanup + daemonized worker threads before interpreter shutdown) + +Deprecations +------------ +The following functions have moved from cassandra.decoder to cassandra.query. +The original functions have been left in place with a DeprecationWarning for +now: + +* cassandra.decoder.tuple_factory has moved to cassandra.query.tuple_factory +* cassandra.decoder.named_tuple_factory has moved to cassandra.query.named_tuple_factory +* cassandra.decoder.dict_factory has moved to cassandra.query.dict_factory +* cassandra.decoder.ordered_dict_factory has moved to cassandra.query.ordered_dict_factory + +Exceptions that were in cassandra.decoder have been moved to cassandra.protocol. If +you handle any of these exceptions, you must adjust the code accordingly. + +1.1.2 +===== +May 8, 2014 + +Features +-------- +* Allow a specific compression type to be requested for communications with + Cassandra and prefer lz4 if available + +Bug Fixes +--------- +* Update token metadata (for TokenAware calculations) when a node is removed + from the ring +* Fix file handle leak with gevent reactor due to blocking Greenlet kills when + closing excess connections +* Avoid handling a node coming up multiple times due to a reconnection attempt + succeeding close to the same time that an UP notification is pushed +* Fix duplicate node-up handling, which could result in multiple reconnectors + being started as well as the executor threads becoming deadlocked, preventing + future node up or node down handling from being executed. +* Handle exhausted ReconnectionPolicy schedule correctly + +Other +----- +* Don't log at ERROR when a connection is closed during the startup + communications +* Mke scales, blist optional dependencies + +1.1.1 +===== +April 16, 2014 + +Bug Fixes +--------- +* Fix unconditional import of nose in setup.py (github #111) + +1.1.0 +===== +April 16, 2014 + +Features +-------- +* Gevent is now supported through monkey-patching the stdlib (PYTHON-7, + github issue #46) +* Support static columns in schemas, which are available starting in + Cassandra 2.1. (github issue #91) +* Add debian packaging (github issue #101) +* Add utility methods for easy concurrent execution of statements. See + the new cassandra.concurrent module. (github issue #7) + +Bug Fixes +--------- +* Correctly supply compaction and compression parameters in CREATE statements + for tables when working with Cassandra 2.0+ +* Lowercase boolean literals when generating schemas +* Ignore SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE socket errors. Previously, + these resulted in the connection being defuncted, but they can safely be + ignored by the driver. +* Don't reconnect the control connection every time Cluster.connect() is + called +* Avoid race condition that could leave ResponseFuture callbacks uncalled + if the callback was added outside of the event loop thread (github issue #95) +* Properly escape keyspace name in Session.set_keyspace(). Previously, the + keyspace name was quoted, but any quotes in the string were not escaped. +* Avoid adding hosts to the load balancing policy before their datacenter + and rack information has been set, if possible. +* Avoid KeyError when updating metadata after droping a table (github issues + #97, #98) +* Use tuples instead of sets for DCAwareLoadBalancingPolicy to ensure equal + distribution of requests + +Other +----- +* Don't ignore column names when parsing typestrings. This is needed for + user-defined type support. (github issue #90) +* Better error message when libevwrapper is not found +* Only try to import scales when metrics are enabled (github issue #92) +* Cut down on the number of queries executing when a new Cluster + connects and when the control connection has to reconnect (github issue #104, + PYTHON-59) +* Issue warning log when schema versions do not match + +1.0.2 +===== +March 4, 2014 + +Bug Fixes +--------- +* With asyncorereactor, correctly handle EAGAIN/EWOULDBLOCK when the message from + Cassandra is a multiple of the read buffer size. Previously, if no more data + became available to read on the socket, the message would never be processed, + resulting in an OperationTimedOut error. +* Double quote keyspace, table and column names that require them (those using + uppercase characters or keywords) when generating CREATE statements through + KeyspaceMetadata and TableMetadata. +* Decode TimestampType as DateType. (Cassandra replaced DateType with + TimestampType to fix sorting of pre-unix epoch dates in CASSANDRA-5723.) +* Handle latest table options when parsing the schema and generating + CREATE statements. +* Avoid 'Set changed size during iteration' during query plan generation + when hosts go up or down + +Other +----- +* Remove ignored ``tracing_enabled`` parameter for ``SimpleStatement``. The + correct way to trace a query is by setting the ``trace`` argument to ``True`` + in ``Session.execute()`` and ``Session.execute_async()``. +* Raise TypeError instead of cassandra.query.InvalidParameterTypeError when + a parameter for a prepared statement has the wrong type; remove + cassandra.query.InvalidParameterTypeError. +* More consistent type checking for query parameters +* Add option to a return special object for empty string values for non-string + columns + +1.0.1 +===== +Feb 19, 2014 + +Bug Fixes +--------- +* Include table indexes in ``KeyspaceMetadata.export_as_string()`` +* Fix broken token awareness on ByteOrderedPartitioner +* Always close socket when defuncting error'ed connections to avoid a potential + file descriptor leak +* Handle "custom" types (such as the replaced DateType) correctly +* With libevreactor, correctly handle EAGAIN/EWOULDBLOCK when the message from + Cassandra is a multiple of the read buffer size. Previously, if no more data + became available to read on the socket, the message would never be processed, + resulting in an OperationTimedOut error. +* Don't break tracing when a Session's row_factory is not the default + namedtuple_factory. +* Handle data that is already utf8-encoded for UTF8Type values +* Fix token-aware routing for tokens that fall before the first node token in + the ring and tokens that exactly match a node's token +* Tolerate null source_elapsed values for Trace events. These may not be + set when events complete after the main operation has already completed. + +Other +----- +* Skip sending OPTIONS message on connection creation if compression is + disabled or not available and a CQL version has not been explicitly + set +* Add details about errors and the last queried host to ``OperationTimedOut`` + +1.0.0 Final +=========== +Jan 29, 2014 + +Bug Fixes +--------- +* Prevent leak of Scheduler thread (even with proper shutdown) +* Correctly handle ignored hosts, which are common with the + DCAwareRoundRobinPolicy +* Hold strong reference to prepared statement while executing it to avoid + garbage collection +* Add NullHandler logging handler to the cassandra package to avoid + warnings about there being no configured logger +* Fix bad handling of nodes that have been removed from the cluster +* Properly escape string types within cql collections +* Handle setting the same keyspace twice in a row +* Avoid race condition during schema agreement checks that could result + in schema update queries returning before all nodes had seen the change +* Preserve millisecond-level precision in datetimes when performing inserts + with simple (non-prepared) statements +* Properly defunct connections when libev reports an error by setting + errno instead of simply logging the error +* Fix endless hanging of some requests when using the libev reactor +* Always start a reconnection process when we fail to connect to + a newly bootstrapped node +* Generators map to CQL lists, not key sequences +* Always defunct connections when an internal operation fails +* Correctly break from handle_write() if nothing was sent (asyncore + reactor only) +* Avoid potential double-erroring of callbacks when a connection + becomes defunct + +Features +-------- +* Add default query timeout to ``Session`` +* Add timeout parameter to ``Session.execute()`` +* Add ``WhiteListRoundRobinPolicy`` as a load balancing policy option +* Support for consistency level ``LOCAL_ONE`` +* Make the backoff for fetching traces exponentially increasing and + configurable + +Other +----- +* Raise Exception if ``TokenAwarePolicy`` is used against a cluster using the + ``Murmur3Partitioner`` if the murmur3 C extension has not been compiled +* Add encoder mapping for ``OrderedDict`` +* Use timeouts on all control connection queries +* Benchmark improvements, including command line options and eay + multithreading support +* Reduced lock contention when using the asyncore reactor +* Warn when non-datetimes are used for 'timestamp' column values in + prepared statements +* Add requirements.txt and test-requirements.txt +* TravisCI integration for running unit tests against Python 2.6, + Python 2.7, and PyPy + +1.0.0b7 +======= +Nov 12, 2013 + +This release makes many stability improvements, especially around +prepared statements and node failure handling. In particular, +several cases where a request would never be completed (and as a +result, leave the application hanging) have been resolved. + +Features +-------- +* Add `timeout` kwarg to ``ResponseFuture.result()`` +* Create connection pools to all hosts in parallel when initializing + new Sesssions. + +Bug Fixes +--------- +* Properly set exception on ResponseFuture when a query fails + against all hosts +* Improved cleanup and reconnection efforts when reconnection fails + on a node that has recently come up +* Use correct consistency level when retrying failed operations + against a different host. (An invalid consistency level was being + used, causing the retry to fail.) +* Better error messages for failed ``Session.prepare()`` opertaions +* Prepare new statements against all hosts in parallel (formerly + sequential) +* Fix failure to save the new current keyspace on connections. (This + could cause problems for prepared statements and lead to extra + operations to continuously re-set the keyspace.) +* Avoid sharing ``LoadBalancingPolicies`` across ``Cluster`` instances. (When + a second ``Cluster`` was connected, it effectively mark nodes down for the + first ``Cluster``.) +* Better handling of failures during the re-preparation sequence for + unrecognized prepared statements +* Throttle trashing of underutilized connections to avoid trashing newly + created connections +* Fix race condition which could result in trashed connections being closed + before the last operations had completed +* Avoid preparing statements on the event loop thread (which could lead to + deadlock) +* Correctly mark up non-contact point nodes discovered by the control + connection. (This lead to prepared statements not being prepared + against those hosts, generating extra traffic later when the + statements were executed and unrecognized.) +* Correctly handle large messages through libev +* Add timeout to schema agreement check queries +* More complete (and less contended) locking around manipulation of the + pending message deque for libev connections + +Other +----- +* Prepare statements in batches of 10. (When many prepared statements + are in use, this allows the driver to start utilizing nodes that + were restarted more quickly.) +* Better debug logging around connection management +* Don't retain unreferenced prepared statements in the local cache. + (If many different prepared statements were created, this would + increase memory usage and greatly increase the amount of time + required to begin utilizing a node that was added or marked + up.) + +1.0.0b6 +======= +Oct 22, 2013 + +Bug Fixes +--------- +* Use lazy string formatting when logging +* Avoid several deadlock scenarios, especially when nodes go down +* Avoid trashing newly created connections due to insufficient traffic +* Gracefully handle un-handled Exceptions when erroring callbacks + +Other +----- +* Node state listeners (which are called when a node is added, removed, + goes down, or comes up) should now be registered through + Cluster.register_listener() instead of through a host's HealthMonitor + (which has been removed) + + +1.0.0b5 +======== +Oct 10, 2013 + +Features +-------- +* SSL support + +Bug Fixes +--------- +* Avoid KeyError when building replica map for NetworkTopologyStrategy +* Work around python bug which causes deadlock when a thread imports + the utf8 module +* Handle no blist library, which is not compatible with pypy +* Avoid deadlock triggered by a keyspace being set on a connection (which + may happen automatically for new connections) + +Other +----- +* Switch packaging from Distribute to setuptools, improved C extension + support +* Use PEP 386 compliant beta and post-release versions + +1.0.0-beta4 +=========== +Sep 24, 2013 + +Features +-------- +* Handle new blob syntax in Cassandra 2.0 by accepting bytearray + objects for blob values +* Add cql_version kwarg to Cluster.__init__ + +Bug Fixes +--------- +* Fix KeyError when building token map with NetworkTopologyStrategy + keyspaces (this prevented a Cluster from successfully connecting + at all). +* Don't lose default consitency level from parent PreparedStatement + when creating BoundStatements + 1.0.0-beta3 =========== -(In progress) +Sep 20, 2013 Features -------- @@ -25,6 +1813,7 @@ Bug Fixes * Avoid a potential loss of precision on float constants due to string formatting * Actually utilize non-standard ports set on Cluster objects +* Fix export of schema as a set of CQL queries Other ----- @@ -33,6 +1822,7 @@ Other * Raise InvalidTypeParameterError when parameters of the wrong type are used with statements * Make all tests compatible with Python 2.6 +* Add 1s timeout for opening new connections 1.0.0-beta2 =========== diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst new file mode 100644 index 0000000000..f71ebabdbb --- /dev/null +++ b/CONTRIBUTING.rst @@ -0,0 +1,29 @@ +Contributing +============ + +Contributions are welcome in the form of bug reports or pull requests. + +Bug Reports +----------- +Quality bug reports are welcome at the `CASSPYTHON project `_ +of the ASF JIRA. + +There are plenty of `good resources `_ describing how to create +good bug reports. They will not be repeated in detail here, but in general, the bug report include where appropriate: + +* relevant software versions (Python runtime, driver version, cython version, server version) +* details for how to produce (e.g. a test script or written procedure) + * any effort to isolate the issue in reproduction is much-appreciated +* stack trace from a crashed runtime + +Pull Requests +------------- +If you're able to fix a bug yourself, you can `fork the repository `_ and submit a `Pull Request `_ with the fix. +Please include tests demonstrating the issue and fix. For examples of how to run the tests, consult the `dev README `_. + +Design and Implementation Guidelines +------------------------------------ +- We have integrations (notably Cassandra cqlsh) that require pure Python and minimal external dependencies. We try to avoid new external dependencies. Where compiled extensions are concerned, there should always be a pure Python fallback implementation. +- This project follows `semantic versioning `_, so breaking API changes will only be introduced in major versions. +- Legacy ``cqlengine`` has varying degrees of overreaching client-side validation. Going forward, we will avoid client validation where server feedback is adequate and not overly expensive. +- When writing tests, try to achieve maximal coverage in unit tests (where it is faster to run across many runtimes). Integration tests are good for things where we need to test server interaction, or where it is important to test across different server versions (emulating in unit tests would not be effective). diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000000..46624d57d2 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,779 @@ +#!groovy +/* + +There are multiple combinations to test the python driver. + +Test Profiles: + + Full: Execute all unit and integration tests, including long tests. + Standard: Execute unit and integration tests. + Smoke Tests: Execute a small subset of tests. + EVENT_LOOP: Execute a small subset of tests selected to test EVENT_LOOPs. + +Matrix Types: + + Full: All server versions, python runtimes tested with and without Cython. + Cassandra: All cassandra server versions. + Dse: All dse server versions. + Hcd: All hcd server versions. + Smoke: CI-friendly configurations. Currently-supported Python version + modern Cassandra/DSE instances. + We also avoid cython since it's tested as part of the nightlies + +Parameters: + + EVENT_LOOP: 'LIBEV' (Default), 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED' + CYTHON: Default, 'True', 'False' + +*/ + +@Library('dsdrivers-pipeline-lib@develop') +import com.datastax.jenkins.drivers.python.Slack + +slack = new Slack() + +DEFAULT_CASSANDRA = ['3.11', '4.0', '4.1', '5.0'] +DEFAULT_DSE = ['dse-5.1.35', 'dse-6.8.30', 'dse-6.9.0'] +DEFAULT_HCD = ['hcd-1.0.0'] +DEFAULT_RUNTIME = ['3.10.19', '3.11.14', '3.12.12', '3.13.12', '3.14.3'] +DEFAULT_CYTHON = ["True", "False"] +matrices = [ + "FULL": [ + "SERVER": DEFAULT_CASSANDRA + DEFAULT_DSE, + "RUNTIME": DEFAULT_RUNTIME, + "CYTHON": DEFAULT_CYTHON + ], + "CASSANDRA": [ + "SERVER": DEFAULT_CASSANDRA, + "RUNTIME": DEFAULT_RUNTIME, + "CYTHON": DEFAULT_CYTHON + ], + "DSE": [ + "SERVER": DEFAULT_DSE, + "RUNTIME": DEFAULT_RUNTIME, + "CYTHON": DEFAULT_CYTHON + ], + "SMOKE": [ + "SERVER": DEFAULT_CASSANDRA.takeRight(2) + DEFAULT_DSE.takeRight(2) + DEFAULT_HCD.takeRight(1), + "RUNTIME": DEFAULT_RUNTIME.take(1) + DEFAULT_RUNTIME.takeRight(1), + "CYTHON": ["True"] + ] +] + +def initializeSlackContext() { + /* + Based on git branch/commit, configure the build context and env vars. + */ + + def driver_display_name = 'Cassandra Python Driver' + if (env.GIT_URL.contains('riptano/python-driver')) { + driver_display_name = 'private ' + driver_display_name + } else if (env.GIT_URL.contains('python-dse-driver')) { + driver_display_name = 'DSE Python Driver' + } + env.DRIVER_DISPLAY_NAME = driver_display_name + env.GIT_SHA = "${env.GIT_COMMIT.take(7)}" + env.GITHUB_PROJECT_URL = "https://${GIT_URL.replaceFirst(/(git@|http:\/\/|https:\/\/)/, '').replace(':', '/').replace('.git', '')}" + env.GITHUB_BRANCH_URL = "${env.GITHUB_PROJECT_URL}/tree/${env.BRANCH_NAME}" + env.GITHUB_COMMIT_URL = "${env.GITHUB_PROJECT_URL}/commit/${env.GIT_COMMIT}" +} + +def getBuildContext() { + /* + Based on schedule and parameters, configure the build context and env vars. + */ + + def PROFILE = "${params.PROFILE}" + def EVENT_LOOP = "${params.EVENT_LOOP.toLowerCase()}" + + matrixType = params.MATRIX != "DEFAULT" ? params.MATRIX : "SMOKE" + matrix = matrices[matrixType].clone() + + // Check if parameters were set explicitly + if (params.CYTHON != "DEFAULT") { + matrix["CYTHON"] = [params.CYTHON] + } + + if (params.SERVER_VERSION != "DEFAULT") { + matrix["SERVER"] = [params.SERVER_VERSION] + } + + if (params.PYTHON_VERSION != "DEFAULT") { + matrix["RUNTIME"] = [params.PYTHON_VERSION] + } + + if (params.CI_SCHEDULE == "WEEKNIGHTS") { + matrix["SERVER"] = params.CI_SCHEDULE_SERVER_VERSION.split(' ') + matrix["RUNTIME"] = params.CI_SCHEDULE_PYTHON_VERSION.split(' ') + } + + context = [ + vars: [ + "PROFILE=${PROFILE}", + "EVENT_LOOP=${EVENT_LOOP}" + ], + matrix: matrix + ] + + return context +} + +def buildAndTest(context) { + initializeEnvironment() + installDriver() + + try { + executeTests() + } finally { + junit testResults: '*_results.xml' + } +} + +def getMatrixBuilds(buildContext) { + def tasks = [:] + matrix = buildContext.matrix + + matrix["SERVER"].each { serverVersion -> + matrix["RUNTIME"].each { runtimeVersion -> + matrix["CYTHON"].each { cythonFlag -> + def taskVars = [ + "CASSANDRA_VERSION=${serverVersion}", + "PYTHON_VERSION=${runtimeVersion}", + "CYTHON_ENABLED=${cythonFlag}" + ] + def cythonDesc = cythonFlag == "True" ? ", Cython": "" + tasks["${serverVersion}, py${runtimeVersion}${cythonDesc}"] = { + node("${OS_VERSION}") { + scm_variables = checkout scm + env.GIT_COMMIT = scm_variables.get('GIT_COMMIT') + env.GIT_URL = scm_variables.get('GIT_URL') + initializeSlackContext() + + if (env.BUILD_STATED_SLACK_NOTIFIED != 'true') { + slack.notifyChannel() + } + + withEnv(taskVars) { + buildAndTest(context) + } + } + } + } + } + } + return tasks +} + +def initializeEnvironment() { + sh label: 'Initialize the environment', script: '''#!/bin/bash -lex + + # One of the integration tests relies on socat so let's install that here + sudo apt-get install -y socat moreutils + + pyenv shell ${PYTHON_VERSION} + python -m venv jenkins-venv + . ./jenkins-venv/bin/activate + pip install --upgrade pip setuptools wheel + + # Install a version of pyyaml<6.0 compatible with ccm-3.1.5 as of Aug 2023 + # this works around the python-3.10+ compatibility problem as described in DSP-23524 + pip install "pyyaml<6.0" --no-build-isolation + pip install ${HOME}/ccm + ''' + + // Determine if server version is Apache CassandraⓇ or DataStax Enterprise + if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') { + if (env.PYTHON_VERSION =~ /3\.12\.\d+/) { + echo "Cannot install DSE dependencies for Python 3.12.x; installing Apache CassandraⓇ requirements only. See PYTHON-1368 for more detail." + sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + pip install -r test-requirements.txt + ''' + } + else { + sh label: 'Install DataStax Enterprise requirements', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + pip install -r test-datastax-requirements.txt + ''' + } + } else { + sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + pip install -r test-requirements.txt + ''' + + sh label: 'Uninstall the geomet dependency since it is not required for Cassandra', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + pip uninstall -y geomet + ''' + } + + sh label: 'Download Apache CassandraⓇ or DataStax Enterprise', script: '''#!/bin/bash -lex + . ${CCM_ENVIRONMENT_SHELL} ${CASSANDRA_VERSION} + ''' + + if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') { + env.DSE_FIXED_VERSION = env.CASSANDRA_VERSION.split('-')[1] + sh label: 'Update environment for DataStax Enterprise', script: '''#!/bin/bash -le + cat >> ${HOME}/environment.txt << ENVIRONMENT_EOF +CCM_CASSANDRA_VERSION=${DSE_FIXED_VERSION} # maintain for backwards compatibility +CCM_VERSION=${DSE_FIXED_VERSION} +CCM_SERVER_TYPE=dse +DSE_VERSION=${DSE_FIXED_VERSION} +CCM_IS_DSE=true +CCM_BRANCH=${DSE_FIXED_VERSION} +DSE_BRANCH=${DSE_FIXED_VERSION} +ENVIRONMENT_EOF + ''' + } else if (env.CASSANDRA_VERSION.split('-')[0] == 'hcd') { + env.HCD_FIXED_VERSION = env.CASSANDRA_VERSION.split('-')[1] + sh label: 'Update environment for DataStax Enterprise', script: '''#!/bin/bash -le + cat >> ${HOME}/environment.txt << ENVIRONMENT_EOF +CCM_CASSANDRA_VERSION=${HCD_FIXED_VERSION} # maintain for backwards compatibility +CCM_VERSION=${HCD_FIXED_VERSION} +CCM_SERVER_TYPE=hcd +HCD_VERSION=${HCD_FIXED_VERSION} +CCM_IS_HCD=true +CCM_BRANCH=${HCD_FIXED_VERSION} +HCD_BRANCH=${HCD_FIXED_VERSION} +ENVIRONMENT_EOF + ''' + } + + sh label: 'Display Python and environment information', script: '''#!/bin/bash -le + . ./jenkins-venv/bin/activate + + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + python --version + pip --version + printenv | sort + ''' +} + +def installDriver() { + sh label: 'Install the driver and compile with C extensions with Cython', script: '''#!/bin/bash -lex + # Update libev includes and libs to point to the right spot for this install + pyenv shell ${PYTHON_VERSION} + python -m venv libev-venv + . ./libev-venv/bin/activate + pip install toml + python fix-jenkinsfile-libev.py ./pyproject.toml "/usr/include" "/usr/lib/x86_64-linux-gnu" | sponge ./pyproject.toml + deactivate + + ls /usr/include/ev.h + ls /usr/lib/x86_64-linux-gnu/libev* + + # Now that we've made relevant mods to our local pyproject.toml we're ready to build the driver + . ./jenkins-venv/bin/activate + + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + cat ./pyproject.toml + pip install --verbose --editable . + + # After install display a list of packages in the venv for auditing + pip list + ''' +} + + +def executeStandardTests() { + + try { + sh label: 'Execute unit tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + failure=0 + EVENT_LOOP=${EVENT_LOOP} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=unit_results.xml tests/unit/ || failure=1 + EVENT_LOOP_MANAGER=eventlet VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=unit_eventlet_results.xml tests/unit/io/test_eventletreactor.py || failure=1 + EVENT_LOOP_MANAGER=gevent VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=unit_gevent_results.xml tests/unit/io/test_geventreactor.py || failure=1 + exit $failure + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + try { + sh label: 'Execute Simulacron integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + failure=0 + SIMULACRON_JAR="${HOME}/simulacron.jar" + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_results.xml tests/integration/simulacron/ || true + + # Run backpressure tests separately to avoid memory issue + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_1_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_paused_connections || failure=1 + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_2_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_queued_requests_timeout || failure=1 + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_3_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_cluster_busy || failure=1 + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_4_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_node_busy || failure=1 + exit $failure + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + try { + sh label: 'Execute CQL engine integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=cqle_results.xml tests/integration/cqlengine/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + try { + sh label: 'Execute Apache CassandraⓇ integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=standard_results.xml tests/integration/standard/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + if (env.CASSANDRA_VERSION.split('-')[0] == 'dse' && env.CASSANDRA_VERSION.split('-')[1] != '4.8') { + if (env.PYTHON_VERSION =~ /3\.12\.\d+/) { + echo "Cannot install DSE dependencies for Python 3.12.x. See PYTHON-1368 for more detail." + } + else { + try { + sh label: 'Execute DataStax Enterprise integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} ADS_HOME="${HOME}/" VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=dse_results.xml tests/integration/advanced/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + } + } + + try { + sh label: 'Execute DataStax Astra integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CLOUD_PROXY_PATH="${HOME}/proxy/" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=advanced_results.xml tests/integration/cloud/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + if (env.PROFILE == 'FULL') { + try { + sh label: 'Execute long running integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=tests/integration/long/upgrade --junit-xml=long_results.xml tests/integration/long/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + } +} + +def executeDseSmokeTests() { + sh label: 'Execute profile DataStax Enterprise smoke test integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} DSE_VERSION=${DSE_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=standard_results.xml tests/integration/standard/test_dse.py + ''' +} + +def executeEventLoopTests() { + sh label: 'Execute profile event loop manager integration tests', script: '''#!/bin/bash -lex + . ./jenkins-venv/bin/activate + + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP_TESTS=( + "tests/integration/standard/test_cluster.py" + "tests/integration/standard/test_concurrent.py" + "tests/integration/standard/test_connection.py" + "tests/integration/standard/test_control_connection.py" + "tests/integration/standard/test_metrics.py" + "tests/integration/standard/test_query.py" + "tests/integration/simulacron/test_endpoint.py" + "tests/integration/long/test_ssl.py" + ) + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=standard_results.xml ${EVENT_LOOP_TESTS[@]} + ''' +} + +def executeTests() { + switch(env.PROFILE) { + case 'DSE-SMOKE-TEST': + executeDseSmokeTests() + break + case 'EVENT_LOOP': + executeEventLoopTests() + break + default: + executeStandardTests() + break + } +} + + +// TODO move this in the shared lib +def getDriverMetricType() { + metric_type = 'oss' + if (env.GIT_URL.contains('riptano/python-driver')) { + metric_type = 'oss-private' + } else if (env.GIT_URL.contains('python-dse-driver')) { + metric_type = 'dse' + } + return metric_type +} + +def describeBuild(buildContext) { + script { + def runtimes = buildContext.matrix["RUNTIME"] + def serverVersions = buildContext.matrix["SERVER"] + def numBuilds = runtimes.size() * serverVersions.size() * buildContext.matrix["CYTHON"].size() + currentBuild.displayName = "${env.PROFILE} (${env.EVENT_LOOP} | ${numBuilds} builds)" + currentBuild.description = "${env.PROFILE} build testing servers (${serverVersions.join(', ')}) against Python (${runtimes.join(', ')}) using ${env.EVENT_LOOP} event loop manager" + } +} + +// branch pattern for cron +def branchPatternCron() { + ~"(master)" +} + +pipeline { + agent none + + // Global pipeline timeout + options { + disableConcurrentBuilds() + timeout(time: 10, unit: 'HOURS') // TODO timeout should be per build + buildDiscarder(logRotator(artifactNumToKeepStr: '10', // Keep only the last 10 artifacts + numToKeepStr: '50')) // Keep only the last 50 build records + } + + parameters { + choice( + name: 'ADHOC_BUILD_TYPE', + choices: ['BUILD', 'BUILD-AND-EXECUTE-TESTS'], + description: '''

Perform a adhoc build operation

+ + + + + + + + + + + + + + + +
ChoiceDescription
BUILDPerforms a Per-Commit build
BUILD-AND-EXECUTE-TESTSPerforms a build and executes the integration and unit tests
''') + choice( + name: 'PROFILE', + choices: ['STANDARD', 'FULL', 'DSE-SMOKE-TEST', 'EVENT_LOOP'], + description: '''

Profile to utilize for scheduled or adhoc builds

+ + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
STANDARDExecute the standard tests for the driver
FULLExecute all tests for the driver, including long tests.
DSE-SMOKE-TESTExecute only the DataStax Enterprise smoke tests
EVENT_LOOPExecute only the event loop tests for the specified event loop manager (see: EVENT_LOOP)
''') + choice( + name: 'MATRIX', + choices: ['DEFAULT', 'SMOKE', 'FULL', 'CASSANDRA', 'DSE', 'HCD'], + description: '''

The matrix for the build.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
DEFAULTDefault to the build context.
SMOKEBasic smoke tests for current Python runtimes + C*/DSE versions, no Cython
FULLAll server versions, python runtimes tested with and without Cython.
CASSANDRAAll cassandra server versions.
DSEAll dse server versions.
HCDAll hcd server versions.
''') + choice( + name: 'PYTHON_VERSION', + choices: ['DEFAULT'] + DEFAULT_RUNTIME, + description: 'Python runtime version. Default to the build context.') + choice( + name: 'SERVER_VERSION', + choices: ['DEFAULT'] + DEFAULT_CASSANDRA + DEFAULT_DSE + DEFAULT_HCD, + description: '''Apache CassandraⓇ and DataStax Enterprise server version to use for adhoc BUILD-AND-EXECUTE-TESTS ONLY! + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
DEFAULTDefault to the build context.
3.0Apache CassandraⓇ v3.0.x
3.11Apache CassandraⓇ v3.11.x
4.0Apache CassandraⓇ v4.0.x
5.0Apache CassandraⓇ v5.0.x
dse-5.1.35DataStax Enterprise v5.1.x
dse-6.8.30DataStax Enterprise v6.8.x
dse-6.9.0DataStax Enterprise v6.9.x (CURRENTLY UNDER DEVELOPMENT)
hcd-1.0.0DataStax HCD v1.0.x (CURRENTLY UNDER DEVELOPMENT)
''') + choice( + name: 'CYTHON', + choices: ['DEFAULT'] + DEFAULT_CYTHON, + description: '''

Flag to determine if Cython should be enabled

+ + + + + + + + + + + + + + + + + + + +
ChoiceDescription
DefaultDefault to the build context.
TrueEnable Cython
FalseDisable Cython
''') + choice( + name: 'EVENT_LOOP', + choices: ['LIBEV', 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED'], + description: '''

Event loop manager to utilize for scheduled or adhoc builds

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
LIBEVA full-featured and high-performance event loop that is loosely modeled after libevent, but without its limitations and bugs
GEVENTA co-routine -based Python networking library that uses greenlet to provide a high-level synchronous API on top of the libev or libuv event loop
EVENTLETA concurrent networking library for Python that allows you to change how you run your code, not how you write it
ASYNCIOA library to write concurrent code using the async/await syntax
ASYNCOREA module provides the basic infrastructure for writing asynchronous socket service clients and servers
TWISTEDAn event-driven networking engine written in Python and licensed under the open source MIT license
''') + choice( + name: 'CI_SCHEDULE', + choices: ['DO-NOT-CHANGE-THIS-SELECTION', 'WEEKNIGHTS', 'WEEKENDS'], + description: 'CI testing schedule to execute periodically scheduled builds and tests of the driver (DO NOT CHANGE THIS SELECTION)') + string( + name: 'CI_SCHEDULE_PYTHON_VERSION', + defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION', + description: 'CI testing python version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)') + string( + name: 'CI_SCHEDULE_SERVER_VERSION', + defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION', + description: 'CI testing server version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)') + } + + triggers { + parameterizedCron(branchPatternCron().matcher(env.BRANCH_NAME).matches() ? """ + # Every weeknight (Monday - Friday) around 4:00 AM + # These schedules will run with and without Cython enabled for Python 3.9.23 and 3.13.5 + H 4 * * 1-5 %CI_SCHEDULE=WEEKNIGHTS;EVENT_LOOP=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.9.23 3.13.5;CI_SCHEDULE_SERVER_VERSION=3.11 4.0 5.0 dse-5.1.35 dse-6.8.30 dse-6.9.0 hcd-1.0.0 + """ : "") + } + + environment { + OS_VERSION = 'ubuntu/jammy64/python-driver' + CCM_ENVIRONMENT_SHELL = '/usr/local/bin/ccm_environment.sh' + CCM_MAX_HEAP_SIZE = '1536M' + JABBA_SHELL = '/usr/lib/jabba/jabba.sh' + } + + stages { + stage ('Build and Test') { + when { + beforeAgent true + allOf { + not { buildingTag() } + } + } + + steps { + script { + context = getBuildContext() + withEnv(context.vars) { + describeBuild(context) + + // build and test all builds + parallel getMatrixBuilds(context) + + slack.notifyChannel(currentBuild.currentResult) + } + } + } + } + + } +} diff --git a/MANIFEST.in b/MANIFEST.in index daa6931fbf..cf828fe4da 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,7 @@ -include setup.py README.rst MANIFEST.in LICENSE distribute_setup.py -global-exclude *~ +global-exclude *.c +include setup.py README.rst MANIFEST.in LICENSE +include cassandra/cmurmur3.c +include cassandra/io/libevwrapper.c +include cassandra/*.pyx +include cassandra/*.pxd +include cassandra/*.h diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000000..58250f616b --- /dev/null +++ b/NOTICE @@ -0,0 +1,115 @@ +Apache Cassandra Python Driver +Copyright 2013 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +This product originates, before git sha +5d4fd2349119a3237ad351a96e7f2b3317159305, from software from DataStax and other +individual contributors. All work was previously copyrighted to DataStax. + +Non-DataStax contributors are listed below. Those marked with asterisk have +explicitly consented to their contributions being donated to the ASF. + +a-detiste Alexandre Detiste * +a-lst Andrey Istochkin * +aboudreault Alan Boudreault alan@alanb.ca +advance512 Alon Diamant diamant.alon@gmail.com * +alanjds Alan Justino da Silva alan.justino@yahoo.com.br * +alistair-broomhead Alistair Broomhead * +amygdalama Amy Hanlon * +andy-slac Andy Salnikov * +andy8zhao Andy Zhao +anthony-cervantes Anthony Cervantes anthony@cervantes.io * +BackEndTea Gert de Pagter * +barvinograd Bar Vinograd +bbirand Berk Birand +bergundy Roey Berman roey.berman@gmail.com * +bohdantan * +codesnik Alexey Trofimenko aronaxis@gmail.com * +coldeasy coldeasy +DanieleSalatti Daniele Salatti me@danielesalatti.com * +daniloarodrigues Danilo de Araújo Rodrigues * +daubman Aaron Daubman github@ajd.us * +dcosson Danny Cosson dcosson@gmail.com * +detzgk Eli Green eli@zigr.org * +devdazed Russ Bradberry * +dizpers Dmitry Belaventsev dizpers@gmail.com +dkropachev Dmitry Kropachev dmitry.kropachev@gmail.com * +dmglab Daniel dmg.lab@outlook.com +dokai Kai Lautaportti * +eamanu Emmanuel Arias eamanu@yaerobi.com +figpope Andrew FigPope andrew.figpope@gmail.com * +flupke Luper Rouch * +frensjan Frens Jan Rumph * +frew Fred Wulff frew@cs.stanford.edu * +gdoermann Greg Doermann +haaawk Piotr Jastrzębski +ikapl Irina Kaplounova +ittus Thang Minh Vu * +JeremyOT Jeremy Olmsted-Thompson * +jeremyschlatter Jeremy Schlatter * +jpuerta Ernesto Puerta * +julien-duponchelle Julien Duponchelle julien@duponchelle.info * +justinsb Justin Santa Barbara justinsb@google.com * +Kami Tomaz Muraus tomaz@tomaz.me +kandul Michał Kandulski michal.kandulski@gmail.com +kdeldycke Kevin Deldycke * +kishkaru Kishan Karunaratne kishan@karu.io * +kracekumar Kracekumar kracethekingmaker@gmail.com +lenards Andrew Lenards andrew.lenards@gmail.com * +lenolib +Lifto Ellis Low +Lorak-mmk Karol Baryła git@baryla.org * +lukaselmer Lukas Elmer lukas.elmer@gmail.com * +mahall Michael Hall +markflorisson Mark Florisson * +mattrobenolt Matt Robenolt m@robenolt.com * +mattstibbs Matt Stibbs * +Mhs-220 Mo Shahmohammadi hos1377@gmail.com * +mikeokner Mike Okner * +Mishail Mikhail Stepura mstepura@apple.com * +mission-liao mission.liao missionaryliao@gmail.com * +mkocikowski Mik Kocikowski +Mokto Théo Mathieu * +mrk-its Mariusz Kryński * +multani Jonathan Ballet jon@multani.info * +niklaskorz Niklas Korz * +nisanharamati nisanharamati +nschrader Nick Schrader nick.schrader@mailbox.org * +Orenef11 Oren Efraimov * +oz123 Oz Tiram * +pistolero Sergii Kyryllov * +pmcnett Paul McNett p@ulmcnett.com * +psarna Piotr Sarna * +r4fek Rafał Furmański * +raopm +rbranson Rick Branson * +rqx Roman Khanenko * +rtb-zla-karma xyz * +sigmunau +silviot Silvio Tomatis +sontek John Anderson sontek@gmail.com * +stanhu Stan Hu +stefanor Stefano Rivera stefanor@debian.org * +strixcuriosus Ash Hoover strixcuriosus@gmail.com +tarzanjw Học Đỗ hoc3010@gmail.com +tbarbugli Tommaso Barbugli +tchaikov Kefu Chai tchaikov@gmail.com * +tglines Travis Glines +thoslin Tom Lin +tigrus Nikolay Fominykh nikolayfn@gmail.com +timgates42 Tim Gates +timsavage Tim Savage * +tirkarthi Karthikeyan Singaravelan tir.karthi@gmail.com * +Trundle Andreas Stührk andy@hammerhartes.de +ubombi Vitalii Kozlovskyi vitalii@kozlovskyi.dev * +ultrabug Ultrabug * +vetal4444 Shevchenko Vitaliy * +victorpoluceno Victor Godoy Poluceno victorpoluceno@gmail.com +weisslj Johannes Weißl * +wenheping wenheping wenheping2000@hotmail.com +yi719 +yinyin Yinyin * +yriveiro Yago Riveiro * \ No newline at end of file diff --git a/README-dev.rst b/README-dev.rst index 4a1a3ce14c..3f8789cb81 100644 --- a/README-dev.rst +++ b/README-dev.rst @@ -1,51 +1,155 @@ Releasing ========= +Note: the precise details of some of these steps have changed. Leaving this here as a guide only. + * Run the tests and ensure they all pass * Update CHANGELOG.rst + * Check for any missing entries + * Add today's date to the release section * Update the version in ``cassandra/__init__.py`` -* Commit the changelog and version changes + * For beta releases, use a version like ``(2, 1, '0b1')`` + * For release candidates, use a version like ``(2, 1, '0rc1')`` + * When in doubt, follow PEP 440 versioning +* Add the new version in ``docs.yaml`` +* Commit the changelog and version changes, e.g. ``git commit -m'version 1.0.0'`` * Tag the release. For example: ``git tag -a 1.0.0 -m 'version 1.0.0'`` -* Push the commit and tag: ``git push --tags origin master`` -* Upload the package to pypi: +* Push the tag and new ``master``: ``git push origin 1.0.0 ; git push origin master`` +* Update the `python-driver` submodule of `python-driver-wheels`, + commit then push. +* Trigger the Github Actions necessary to build wheels for the various platforms +* For a GA release, upload the package to pypi:: + + # Clean the working directory + python setup.py clean + rm dist/* + + # Build the source distribution + python setup.py sdist - python setup.py register + # Download all wheels from the jfrog repository and copy them in + # the dist/ directory + cp /path/to/wheels/*.whl dist/ - python setup.py sdist upload + # Upload all files + twine upload dist/* +* On pypi, make the latest GA the only visible version * Update the docs (see below) -* Add a '+' to the version in ``cassandra/__init__.py`` so that it looks - like ``x.y.z+`` +* Append a 'postN' string to the version tuple in ``cassandra/__init__.py`` + so that it looks like ``(x, y, z, 'postN')`` + + * After a beta or rc release, this should look like ``(2, 1, '0b1', 'post0')`` + +* After the release has been tagged, add a section to docs.yaml with the new tag ref:: + + versions: + - name: + ref: + * Commit and push +* Update 'cassandra-test' branch to reflect new release -Running the Tests -================= -In order for the extensions to be built and used in the test, run: + * this is typically a matter of merging or rebasing onto master + * test and push updated branch to origin - python setup.py nosetests +* Update the JIRA releases: https://issues.apache.org/jira/projects/CASSPYTHON?selectedItem=com.atlassian.jira.jira-projects-plugin:release-page -Building the Docs -================= -Sphinx is required to build the docs. You probably want to install through apt, -if possible: + * add release dates and set version as "released" - $ sudo apt-get install python-sphinx +* Make an announcement on the mailing list -pip may also work: +Tests +===== - $ sudo pip install -U Sphinx +Running Unit Tests +------------------ +Unit tests can be run like so:: -To build the docs, run: + pytest tests/unit/ - python setup.py doc +You can run a specific test method like so:: + + pytest tests/unit/test_connection.py::ConnectionTest::test_bad_protocol_version + +Running Integration Tests +------------------------- +In order to run integration tests, you must specify a version to run using the ``CASSANDRA_VERSION`` or ``DSE_VERSION`` environment variable:: + + CASSANDRA_VERSION=4.0.1 pytest tests/integration/standard + +Or you can specify a cassandra directory (to test unreleased versions):: + + CASSANDRA_DIR=/path/to/cassandra pytest tests/integration/standard/ + +Specifying the usage of an already running Cassandra cluster +------------------------------------------------------------ +The test will start the appropriate Cassandra clusters when necessary but if you don't want this to happen because a Cassandra cluster is already running the flag ``USE_CASS_EXTERNAL`` can be used, for example:: + + USE_CASS_EXTERNAL=1 CASSANDRA_VERSION=4.0.1 pytest tests/integration/standard + +Specify a Protocol Version for Tests +------------------------------------ +The protocol version defaults to 1 for cassandra 1.2 and 2 otherwise. You can explicitly set +it with the ``PROTOCOL_VERSION`` environment variable:: + + PROTOCOL_VERSION=3 pytest tests/integration/standard + +Testing Multiple Python Versions +-------------------------------- +Use tox to test all of Python 3.10 through 3.14 and pypy:: + + tox + +By default, tox only runs the unit tests. -To upload the docs, checkout the ``gh-pages`` branch (it's usually easier to -clone a second copy of this repo and leave it on that branch) and copy the entire -contents all of ``docs/_build/X.Y.Z/*`` into the root of the ``gh-pages`` branch -and then push that branch to github. +Running the Benchmarks +====================== +There needs to be a version of cassandra running locally so before running the benchmarks, if ccm is installed: + + ccm create benchmark_cluster -v 4.0.1 -n 1 -s -For example: +To run the benchmarks, pick one of the files under the ``benchmarks/`` dir and run it:: + + python benchmarks/future_batches.py + +There are a few options. Use ``--help`` to see them all:: + + python benchmarks/future_batches.py --help + +Packaging for Cassandra +======================= +A source distribution is included in Cassandra, which uses the driver internally for ``cqlsh``. +To package a released version, checkout the tag and build a source zip archive:: + + python setup.py sdist --formats=zip + +If packaging a pre-release (untagged) version, it is useful to include a commit hash in the archive +name to specify the built version:: + + python setup.py egg_info -b-`git rev-parse --short HEAD` sdist --formats=zip + +The file (``dist/cassandra-driver-.zip``) is packaged with Cassandra in ``cassandra/lib/cassandra-driver-internal-only*zip``. + +Releasing an EAP +================ + +An EAP release is only uploaded on a private server and it is not published on pypi. + +* Clean the environment:: + + python setup.py clean + +* Package the source distribution:: + + python setup.py sdist + +* Test the source distribution:: + + pip install dist/cassandra-driver-.tar.gz + +* Upload the package on the EAP download server. +* Build the documentation:: + + python setup.py doc - $ python setup.py doc - $ cp -R docs/_build/1.0.0-beta1/* ~/python-driver-docs/ - $ cd ~/python-driver-docs - $ git push origin gh-pages +* Upload the docs on the EAP download server. diff --git a/README.rst b/README.rst index 91b2aca516..da998967ea 100644 --- a/README.rst +++ b/README.rst @@ -1,120 +1,86 @@ -DataStax Python Driver for Apache Cassandra (Beta) -================================================== -A Python client driver for Apache Cassandra. This driver works exclusively -with the Cassandra Query Language v3 (CQL3) and Cassandra's native -protocol. As such, only Cassandra 1.2+ is supported. - -**Warning** - -This driver is currently under heavy development, so the API and layout of -packages, modules, classes, and functions are subject to change. There may -also be serious bugs, so usage in a production environment is *not* -recommended at this time. - -* `JIRA `_ -* `Mailing List `_ -* IRC: #datastax-drivers on irc.freenode.net (you can use `freenode's web-based client `_) -* `API Documentation `_ - -Features to be Added --------------------- -* C extension for encoding/decoding messages -* Authentication/security feature support -* Twisted, gevent support -* Python 3 support -* IPv6 Support + +.. |license| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg + :target: https://opensource.org/licenses/Apache-2.0 +.. |version| image:: https://badge.fury.io/py/cassandra-driver.svg + :target: https://badge.fury.io/py/cassandra-driver +.. |pyversion| image:: https://img.shields.io/pypi/pyversions/cassandra-driver.svg +.. |travis| image:: https://api.travis-ci.com/datastax/python-driver.svg?branch=master + :target: https://travis-ci.com/github/datastax/python-driver + +|license| |version| |pyversion| |travis| + +Apache Cassandra Python Driver +============================== + +A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) and +DataStax Enterprise (4.7+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. + +The driver supports Python 3.10 through 3.14. + +**Note:** DataStax products do not support big-endian systems. + +Features +-------- +* `Synchronous `_ and `Asynchronous `_ APIs +* `Simple, Prepared, and Batch statements `_ +* Asynchronous IO, parallel execution, request pipelining +* `Connection pooling `_ +* Automatic node discovery +* `Automatic reconnection `_ +* Configurable `load balancing `_ and `retry policies `_ +* `Concurrent execution utilities `_ +* `Object mapper `_ +* `Connecting to DataStax Astra database (cloud) `_ +* DSE Graph execution API +* DSE Geometric type serialization +* DSE PlainText and GSSAPI authentication Installation ------------ -If you would like to use the optional C extensions, please follow -the instructions in the section below before installing the driver. - Installation through pip is recommended:: - $ sudo pip install cassandra-driver - -If you want to install manually, you can instead do:: - - $ sudo pip install futures scales blist # install dependencies - $ sudo python setup.py install - -C Extensions -^^^^^^^^^^^^ -By default, two C extensions are compiled: one that adds support -for token-aware routing with the Murmur3Partitioner, and one that -allows you to use libev for the event loop, which improves performance. - -When running setup.py, you can disable both with the ``--no-extensions`` -option, or selectively disable one or the other with ``--no-murmur3`` and -``--no-libev``. - -To compile the extenions, ensure that GCC and the Python headers are available. - -On Ubuntu and Debian, this can be accomplished by running:: - - $ sudo apt-get install build-essential python-dev - -On RedHat and RedHat-based systems like CentOS and Fedora:: - - $ sudo yum install gcc python-devel + $ pip install cassandra-driver -On OS X, homebrew installations of Python should provide the necessary headers. +For more complete installation instructions, see the +`installation guide `_. -libev support -^^^^^^^^^^^^^ -The driver currently uses Python's ``asyncore`` module for its default -event loop. For better performance, ``libev`` is also supported through -a C extension. +Documentation +------------- +The documentation can be found online `here `_. -If you're on Linux, you should be able to install libev -through a package manager. For example, on Debian/Ubuntu:: +A couple of links for getting up to speed: - $ sudo apt-get install libev4 libev-dev +* `Installation `_ +* `Getting started guide `_ +* `API docs `_ +* `Performance tips `_ -On RHEL/CentOS/Fedora:: +Object Mapper +------------- +cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the +community) is now maintained as an integral part of this package. Refer to +`documentation here `_. - $ sudo yum install libev libev-devel - -If you're on Mac OS X, you should be able to install libev -through `Homebrew `_. For example, on Mac OS X:: - - $ brew install libev - -If successful, you should be able to build and install the extension -(just using ``setup.py build`` or ``setup.py install``) and then use -the libev event loop by doing the following: - -.. code-block:: python - - >>> from cassandra.io.libevreactor import LibevConnection - >>> from cassandra.cluster import Cluster - - >>> cluster = Cluster() - >>> cluster.connection_class = LibevConnection - >>> session = cluster.connect() - -Compression Support -^^^^^^^^^^^^^^^^^^^ -Compression can optionally be used for communication between the driver and -Cassandra. There are currently two supported compression algorithms: -snappy (in Cassandra 1.2+) and LZ4 (only in Cassandra 2.0+). If either is -available for the driver and Cassandra also supports it, it will -be used automatically. - -For lz4 support:: - - sudo pip install lz4 +Contributing +------------ +See `CONTRIBUTING.rst `_. -For snappy support:: +Reporting Problems +------------------ +Please report any bugs and make any feature requests on the +`CASSPYTHON project `_ +of the ASF JIRA. - sudo pip install python-snappy +If you would like to contribute, please feel free to open a pull request. -(If using a Debian Linux derivative such as Ubuntu, it may be easier to -just run ``apt-get install python-snappy``.) +Getting Help +------------ +You can talk about the driver, ask questions and get help in the #cassandra-drivers channel on +`ASF Slack `_. License ------- -Copyright 2013, DataStax +Copyright 2013 The Apache Software Foundation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/benchmarks/base.py b/benchmarks/base.py index 832c6b65c2..290ba28788 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -1,93 +1,309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cProfile import Profile import logging import os.path import sys +from threading import Thread import time +from optparse import OptionParser +import uuid + +from greplin import scales + dirname = os.path.dirname(os.path.abspath(__file__)) sys.path.append(dirname) sys.path.append(os.path.join(dirname, '..')) +import cassandra from cassandra.cluster import Cluster from cassandra.io.asyncorereactor import AsyncoreConnection -from cassandra.query import SimpleStatement log = logging.getLogger() -log.setLevel('INFO') handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) log.addHandler(handler) +logging.getLogger('cassandra').setLevel(logging.WARN) + +_log_levels = { + 'CRITICAL': logging.CRITICAL, + 'ERROR': logging.ERROR, + 'WARN': logging.WARNING, + 'WARNING': logging.WARNING, + 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG, + 'NOTSET': logging.NOTSET, +} + +have_libev = False supported_reactors = [AsyncoreConnection] try: from cassandra.io.libevreactor import LibevConnection + have_libev = True supported_reactors.append(LibevConnection) -except ImportError, exc: - log.warning("Not benchmarking libev reactor: %s" % (exc,)) +except ImportError as exc: + pass -KEYSPACE = "testkeyspace" +have_asyncio = False +try: + from cassandra.io.asyncioreactor import AsyncioConnection + have_asyncio = True + supported_reactors.append(AsyncioConnection) +except (ImportError, SyntaxError): + pass + +have_twisted = False +try: + from cassandra.io.twistedreactor import TwistedConnection + have_twisted = True + supported_reactors.append(TwistedConnection) +except ImportError as exc: + log.exception("Error importing twisted") + pass + +KEYSPACE = "testkeyspace" + str(int(time.time())) TABLE = "testtable" -NUM_QUERIES = 10000 -def setup(): +COLUMN_VALUES = { + 'int': 42, + 'text': "'42'", + 'float': 42.0, + 'uuid': uuid.uuid4(), + 'timestamp': "'2016-02-03 04:05+0000'" +} + + +def setup(options): + log.info("Using 'cassandra' package from %s", cassandra.__path__) + + cluster = Cluster(options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False) + try: + session = cluster.connect() + + log.debug("Creating keyspace...") + try: + session.execute(""" + CREATE KEYSPACE %s + WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + """ % options.keyspace) + + log.debug("Setting keyspace...") + except cassandra.AlreadyExists: + log.debug("Keyspace already exists") + + session.set_keyspace(options.keyspace) + + log.debug("Creating table...") + create_table_query = """ + CREATE TABLE {0} ( + thekey text, + """ + for i in range(options.num_columns): + create_table_query += "col{0} {1},\n".format(i, options.column_type) + create_table_query += "PRIMARY KEY (thekey))" + + try: + session.execute(create_table_query.format(TABLE)) + except cassandra.AlreadyExists: + log.debug("Table already exists.") + + finally: + cluster.shutdown() - cluster = Cluster(['127.0.0.1']) - session = cluster.connect() - rows = session.execute("SELECT keyspace_name FROM system.schema_keyspaces") - if KEYSPACE in [row[0] for row in rows]: - log.debug("dropping existing keyspace...") - session.execute("DROP KEYSPACE " + KEYSPACE) - - log.debug("Creating keyspace...") - session.execute(""" - CREATE KEYSPACE %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } - """ % KEYSPACE) - - log.debug("Setting keyspace...") - session.set_keyspace(KEYSPACE) - - log.debug("Creating table...") - session.execute(""" - CREATE TABLE %s ( - thekey text, - col1 text, - col2 text, - PRIMARY KEY (thekey, col1) - ) - """ % TABLE) - -def teardown(): - cluster = Cluster(['127.0.0.1']) +def teardown(options): + cluster = Cluster(options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False) session = cluster.connect() - session.execute("DROP KEYSPACE " + KEYSPACE) + if not options.keep_data: + session.execute("DROP KEYSPACE " + options.keyspace) + cluster.shutdown() -def benchmark(run_fn): - for conn_class in supported_reactors: - setup() +def benchmark(thread_class): + options, args = parse_options() + for conn_class in options.supported_reactors: + setup(options) log.info("==== %s ====" % (conn_class.__name__,)) - cluster = Cluster(['127.0.0.1']) - cluster.connection_class = conn_class - session = cluster.connect(KEYSPACE) + kwargs = {'metrics_enabled': options.enable_metrics, + 'connection_class': conn_class} + if options.protocol_version: + kwargs['protocol_version'] = options.protocol_version + cluster = Cluster(options.hosts, **kwargs) + session = cluster.connect(options.keyspace) log.debug("Sleeping for two seconds...") time.sleep(2.0) - query = SimpleStatement(""" - INSERT INTO {table} (thekey, col1, col2) - VALUES (%(key)s, %(a)s, %(b)s) - """.format(table=TABLE)) - values = {'key': 'key', 'a': 'a', 'b': 'b'} - log.debug("Beginning inserts...") + # Generate the query + if options.read: + query = "SELECT * FROM {0} WHERE thekey = '{{key}}'".format(TABLE) + else: + query = "INSERT INTO {0} (thekey".format(TABLE) + for i in range(options.num_columns): + query += ", col{0}".format(i) + + query += ") VALUES ('{key}'" + for i in range(options.num_columns): + query += ", {0}".format(COLUMN_VALUES[options.column_type]) + query += ")" + + values = None # we don't use that anymore. Keeping it in case we go back to prepared statements. + per_thread = options.num_ops // options.threads + threads = [] + + log.debug("Beginning {0}...".format('reads' if options.read else 'inserts')) start = time.time() try: - run_fn(session, query, values, NUM_QUERIES) + for i in range(options.threads): + thread = thread_class( + i, session, query, values, per_thread, + cluster.protocol_version, options.profile) + thread.daemon = True + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + while thread.is_alive(): + thread.join(timeout=0.5) + end = time.time() finally: - teardown() + cluster.shutdown() + teardown(options) total = end - start log.info("Total time: %0.2fs" % total) - log.info("Average throughput: %0.2f/sec" % (NUM_QUERIES / total)) + log.info("Average throughput: %0.2f/sec" % (options.num_ops / total)) + if options.enable_metrics: + stats = scales.getStats()['cassandra'] + log.info("Connection errors: %d", stats['connection_errors']) + log.info("Write timeouts: %d", stats['write_timeouts']) + log.info("Read timeouts: %d", stats['read_timeouts']) + log.info("Unavailables: %d", stats['unavailables']) + log.info("Other errors: %d", stats['other_errors']) + log.info("Retries: %d", stats['retries']) + + request_timer = stats['request_timer'] + log.info("Request latencies:") + log.info(" min: %0.4fs", request_timer['min']) + log.info(" max: %0.4fs", request_timer['max']) + log.info(" mean: %0.4fs", request_timer['mean']) + log.info(" stddev: %0.4fs", request_timer['stddev']) + log.info(" median: %0.4fs", request_timer['median']) + log.info(" 75th: %0.4fs", request_timer['75percentile']) + log.info(" 95th: %0.4fs", request_timer['95percentile']) + log.info(" 98th: %0.4fs", request_timer['98percentile']) + log.info(" 99th: %0.4fs", request_timer['99percentile']) + log.info(" 99.9th: %0.4fs", request_timer['999percentile']) + + +def parse_options(): + parser = OptionParser() + parser.add_option('-H', '--hosts', default='127.0.0.1', + help='cassandra hosts to connect to (comma-separated list) [default: %default]') + parser.add_option('-t', '--threads', type='int', default=1, + help='number of threads [default: %default]') + parser.add_option('-n', '--num-ops', type='int', default=10000, + help='number of operations [default: %default]') + parser.add_option('--asyncore-only', action='store_true', dest='asyncore_only', + help='only benchmark with asyncore connections') + parser.add_option('--asyncio-only', action='store_true', dest='asyncio_only', + help='only benchmark with asyncio connections') + parser.add_option('--libev-only', action='store_true', dest='libev_only', + help='only benchmark with libev connections') + parser.add_option('--twisted-only', action='store_true', dest='twisted_only', + help='only benchmark with Twisted connections') + parser.add_option('-m', '--metrics', action='store_true', dest='enable_metrics', + help='enable and print metrics for operations') + parser.add_option('-l', '--log-level', default='info', + help='logging level: debug, info, warning, or error') + parser.add_option('-p', '--profile', action='store_true', dest='profile', + help='Profile the run') + parser.add_option('--protocol-version', type='int', dest='protocol_version', default=4, + help='Native protocol version to use') + parser.add_option('-c', '--num-columns', type='int', dest='num_columns', default=2, + help='Specify the number of columns for the schema') + parser.add_option('-k', '--keyspace', type='str', dest='keyspace', default=KEYSPACE, + help='Specify the keyspace name for the schema') + parser.add_option('--keep-data', action='store_true', dest='keep_data', default=False, + help='Keep the data after the benchmark') + parser.add_option('--column-type', type='str', dest='column_type', default='text', + help='Specify the column type for the schema (supported: int, text, float, uuid, timestamp)') + parser.add_option('--read', action='store_true', dest='read', default=False, + help='Read mode') + + + options, args = parser.parse_args() + + options.hosts = options.hosts.split(',') + + level = options.log_level.upper() + try: + log.setLevel(_log_levels[level]) + except KeyError: + log.warning("Unknown log level specified: %s; specify one of %s", options.log_level, _log_levels.keys()) + + if options.asyncore_only: + options.supported_reactors = [AsyncoreConnection] + elif options.asyncio_only: + options.supported_reactors = [AsyncioConnection] + elif options.libev_only: + if not have_libev: + log.error("libev is not available") + sys.exit(1) + options.supported_reactors = [LibevConnection] + elif options.twisted_only: + if not have_twisted: + log.error("Twisted is not available") + sys.exit(1) + options.supported_reactors = [TwistedConnection] + else: + options.supported_reactors = supported_reactors + if not have_libev: + log.warning("Not benchmarking libev reactor because libev is not available") + + return options, args + + +class BenchmarkThread(Thread): + + def __init__(self, thread_num, session, query, values, num_queries, protocol_version, profile): + Thread.__init__(self) + self.thread_num = thread_num + self.session = session + self.query = query + self.values = values + self.num_queries = num_queries + self.protocol_version = protocol_version + self.profiler = Profile() if profile else None + + def start_profile(self): + if self.profiler: + self.profiler.enable() + + def run_query(self, key, **kwargs): + return self.session.execute_async(self.query.format(key=key), **kwargs) + + def finish_profile(self): + if self.profiler: + self.profiler.disable() + self.profiler.dump_stats('profile-%d' % self.thread_num) diff --git a/benchmarks/callback_full_pipeline.py b/benchmarks/callback_full_pipeline.py new file mode 100644 index 0000000000..5eafa5df8b --- /dev/null +++ b/benchmarks/callback_full_pipeline.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from itertools import count +from threading import Event + +from base import benchmark, BenchmarkThread + +log = logging.getLogger(__name__) + + +sentinel = object() + + +class Runner(BenchmarkThread): + + def __init__(self, *args, **kwargs): + BenchmarkThread.__init__(self, *args, **kwargs) + self.num_started = count() + self.num_finished = count() + self.event = Event() + + def insert_next(self, previous_result=sentinel): + if previous_result is not sentinel: + if isinstance(previous_result, BaseException): + log.error("Error on insert: %r", previous_result) + if next(self.num_finished) >= self.num_queries: + self.event.set() + + i = next(self.num_started) + if i <= self.num_queries: + key = "{0}-{1}".format(self.thread_num, i) + future = self.run_query(key, timeout=None) + future.add_callbacks(self.insert_next, self.insert_next) + + def run(self): + self.start_profile() + + if self.protocol_version >= 3: + concurrency = 1000 + else: + concurrency = 100 + + for _ in range(min(concurrency, self.num_queries)): + self.insert_next() + + self.event.wait() + + self.finish_profile() + + +if __name__ == "__main__": + benchmark(Runner) diff --git a/benchmarks/future_batches.py b/benchmarks/future_batches.py new file mode 100644 index 0000000000..112cc24981 --- /dev/null +++ b/benchmarks/future_batches.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from base import benchmark, BenchmarkThread +import queue + +log = logging.getLogger(__name__) + + +class Runner(BenchmarkThread): + + def run(self): + futures = queue.Queue(maxsize=121) + + self.start_profile() + + for i in range(self.num_queries): + if i > 0 and i % 120 == 0: + # clear the existing queue + while True: + try: + futures.get_nowait().result() + except queue.Empty: + break + + key = "{0}-{1}".format(self.thread_num, i) + future = self.run_query(key) + futures.put_nowait(future) + + while True: + try: + futures.get_nowait().result() + except queue.Empty: + break + + self.finish_profile() + + +if __name__ == "__main__": + benchmark(Runner) diff --git a/benchmarks/future_full_pipeline.py b/benchmarks/future_full_pipeline.py new file mode 100644 index 0000000000..ca95b742d2 --- /dev/null +++ b/benchmarks/future_full_pipeline.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from base import benchmark, BenchmarkThread +import queue + +log = logging.getLogger(__name__) + + +class Runner(BenchmarkThread): + + def run(self): + futures = queue.Queue(maxsize=121) + + self.start_profile() + + for i in range(self.num_queries): + if i >= 120: + old_future = futures.get_nowait() + old_future.result() + + key = "{}-{}".format(self.thread_num, i) + future = self.run_query(key) + futures.put_nowait(future) + + while True: + try: + futures.get_nowait().result() + except queue.Empty: + break + + self.finish_profile + + +if __name__ == "__main__": + benchmark(Runner) diff --git a/benchmarks/future_full_throttle.py b/benchmarks/future_full_throttle.py new file mode 100644 index 0000000000..f85eb99b0d --- /dev/null +++ b/benchmarks/future_full_throttle.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from base import benchmark, BenchmarkThread + +log = logging.getLogger(__name__) + +class Runner(BenchmarkThread): + + def run(self): + futures = [] + + self.start_profile() + + for i in range(self.num_queries): + key = "{0}-{1}".format(self.thread_num, i) + future = self.run_query(key) + futures.append(future) + + for future in futures: + future.result() + + self.finish_profile() + + +if __name__ == "__main__": + benchmark(Runner) diff --git a/benchmarks/single_thread_callback_full_pipeline.py b/benchmarks/single_thread_callback_full_pipeline.py deleted file mode 100644 index 248a90e93c..0000000000 --- a/benchmarks/single_thread_callback_full_pipeline.py +++ /dev/null @@ -1,38 +0,0 @@ -from base import benchmark - -import logging -from itertools import count -from threading import Event - -log = logging.getLogger(__name__) - -initial = object() - -def execute(session, query, values, num_queries): - - num_started = count() - num_finished = count() - event = Event() - - def handle_error(exc): - log.error("Error on insert: %r", exc) - - def insert_next(previous_result): - current_num = num_started.next() - - if previous_result is not initial: - num = next(num_finished) - if num >= num_queries: - event.set() - - if current_num <= num_queries: - future = session.execute_async(query, values) - future.add_callbacks(insert_next, handle_error) - - for i in range(120): - insert_next(initial) - - event.wait() - -if __name__ == "__main__": - benchmark(execute) diff --git a/benchmarks/single_thread_future_batches.py b/benchmarks/single_thread_future_batches.py deleted file mode 100644 index 8396916ad0..0000000000 --- a/benchmarks/single_thread_future_batches.py +++ /dev/null @@ -1,32 +0,0 @@ -from base import benchmark - -import logging -import Queue - -log = logging.getLogger(__name__) - -def execute(session, query, values, num_queries): - - futures = Queue.Queue(maxsize=121) - - for i in range(num_queries): - if i > 0 and i % 120 == 0: - # clear the existing queue - while True: - try: - futures.get_nowait().result() - except Queue.Empty: - break - - future = session.execute_async(query, values) - futures.put_nowait(future) - - while True: - try: - futures.get_nowait().result() - except Queue.Empty: - break - - -if __name__ == "__main__": - benchmark(execute) diff --git a/benchmarks/single_thread_future_full_pipeline.py b/benchmarks/single_thread_future_full_pipeline.py deleted file mode 100644 index a4b11ebcf8..0000000000 --- a/benchmarks/single_thread_future_full_pipeline.py +++ /dev/null @@ -1,28 +0,0 @@ -from base import benchmark - -import logging -import Queue - -log = logging.getLogger(__name__) - -def execute(session, query, values, num_queries): - - futures = Queue.Queue(maxsize=121) - - for i in range(num_queries): - if i >= 120: - old_future = futures.get_nowait() - old_future.result() - - future = session.execute_async(query, values) - futures.put_nowait(future) - - while True: - try: - futures.get_nowait().result() - except Queue.Empty: - break - - -if __name__ == "__main__": - benchmark(execute) diff --git a/benchmarks/single_thread_future_full_throttle.py b/benchmarks/single_thread_future_full_throttle.py deleted file mode 100644 index 18484dac6b..0000000000 --- a/benchmarks/single_thread_future_full_throttle.py +++ /dev/null @@ -1,20 +0,0 @@ -from base import benchmark - -import logging - -log = logging.getLogger(__name__) - -def execute(session, query, values, num_queries): - - futures = [] - - for i in range(num_queries): - future = session.execute_async(query, values) - futures.append(future) - - for future in futures: - future.result() - - -if __name__ == "__main__": - benchmark(execute) diff --git a/benchmarks/single_thread_sync.py b/benchmarks/single_thread_sync.py deleted file mode 100644 index 3fd0cf4894..0000000000 --- a/benchmarks/single_thread_sync.py +++ /dev/null @@ -1,8 +0,0 @@ -from base import benchmark - -def execute(session, query, values, num_queries): - for i in xrange(num_queries): - session.execute(query, values) - -if __name__ == "__main__": - benchmark(execute) diff --git a/benchmarks/sync.py b/benchmarks/sync.py new file mode 100644 index 0000000000..090a265579 --- /dev/null +++ b/benchmarks/sync.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from base import benchmark, BenchmarkThread + + +class Runner(BenchmarkThread): + + def run(self): + self.start_profile() + + for _ in range(self.num_queries): + self.session.execute(self.query, self.values) + + self.finish_profile() + + +if __name__ == "__main__": + benchmark(Runner) diff --git a/cassandra/__init__.py b/cassandra/__init__.py index e9c9a46d65..c732708605 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -1,10 +1,34 @@ -__version_info__ = (1, 0, '0-beta2+') -__version__ = '.'.join(map(str, __version_info__)) - +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import importlib.metadata + +class NullHandler(logging.Handler): + + def emit(self, record): + pass + +logging.getLogger('cassandra').addHandler(NullHandler()) + +__version__ = importlib.metadata.version('cassandra-driver') class ConsistencyLevel(object): """ - Spcifies how many replicas must respond for an operation to be considered + Specifies how many replicas must respond for an operation to be considered a success. By default, ``ONE`` is used for all operations. """ @@ -31,7 +55,7 @@ class ConsistencyLevel(object): QUORUM = 4 """ - ``ceil(RF/2)`` replicas must respond to consider the operation a success + ``ceil(RF/2) + 1`` replicas must respond to consider the operation a success """ ALL = 5 @@ -49,6 +73,30 @@ class ConsistencyLevel(object): Requires a quorum of replicas in each datacenter """ + SERIAL = 8 + """ + For conditional inserts/updates that utilize Cassandra's lightweight + transactions, this requires consensus among all replicas for the + modified data. + """ + + LOCAL_SERIAL = 9 + """ + Like :attr:`~ConsistencyLevel.SERIAL`, but only requires consensus + among replicas in the local datacenter. + """ + + LOCAL_ONE = 10 + """ + Sends a request only to replicas in the local datacenter and waits for + one response. + """ + + @staticmethod + def is_serial(cl): + return cl == ConsistencyLevel.SERIAL or cl == ConsistencyLevel.LOCAL_SERIAL + + ConsistencyLevel.value_to_name = { ConsistencyLevel.ANY: 'ANY', ConsistencyLevel.ONE: 'ONE', @@ -57,7 +105,10 @@ class ConsistencyLevel(object): ConsistencyLevel.QUORUM: 'QUORUM', ConsistencyLevel.ALL: 'ALL', ConsistencyLevel.LOCAL_QUORUM: 'LOCAL_QUORUM', - ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM' + ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM', + ConsistencyLevel.SERIAL: 'SERIAL', + ConsistencyLevel.LOCAL_SERIAL: 'LOCAL_SERIAL', + ConsistencyLevel.LOCAL_ONE: 'LOCAL_ONE' } ConsistencyLevel.name_to_value = { @@ -68,11 +119,285 @@ class ConsistencyLevel(object): 'QUORUM': ConsistencyLevel.QUORUM, 'ALL': ConsistencyLevel.ALL, 'LOCAL_QUORUM': ConsistencyLevel.LOCAL_QUORUM, - 'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM + 'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM, + 'SERIAL': ConsistencyLevel.SERIAL, + 'LOCAL_SERIAL': ConsistencyLevel.LOCAL_SERIAL, + 'LOCAL_ONE': ConsistencyLevel.LOCAL_ONE } -class Unavailable(Exception): +def consistency_value_to_name(value): + return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set" + + +class ProtocolVersion(object): + """ + Defines native protocol versions supported by this driver. + """ + V1 = 1 + """ + v1, supported in Cassandra 1.2-->2.2 + """ + + V2 = 2 + """ + v2, supported in Cassandra 2.0-->2.2; + added support for lightweight transactions, batch operations, and automatic query paging. + """ + + V3 = 3 + """ + v3, supported in Cassandra 2.1-->3.x+; + added support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`), + serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool. + """ + + V4 = 4 + """ + v4, supported in Cassandra 2.2-->3.x+; + added a number of new types, server warnings, new failure messages, and custom payloads. Details in the + `project docs `_ + """ + + V5 = 5 + """ + v5, in beta from 3.x+. Finalised in 4.0-beta5 + """ + + V6 = 6 + """ + v6, in beta from 4.0-beta5 + """ + + DSE_V1 = 0x41 + """ + DSE private protocol v1, supported in DSE 5.1+ + """ + + DSE_V2 = 0x42 + """ + DSE private protocol v2, supported in DSE 6.0+ + """ + + SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1) + """ + A tuple of all supported protocol versions + """ + + BETA_VERSIONS = (V6,) + """ + A tuple of all beta protocol versions + """ + + MIN_SUPPORTED = min(SUPPORTED_VERSIONS) + """ + Minimum protocol version supported by this driver. + """ + + MAX_SUPPORTED = max(SUPPORTED_VERSIONS) + """ + Maximum protocol version supported by this driver. + """ + + @classmethod + def get_lower_supported(cls, previous_version): + """ + Return the lower supported protocol version. Beta versions are omitted. + """ + try: + version = next(v for v in sorted(ProtocolVersion.SUPPORTED_VERSIONS, reverse=True) if + v not in ProtocolVersion.BETA_VERSIONS and v < previous_version) + except StopIteration: + version = 0 + + return version + + @classmethod + def uses_int_query_flags(cls, version): + return version >= cls.V5 + + @classmethod + def uses_prepare_flags(cls, version): + return version >= cls.V5 and version != cls.DSE_V1 + + @classmethod + def uses_prepared_metadata(cls, version): + return version >= cls.V5 and version != cls.DSE_V1 + + @classmethod + def uses_error_code_map(cls, version): + return version >= cls.V5 + + @classmethod + def uses_keyspace_flag(cls, version): + return version >= cls.V5 and version != cls.DSE_V1 + + @classmethod + def has_continuous_paging_support(cls, version): + return version >= cls.DSE_V1 + + @classmethod + def has_continuous_paging_next_pages(cls, version): + return version >= cls.DSE_V2 + + @classmethod + def has_checksumming_support(cls, version): + return cls.V5 <= version < cls.DSE_V1 + + +class WriteType(object): + """ + For usage with :class:`.RetryPolicy`, this describes a type + of write operation. + """ + + SIMPLE = 0 + """ + A write to a single partition key. Such writes are guaranteed to be atomic + and isolated. + """ + + BATCH = 1 + """ + A write to multiple partition keys that used the distributed batch log to + ensure atomicity. + """ + + UNLOGGED_BATCH = 2 + """ + A write to multiple partition keys that did not use the distributed batch + log. Atomicity for such writes is not guaranteed. + """ + + COUNTER = 3 + """ + A counter write (for one or multiple partition keys). Such writes should + not be replayed in order to avoid over counting. + """ + + BATCH_LOG = 4 + """ + The initial write to the distributed batch log that Cassandra performs + internally before a BATCH write. + """ + + CAS = 5 + """ + A lightweight-transaction write, such as "DELETE ... IF EXISTS". + """ + + VIEW = 6 + """ + This WriteType is only seen in results for requests that were unable to + complete MV operations. + """ + + CDC = 7 + """ + This WriteType is only seen in results for requests that were unable to + complete CDC operations. + """ + + +WriteType.name_to_value = { + 'SIMPLE': WriteType.SIMPLE, + 'BATCH': WriteType.BATCH, + 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, + 'COUNTER': WriteType.COUNTER, + 'BATCH_LOG': WriteType.BATCH_LOG, + 'CAS': WriteType.CAS, + 'VIEW': WriteType.VIEW, + 'CDC': WriteType.CDC +} + + +WriteType.value_to_name = {v: k for k, v in WriteType.name_to_value.items()} + + +class SchemaChangeType(object): + DROPPED = 'DROPPED' + CREATED = 'CREATED' + UPDATED = 'UPDATED' + + +class SchemaTargetType(object): + KEYSPACE = 'KEYSPACE' + TABLE = 'TABLE' + TYPE = 'TYPE' + FUNCTION = 'FUNCTION' + AGGREGATE = 'AGGREGATE' + + +class SignatureDescriptor(object): + + def __init__(self, name, argument_types): + self.name = name + self.argument_types = argument_types + + @property + def signature(self): + """ + function signature string in the form 'name([type0[,type1[...]]])' + + can be used to uniquely identify overloaded function names within a keyspace + """ + return self.format_signature(self.name, self.argument_types) + + @staticmethod + def format_signature(name, argument_types): + return "%s(%s)" % (name, ','.join(t for t in argument_types)) + + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, self.name, self.argument_types) + + +class UserFunctionDescriptor(SignatureDescriptor): + """ + Describes a User function by name and argument signature + """ + + name = None + """ + name of the function + """ + + argument_types = None + """ + Ordered list of CQL argument type names comprising the type signature + """ + + +class UserAggregateDescriptor(SignatureDescriptor): + """ + Describes a User aggregate function by name and argument signature + """ + + name = None + """ + name of the aggregate + """ + + argument_types = None + """ + Ordered list of CQL argument type names comprising the type signature + """ + + +class DriverException(Exception): + """ + Base for all exceptions explicitly raised by the driver. + """ + pass + + +class RequestExecutionException(DriverException): + """ + Base for request execution exceptions returned from the server. + """ + pass + + +class Unavailable(RequestExecutionException): """ There were not enough live replicas to satisfy the requested consistency level, so the coordinator node immediately failed the request without @@ -88,14 +413,17 @@ class Unavailable(Exception): alive_replicas = None """ The number of replicas that were actually alive """ - def __init__(self, message, consistency=None, required_replicas=None, alive_replicas=None): - Exception.__init__(self, message) + def __init__(self, summary_message, consistency=None, required_replicas=None, alive_replicas=None): self.consistency = consistency self.required_replicas = required_replicas self.alive_replicas = alive_replicas + Exception.__init__(self, summary_message + ' info=' + + repr({'consistency': consistency_value_to_name(consistency), + 'required_replicas': required_replicas, + 'alive_replicas': alive_replicas})) -class Timeout(Exception): +class Timeout(RequestExecutionException): """ Replicas failed to respond to the coordinator node before timing out. """ @@ -112,16 +440,31 @@ class Timeout(Exception): the operation """ - def __init__(self, message, consistency=None, required_responses=None, received_responses=None): - Exception.__init__(self, message) + def __init__(self, summary_message, consistency=None, required_responses=None, + received_responses=None, **kwargs): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses + if "write_type" in kwargs: + kwargs["write_type"] = WriteType.value_to_name[kwargs["write_type"]] + + info = {'consistency': consistency_value_to_name(consistency), + 'required_responses': required_responses, + 'received_responses': received_responses} + info.update(kwargs) + + Exception.__init__(self, summary_message + ' info=' + repr(info)) + class ReadTimeout(Timeout): """ A subclass of :exc:`Timeout` for read operations. + + This indicates that the replicas failed to respond to the coordinator + node before the configured timeout. This timeout is configured in + ``cassandra.yaml`` with the ``read_request_timeout_in_ms`` + and ``range_request_timeout_in_ms`` options. """ data_retrieved = None @@ -139,6 +482,11 @@ def __init__(self, message, data_retrieved=None, **kwargs): class WriteTimeout(Timeout): """ A subclass of :exc:`Timeout` for write operations. + + This indicates that the replicas failed to respond to the coordinator + node before the configured timeout. This timeout is configured in + ``cassandra.yaml`` with the ``write_request_timeout_in_ms`` + option. """ write_type = None @@ -147,11 +495,149 @@ class WriteTimeout(Timeout): """ def __init__(self, message, write_type=None, **kwargs): + kwargs["write_type"] = write_type Timeout.__init__(self, message, **kwargs) self.write_type = write_type -class AlreadyExists(Exception): +class CDCWriteFailure(RequestExecutionException): + """ + Hit limit on data in CDC folder, writes are rejected + """ + def __init__(self, message): + Exception.__init__(self, message) + + +class CoordinationFailure(RequestExecutionException): + """ + Replicas sent a failure to the coordinator. + """ + + consistency = None + """ The requested :class:`ConsistencyLevel` """ + + required_responses = None + """ The number of required replica responses """ + + received_responses = None + """ + The number of replicas that responded before the coordinator timed out + the operation + """ + + failures = None + """ + The number of replicas that sent a failure message + """ + + error_code_map = None + """ + A map of inet addresses to error codes representing replicas that sent + a failure message. Only set when `protocol_version` is 5 or higher. + """ + + def __init__(self, summary_message, consistency=None, required_responses=None, + received_responses=None, failures=None, error_code_map=None): + self.consistency = consistency + self.required_responses = required_responses + self.received_responses = received_responses + self.failures = failures + self.error_code_map = error_code_map + + info_dict = { + 'consistency': consistency_value_to_name(consistency), + 'required_responses': required_responses, + 'received_responses': received_responses, + 'failures': failures + } + + if error_code_map is not None: + # make error codes look like "0x002a" + formatted_map = dict((addr, '0x%04x' % err_code) + for (addr, err_code) in error_code_map.items()) + info_dict['error_code_map'] = formatted_map + + Exception.__init__(self, summary_message + ' info=' + repr(info_dict)) + + +class ReadFailure(CoordinationFailure): + """ + A subclass of :exc:`CoordinationFailure` for read operations. + + This indicates that the replicas sent a failure message to the coordinator. + """ + + data_retrieved = None + """ + A boolean indicating whether the requested data was retrieved + by the coordinator from any replicas before it timed out the + operation + """ + + def __init__(self, message, data_retrieved=None, **kwargs): + CoordinationFailure.__init__(self, message, **kwargs) + self.data_retrieved = data_retrieved + + +class WriteFailure(CoordinationFailure): + """ + A subclass of :exc:`CoordinationFailure` for write operations. + + This indicates that the replicas sent a failure message to the coordinator. + """ + + write_type = None + """ + The type of write operation, enum on :class:`~cassandra.policies.WriteType` + """ + + def __init__(self, message, write_type=None, **kwargs): + CoordinationFailure.__init__(self, message, **kwargs) + self.write_type = write_type + + +class FunctionFailure(RequestExecutionException): + """ + User Defined Function failed during execution + """ + + keyspace = None + """ + Keyspace of the function + """ + + function = None + """ + Name of the function + """ + + arg_types = None + """ + List of argument type names of the function + """ + + def __init__(self, summary_message, keyspace, function, arg_types): + self.keyspace = keyspace + self.function = function + self.arg_types = arg_types + Exception.__init__(self, summary_message) + + +class RequestValidationException(DriverException): + """ + Server request validation failed + """ + pass + + +class ConfigurationException(RequestValidationException): + """ + Server indicated request errro due to current configuration + """ + pass + + +class AlreadyExists(ConfigurationException): """ An attempt was made to create a keyspace or table that already exists. """ @@ -179,7 +665,7 @@ def __init__(self, keyspace=None, table=None): self.table = table -class InvalidRequest(Exception): +class InvalidRequest(RequestValidationException): """ A query was made that was invalid for some reason, such as trying to set the keyspace for a connection to a nonexistent keyspace. @@ -187,15 +673,74 @@ class InvalidRequest(Exception): pass -class Unauthorized(Exception): +class Unauthorized(RequestValidationException): """ - The current user is not authorized to perfom the requested operation. + The current user is not authorized to perform the requested operation. """ pass -class AuthenticationFailed(Exception): +class AuthenticationFailed(DriverException): """ Failed to authenticate. """ pass + + +class OperationTimedOut(DriverException): + """ + The operation took longer than the specified (client-side) timeout + to complete. This is not an error generated by Cassandra, only + the driver. + """ + + errors = None + """ + A dict of errors keyed by the :class:`~.Host` against which they occurred. + """ + + last_host = None + """ + The last :class:`~.Host` this operation was attempted against. + """ + + def __init__(self, errors=None, last_host=None): + self.errors = errors + self.last_host = last_host + message = "errors=%s, last_host=%s" % (self.errors, self.last_host) + Exception.__init__(self, message) + + +class UnsupportedOperation(DriverException): + """ + An attempt was made to use a feature that is not supported by the + selected protocol version. See :attr:`Cluster.protocol_version` + for more details. + """ + pass + + +class UnresolvableContactPoints(DriverException): + """ + The driver was unable to resolve any provided hostnames. + + Note that this is *not* raised when a :class:`.Cluster` is created with no + contact points, only when lookup fails for all hosts + """ + pass + +class DependencyException(Exception): + """ + Specific exception class for handling issues with driver dependencies + """ + + excs = [] + """ + A sequence of child exceptions + """ + + def __init__(self, msg, excs=[]): + complete_msg = msg + if excs: + complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs)) + Exception.__init__(self, complete_msg) diff --git a/cassandra/auth.py b/cassandra/auth.py new file mode 100644 index 0000000000..86759afe4d --- /dev/null +++ b/cassandra/auth.py @@ -0,0 +1,309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging + +try: + import kerberos + _have_kerberos = True +except ImportError: + _have_kerberos = False + +try: + from puresasl.client import SASLClient + _have_puresasl = True +except ImportError: + _have_puresasl = False + +try: + from puresasl.client import SASLClient +except ImportError: + SASLClient = None + +log = logging.getLogger(__name__) + +# Custom payload keys related to DSE Unified Auth +_proxy_execute_key = 'ProxyExecute' + + +class AuthProvider(object): + """ + An abstract class that defines the interface that will be used for + creating :class:`~.Authenticator` instances when opening new + connections to Cassandra. + + .. versionadded:: 2.0.0 + """ + + def new_authenticator(self, host): + """ + Implementations of this class should return a new instance + of :class:`~.Authenticator` or one of its subclasses. + """ + raise NotImplementedError() + + +class Authenticator(object): + """ + An abstract class that handles SASL authentication with Cassandra servers. + + Each time a new connection is created and the server requires authentication, + a new instance of this class will be created by the corresponding + :class:`~.AuthProvider` to handler that authentication. The lifecycle of the + new :class:`~.Authenticator` will the be: + + 1) The :meth:`~.initial_response()` method will be called. The return + value will be sent to the server to initiate the handshake. + + 2) The server will respond to each client response by either issuing a + challenge or indicating that the authentication is complete (successful or not). + If a new challenge is issued, :meth:`~.evaluate_challenge()` + will be called to produce a response that will be sent to the + server. This challenge/response negotiation will continue until the server + responds that authentication is successful (or an :exc:`~.AuthenticationFailed` + is raised). + + 3) When the server indicates that authentication is successful, + :meth:`~.on_authentication_success` will be called a token string that + the server may optionally have sent. + + The exact nature of the negotiation between the client and server is specific + to the authentication mechanism configured server-side. + + .. versionadded:: 2.0.0 + """ + + server_authenticator_class = None + """ Set during the connection AUTHENTICATE phase """ + + def initial_response(self): + """ + Returns a message to send to the server to initiate the SASL handshake. + :const:`None` may be returned to send an empty message. + """ + return None + + def evaluate_challenge(self, challenge): + """ + Called when the server sends a challenge message. Generally, this method + should return :const:`None` when authentication is complete from a + client perspective. Otherwise, a string should be returned. + """ + raise NotImplementedError() + + def on_authentication_success(self, token): + """ + Called when the server indicates that authentication was successful. + Depending on the authentication mechanism, `token` may be :const:`None` + or a string. + """ + pass + + +class PlainTextAuthProvider(AuthProvider): + """ + An :class:`~.AuthProvider` that works with Cassandra's PasswordAuthenticator. + + Example usage:: + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + auth_provider = PlainTextAuthProvider( + username='cassandra', password='cassandra') + cluster = Cluster(auth_provider=auth_provider) + + .. versionadded:: 2.0.0 + """ + + def __init__(self, username, password): + self.username = username + self.password = password + + def new_authenticator(self, host): + return PlainTextAuthenticator(self.username, self.password) + + +class TransitionalModePlainTextAuthProvider(object): + """ + An :class:`~.AuthProvider` that works with DSE TransitionalModePlainTextAuthenticator. + + Example usage:: + + from cassandra.cluster import Cluster + from cassandra.auth import TransitionalModePlainTextAuthProvider + + auth_provider = TransitionalModePlainTextAuthProvider() + cluster = Cluster(auth_provider=auth_provider) + + .. warning:: TransitionalModePlainTextAuthProvider will be removed in cassandra-driver + 4.0. The transitional mode will be handled internally without the need + of any auth provider. + """ + + def __init__(self): + # TODO remove next major + log.warning("TransitionalModePlainTextAuthProvider will be removed in cassandra-driver " + "4.0. The transitional mode will be handled internally without the need " + "of any auth provider.") + + def new_authenticator(self, host): + return TransitionalModePlainTextAuthenticator() + + +class SaslAuthProvider(AuthProvider): + """ + An :class:`~.AuthProvider` supporting general SASL auth mechanisms + + Suitable for GSSAPI or other SASL mechanisms + + Example usage:: + + from cassandra.cluster import Cluster + from cassandra.auth import SaslAuthProvider + + sasl_kwargs = {'service': 'something', + 'mechanism': 'GSSAPI', + 'qops': 'auth'.split(',')} + auth_provider = SaslAuthProvider(**sasl_kwargs) + cluster = Cluster(auth_provider=auth_provider) + + .. versionadded:: 2.1.4 + """ + + def __init__(self, **sasl_kwargs): + if SASLClient is None: + raise ImportError('The puresasl library has not been installed') + if 'host' in sasl_kwargs: + raise ValueError("kwargs should not contain 'host' since it is passed dynamically to new_authenticator") + self.sasl_kwargs = sasl_kwargs + + def new_authenticator(self, host): + return SaslAuthenticator(host, **self.sasl_kwargs) + + +class SaslAuthenticator(Authenticator): + """ + A pass-through :class:`~.Authenticator` using the third party package + 'pure-sasl' for authentication + + .. versionadded:: 2.1.4 + """ + + def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): + if SASLClient is None: + raise ImportError('The puresasl library has not been installed') + self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) + + def initial_response(self): + return self.sasl.process() + + def evaluate_challenge(self, challenge): + return self.sasl.process(challenge) + +# TODO remove me next major +DSEPlainTextAuthProvider = PlainTextAuthProvider + + +class DSEGSSAPIAuthProvider(AuthProvider): + """ + Auth provider for GSS API authentication. Works with legacy `KerberosAuthenticator` + or `DseAuthenticator` if `kerberos` scheme is enabled. + """ + def __init__(self, service='dse', qops=('auth',), resolve_host_name=True, **properties): + """ + :param service: name of the service + :param qops: iterable of "Quality of Protection" allowed; see ``puresasl.QOP`` + :param resolve_host_name: boolean flag indicating whether the authenticator should reverse-lookup an FQDN when + creating a new authenticator. Default is ``True``, which will resolve, or return the numeric address if there is no PTR + record. Setting ``False`` creates the authenticator with the numeric address known by Cassandra + :param properties: additional keyword properties to pass for the ``puresasl.mechanisms.GSSAPIMechanism`` class. + Presently, 'principal' (user) is the only one referenced in the ``pure-sasl`` implementation + """ + if not _have_puresasl: + raise ImportError('The puresasl library has not been installed') + if not _have_kerberos: + raise ImportError('The kerberos library has not been installed') + self.service = service + self.qops = qops + self.resolve_host_name = resolve_host_name + self.properties = properties + + def new_authenticator(self, host): + if self.resolve_host_name: + host = socket.getnameinfo((host, 0), 0)[0] + return GSSAPIAuthenticator(host, self.service, self.qops, self.properties) + + +class BaseDSEAuthenticator(Authenticator): + def get_mechanism(self): + raise NotImplementedError("get_mechanism not implemented") + + def get_initial_challenge(self): + raise NotImplementedError("get_initial_challenge not implemented") + + def initial_response(self): + if self.server_authenticator_class == "com.datastax.bdp.cassandra.auth.DseAuthenticator": + return self.get_mechanism() + else: + return self.evaluate_challenge(self.get_initial_challenge()) + + +class PlainTextAuthenticator(BaseDSEAuthenticator): + + def __init__(self, username, password): + self.username = username + self.password = password + + def get_mechanism(self): + return b"PLAIN" + + def get_initial_challenge(self): + return b"PLAIN-START" + + def evaluate_challenge(self, challenge): + if challenge == b'PLAIN-START': + data = "\x00%s\x00%s" % (self.username, self.password) + return data.encode() + raise Exception('Did not receive a valid challenge response from server') + + +class TransitionalModePlainTextAuthenticator(PlainTextAuthenticator): + """ + Authenticator that accounts for DSE authentication is configured with transitional mode. + """ + + def __init__(self): + super(TransitionalModePlainTextAuthenticator, self).__init__('', '') + + +class GSSAPIAuthenticator(BaseDSEAuthenticator): + def __init__(self, host, service, qops, properties): + properties = properties or {} + self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties) + + def get_mechanism(self): + return b"GSSAPI" + + def get_initial_challenge(self): + return b"GSSAPI-START" + + def evaluate_challenge(self, challenge): + if challenge == b'GSSAPI-START': + return self.sasl.process() + else: + return self.sasl.process(challenge) diff --git a/cassandra/buffer.pxd b/cassandra/buffer.pxd new file mode 100644 index 0000000000..3383fcd272 --- /dev/null +++ b/cassandra/buffer.pxd @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple buffer data structure that provides a view on existing memory +(e.g. from a bytes object). This memory must stay alive while the +buffer is in use. +""" + +from cpython.bytes cimport PyBytes_AS_STRING + # char* PyBytes_AS_STRING(object string) + # Macro form of PyBytes_AsString() but without error + # checking. Only string objects are supported; no Unicode objects + # should be passed. + + +cdef struct Buffer: + char *ptr + Py_ssize_t size + + +cdef inline bytes to_bytes(Buffer *buf): + return buf.ptr[:buf.size] + +cdef inline char *buf_ptr(Buffer *buf): + return buf.ptr + +cdef inline char *buf_read(Buffer *buf, Py_ssize_t size) except NULL: + if size > buf.size: + raise IndexError("Requested more than length of buffer") + return buf.ptr + +cdef inline int slice_buffer(Buffer *buf, Buffer *out, + Py_ssize_t start, Py_ssize_t size) except -1: + if size < 0: + raise ValueError("Length must be positive") + + if start + size > buf.size: + raise IndexError("Buffer slice out of bounds") + + out.ptr = buf.ptr + start + out.size = size + return 0 + +cdef inline void from_ptr_and_size(char *ptr, Py_ssize_t size, Buffer *out): + out.ptr = ptr + out.size = size diff --git a/cassandra/bytesio.pxd b/cassandra/bytesio.pxd new file mode 100644 index 0000000000..24320f0ae1 --- /dev/null +++ b/cassandra/bytesio.pxd @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef class BytesIOReader: + cdef bytes buf + cdef char *buf_ptr + cdef Py_ssize_t pos + cdef Py_ssize_t size + cdef char *read(self, Py_ssize_t n = ?) except NULL diff --git a/cassandra/bytesio.pyx b/cassandra/bytesio.pyx new file mode 100644 index 0000000000..d9781035ef --- /dev/null +++ b/cassandra/bytesio.pyx @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef class BytesIOReader: + """ + This class provides efficient support for reading bytes from a 'bytes' buffer, + by returning char * values directly without allocating intermediate objects. + """ + + def __init__(self, bytes buf): + self.buf = buf + self.size = len(buf) + self.buf_ptr = self.buf + + cdef char *read(self, Py_ssize_t n = -1) except NULL: + """Read at most size bytes from the file + (less if the read hits EOF before obtaining size bytes). + + If the size argument is negative or omitted, read all data until EOF + is reached. The bytes are returned as a string object. An empty + string is returned when EOF is encountered immediately. + """ + cdef Py_ssize_t newpos = self.pos + n + if n < 0: + newpos = self.size + elif newpos > self.size: + # Raise an error here, as we do not want the caller to consume past the + # end of the buffer + raise EOFError("Cannot read past the end of the file") + + cdef char *res = self.buf_ptr + self.pos + self.pos = newpos + return res diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 10fd1c4bbb..6b2ab4b288 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1,52 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ This module houses the main classes you will interact with, :class:`.Cluster` and :class:`.Session`. """ -from concurrent.futures import ThreadPoolExecutor +import atexit +from binascii import hexlify +from collections import defaultdict +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures +from copy import copy +from functools import partial, reduce, wraps +from itertools import groupby, count, chain +import json import logging +from warnings import warn +from random import random +import re +import queue +import socket +import sys import time from threading import Lock, RLock, Thread, Event -import Queue +import uuid + import weakref +from weakref import WeakValueDictionary + +from cassandra import (ConsistencyLevel, AuthenticationFailed, + OperationTimedOut, UnsupportedOperation, + SchemaTargetType, DriverException, ProtocolVersion, + UnresolvableContactPoints, DependencyException) +from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider +from cassandra.connection import (ConnectionException, ConnectionShutdown, + ConnectionHeartbeat, ProtocolVersionUnsupported, + EndPoint, DefaultEndPoint, DefaultEndPointFactory, + ContinuousPagingState, SniEndPointFactory, ConnectionBusy) +from cassandra.cqltypes import UserType +from cassandra.encoder import Encoder +from cassandra.protocol import (QueryMessage, ResultMessage, + ErrorMessage, ReadTimeoutErrorMessage, + WriteTimeoutErrorMessage, + UnavailableErrorMessage, + OverloadedErrorMessage, + PrepareMessage, ExecuteMessage, + PreparedQueryNotFound, + IsBootstrappingErrorMessage, + TruncateError, ServerError, + BatchMessage, RESULT_KIND_PREPARED, + RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, + RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler, + RESULT_KIND_VOID, ProtocolException) +from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo +from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, + ExponentialReconnectionPolicy, HostDistance, + RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, + NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy, + NeverRetryPolicy) +from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, + HostConnectionPool, HostConnection, + NoConnectionsAvailable) +from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, + BatchStatement, bind_params, QueryTrace, TraceUnavailable, + named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET, + HostTargetingStatement) +from cassandra.marshal import int64_pack +from cassandra.timestamps import MonotonicTimestampGenerator +from cassandra.util import _resolve_contact_points_to_string_map, Version + +from cassandra.datastax.insights.reporter import MonitorReporter +from cassandra.datastax.insights.util import version_supports_insights + +from cassandra.datastax.graph import (graph_object_row_factory, GraphOptions, GraphSON1Serializer, + GraphProtocol, GraphSON2Serializer, GraphStatement, SimpleGraphStatement, + graph_graphson2_row_factory, graph_graphson3_row_factory, + GraphSON3Serializer) +from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory +from cassandra.datastax import cloud as dscloud + try: - from weakref import WeakSet + from cassandra.io.twistedreactor import TwistedConnection except ImportError: - from cassandra.util import WeakSet # NOQA + TwistedConnection = None -from functools import partial -from itertools import groupby - -from cassandra import ConsistencyLevel, AuthenticationFailed -from cassandra.connection import ConnectionException, ConnectionShutdown -from cassandra.decoder import (QueryMessage, ResultMessage, - ErrorMessage, ReadTimeoutErrorMessage, - WriteTimeoutErrorMessage, - UnavailableErrorMessage, - OverloadedErrorMessage, - PrepareMessage, ExecuteMessage, - PreparedQueryNotFound, - IsBootstrappingErrorMessage, named_tuple_factory, - dict_factory) -from cassandra.metadata import Metadata -from cassandra.metrics import Metrics -from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy, - ExponentialReconnectionPolicy, HostDistance, - RetryPolicy) -from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, - bind_params, QueryTrace, Statement) -from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler, - HostConnectionPool) +try: + from cassandra.io.eventletreactor import EventletConnection +# PYTHON-1364 +# +# At the moment eventlet initialization is chucking AttributeErrors due to its dependence on pyOpenSSL +# and some changes in Python 3.12 which have some knock-on effects there. +except (ImportError, AttributeError): + EventletConnection = None -# libev is all around faster, so we want to try and default to using that when we can try: - from cassandra.io.libevreactor import LibevConnection as DefaultConnection + from weakref import WeakSet except ImportError: - from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA + from cassandra.util import WeakSet # NOQA +def _is_gevent_monkey_patched(): + if 'gevent.monkey' not in sys.modules: + return False + import gevent.socket + return socket.socket is gevent.socket.socket + +def _try_gevent_import(): + if _is_gevent_monkey_patched(): + from cassandra.io.geventreactor import GeventConnection + return (GeventConnection,None) + else: + return (None,None) + +def _is_eventlet_monkey_patched(): + if 'eventlet.patcher' not in sys.modules: + return False + try: + import eventlet.patcher + return eventlet.patcher.is_monkey_patched('socket') + # Another case related to PYTHON-1364 + except AttributeError: + return False + +def _try_eventlet_import(): + if _is_eventlet_monkey_patched(): + from cassandra.io.eventletreactor import EventletConnection + return (EventletConnection,None) + else: + return (None,None) + +def _try_libev_import(): + try: + from cassandra.io.libevreactor import LibevConnection + return (LibevConnection,None) + except DependencyException as e: + return (None, e) + +def _try_asyncore_import(): + try: + from cassandra.io.asyncorereactor import AsyncoreConnection + return (AsyncoreConnection,None) + except DependencyException as e: + return (None, e) + +def _connection_reduce_fn(val,import_fn): + (rv, excs) = val + # If we've already found a workable Connection class return immediately + if rv: + return val + (import_result, exc) = import_fn() + if exc: + excs.append(exc) + return (rv or import_result, excs) log = logging.getLogger(__name__) +conn_fns = (_try_gevent_import, _try_eventlet_import, _try_libev_import, _try_asyncore_import) +(conn_class, excs) = reduce(_connection_reduce_fn, conn_fns, (None,[])) +if not conn_class: + raise DependencyException("Unable to load a default connection class", excs) +DefaultConnection = conn_class + +# Forces load of utf8 encoding module to avoid deadlock that occurs +# if code that is being imported tries to import the module in a separate +# thread. +# See http://bugs.python.org/issue10923 +"".encode('utf8') + DEFAULT_MIN_REQUESTS = 5 DEFAULT_MAX_REQUESTS = 100 @@ -57,6 +190,10 @@ DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1 DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2 +_GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0') + +_NOT_SET = object() + class NoHostAvailable(Exception): """ @@ -76,457 +213,2142 @@ def __init__(self, message, errors): self.errors = errors -class Cluster(object): +def _future_completed(future): + """ Helper for run_in_executor() """ + exc = future.exception() + if exc: + log.debug("Failed to run task on executor", exc_info=exc) + + +def run_in_executor(f): + """ + A decorator to run the given method in the ThreadPoolExecutor. """ - The main class to use when interacting with a Cassandra cluster. - Typically, one instance of this class will be created for each - separate Cassandra cluster that your application interacts with. - Example usage:: + @wraps(f) + def new_f(self, *args, **kwargs): - >>> from cassandra.cluster import Cluster - >>> cluster = Cluster(['192.168.1.1', '192.168.1.2']) - >>> session = cluster.connect() - >>> session.execute("CREATE KEYSPACE ...") - >>> ... - >>> cluster.shutdown() + if self.is_shutdown: + return + try: + future = self.executor.submit(f, self, *args, **kwargs) + future.add_done_callback(_future_completed) + except Exception: + log.exception("Failed to submit task to executor") - """ + return new_f - port = 9042 + +_clusters_for_shutdown = set() + + +def _register_cluster_shutdown(cluster): + _clusters_for_shutdown.add(cluster) + + +def _discard_cluster_shutdown(cluster): + _clusters_for_shutdown.discard(cluster) + + +def _shutdown_clusters(): + clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" + for cluster in clusters: + cluster.shutdown() + + +atexit.register(_shutdown_clusters) + + +def default_lbp_factory(): + if murmur3 is not None: + return TokenAwarePolicy(DCAwareRoundRobinPolicy()) + return DCAwareRoundRobinPolicy() + + +class ContinuousPagingOptions(object): + + class PagingUnit(object): + BYTES = 1 + ROWS = 2 + + page_unit = None """ - The server-side port to open connections to. Defaults to 9042. + Value of PagingUnit. Default is PagingUnit.ROWS. + + Units refer to the :attr:`~.Statement.fetch_size` or :attr:`~.Session.default_fetch_size`. """ - compression = True + max_pages = None """ - Whether or not compression should be enabled when possible. Defaults to - :const:`True` and attempts to use snappy compression. + Max number of pages to send """ - auth_provider = None + max_pages_per_second = None """ - An optional function that accepts one argument, the IP address of a node, - and returns a dict of credentials for that node. + Max rate at which to send pages """ - load_balancing_policy = RoundRobinPolicy() + max_queue_size = None """ - An instance of :class:`.policies.LoadBalancingPolicy` or - one of its subclasses. Defaults to :class:`~.RoundRobinPolicy`. + The maximum queue size for caching pages, only honored for protocol version DSE_V2 and higher, + by default it is 4 and it must be at least 2. """ - reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) + def __init__(self, page_unit=PagingUnit.ROWS, max_pages=0, max_pages_per_second=0, max_queue_size=4): + self.page_unit = page_unit + self.max_pages = max_pages + self.max_pages_per_second = max_pages_per_second + if max_queue_size < 2: + raise ValueError('ContinuousPagingOptions.max_queue_size must be 2 or greater') + self.max_queue_size = max_queue_size + + def page_unit_bytes(self): + return self.page_unit == ContinuousPagingOptions.PagingUnit.BYTES + + +def _addrinfo_or_none(contact_point, port): """ - An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance - of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and - a max delay of ten minutes. + A helper function that wraps socket.getaddrinfo and returns None + when it fails to, e.g. resolve one of the hostnames. Used to address + PYTHON-895. """ + try: + return socket.getaddrinfo(contact_point, port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + log.debug('Could not resolve hostname "{}" ' + 'with port {}'.format(contact_point, port)) + return None + + +def _execution_profile_to_string(name): + default_profiles = { + EXEC_PROFILE_DEFAULT: 'EXEC_PROFILE_DEFAULT', + EXEC_PROFILE_GRAPH_DEFAULT: 'EXEC_PROFILE_GRAPH_DEFAULT', + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT: 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT', + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT', + } + + if name in default_profiles: + return default_profiles[name] - default_retry_policy = RetryPolicy() + return '"%s"' % (name,) + + +class ExecutionProfile(object): + load_balancing_policy = None """ - A default :class:`.policies.RetryPolicy` instance to use for all - :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` - explicitly set. + An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. + + Used in determining host distance for establishing connections, and routing requests. + + Defaults to ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` if not specified """ - conviction_policy_factory = SimpleConvictionPolicy + retry_policy = None """ - A factory function which creates instances of - :class:`.policies.ConvictionPolicy`. Defaults to - :class:`.policies.SimpleConvictionPolicy`. + An instance of :class:`.policies.RetryPolicy` instance used when :class:`.Statement` objects do not have a + :attr:`~.Statement.retry_policy` explicitly set. + + Defaults to :class:`.RetryPolicy` if not specified """ - metrics_enabled = False + consistency_level = ConsistencyLevel.LOCAL_ONE """ - Whether or not metric collection is enabled. + :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement`. """ - metrics = None + serial_consistency_level = None """ - An instance of :class:`.metrics.Metrics` if :attr:`.metrics_enabled` is - :const:`True`, else :const:`None`. + Serial :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement` (for LWT conditional statements). """ - sockopts = None + request_timeout = 10.0 """ - An optional list of tuples which will be used as arguments to - ``socket.setsockopt()`` for all created sockets. + Request timeout used when not overridden in :meth:`.Session.execute` """ - max_schema_agreement_wait = 10 + row_factory = staticmethod(named_tuple_factory) """ - The maximum duration (in seconds) that the driver will wait for schema - agreement across the cluster. Defaults to ten seconds. + A callable to format results, accepting ``(colnames, rows)`` where ``colnames`` is a list of column names, and + ``rows`` is a list of tuples, with each tuple representing a row of parsed values. + + Some example implementations: + + - :func:`cassandra.query.tuple_factory` - return a result row as a tuple + - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple + - :func:`cassandra.query.dict_factory` - return a result row as a dict + - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ - metadata = None + speculative_execution_policy = None """ - An instance of :class:`cassandra.metadata.Metadata`. + An instance of :class:`.policies.SpeculativeExecutionPolicy` + + Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified """ - connection_class = DefaultConnection + continuous_paging_options = None """ - This determines what event loop system will be used for managing - I/O with Cassandra. These are the current options: + *Note:* This feature is implemented to facilitate server integration testing. It is not intended for general use in the Python driver. + See :attr:`.Statement.fetch_size` or :attr:`Session.default_fetch_size` for configuring normal paging. - * :class:`cassandra.io.asyncorereactor.AsyncoreConnection` - * :class:`cassandra.io.libevreactor.LibevConnection` + When set, requests will use DSE's continuous paging, which streams multiple pages without + intermediate requests. - By default, ``AsyncoreConnection`` will be used, which uses - the ``asyncore`` module in the Python standard library. The - performance is slightly worse than with ``libev``, but it is - supported on a wider range of systems. + This has the potential to materialize all results in memory at once if the consumer cannot keep up. Use options + to constrain page size and rate. - If ``libev`` is installed, ``LibevConnection`` will be used instead. + This is only available for DSE clusters. """ - sessions = None - control_connection = None - scheduler = None - executor = None - _is_shutdown = False - _is_setup = False - _prepared_statements = None - - def __init__(self, - contact_points=("127.0.0.1",), - port=9042, - compression=True, - auth_provider=None, - load_balancing_policy=None, - reconnection_policy=None, - default_retry_policy=None, - conviction_policy_factory=None, - metrics_enabled=False, - connection_class=None, - sockopts=None, - executor_threads=2, - max_schema_agreement_wait=10): - """ - Any of the mutable Cluster attributes may be set as keyword arguments - to the constructor. - """ - - self.contact_points = contact_points - self.port = port - self.compression = compression + # indicates if lbp was set explicitly or uses default values + _load_balancing_policy_explicit = False + _consistency_level_explicit = False - if auth_provider is not None: - if not callable(auth_provider): - raise ValueError("auth_provider must be callable") - self.auth_provider = auth_provider + def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, + consistency_level=_NOT_SET, serial_consistency_level=None, + request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None, + continuous_paging_options=None): - if load_balancing_policy is not None: + if load_balancing_policy is _NOT_SET: + self._load_balancing_policy_explicit = False + self.load_balancing_policy = default_lbp_factory() + else: + self._load_balancing_policy_explicit = True self.load_balancing_policy = load_balancing_policy - if reconnection_policy is not None: - self.reconnection_policy = reconnection_policy + if consistency_level is _NOT_SET: + self._consistency_level_explicit = False + self.consistency_level = ConsistencyLevel.LOCAL_ONE + else: + self._consistency_level_explicit = True + self.consistency_level = consistency_level - if default_retry_policy is not None: - self.default_retry_policy = default_retry_policy + self.retry_policy = retry_policy or RetryPolicy() - if conviction_policy_factory is not None: - if not callable(conviction_policy_factory): - raise ValueError("conviction_policy_factory must be callable") - self.conviction_policy_factory = conviction_policy_factory + if (serial_consistency_level is not None and + not ConsistencyLevel.is_serial(serial_consistency_level)): + raise ValueError("serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL.") + self.serial_consistency_level = serial_consistency_level - if connection_class is not None: - self.connection_class = connection_class + self.request_timeout = request_timeout + self.row_factory = row_factory + self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() + self.continuous_paging_options = continuous_paging_options - self.metrics_enabled = metrics_enabled - self.sockopts = sockopts - self.max_schema_agreement_wait = max_schema_agreement_wait - # let Session objects be GC'ed (and shutdown) when the user no longer - # holds a reference. Normally the cycle detector would handle this, - # but implementing __del__ prevents that. - self.sessions = WeakSet() - self.metadata = Metadata(self) - self.control_connection = None - self._prepared_statements = {} +class GraphExecutionProfile(ExecutionProfile): + graph_options = None + """ + :class:`.GraphOptions` to use with this execution - self._min_requests_per_connection = { - HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, - HostDistance.REMOTE: DEFAULT_MIN_REQUESTS - } + Default options for graph queries, initialized as follows by default:: - self._max_requests_per_connection = { - HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, - HostDistance.REMOTE: DEFAULT_MAX_REQUESTS - } + GraphOptions(graph_language=b'gremlin-groovy') - self._core_connections_per_host = { - HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, - HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST - } + See cassandra.graph.GraphOptions + """ - self._max_connections_per_host = { - HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, - HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST - } + def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, + consistency_level=_NOT_SET, serial_consistency_level=None, + request_timeout=30.0, row_factory=None, + graph_options=None, continuous_paging_options=_NOT_SET): + """ + Default execution profile for graph execution. - self.executor = ThreadPoolExecutor(max_workers=executor_threads) - self.scheduler = _Scheduler(self.executor) + See :class:`.ExecutionProfile` for base attributes. Note that if not explicitly set, + the row_factory and graph_options.graph_protocol are resolved during the query execution. + These options will resolve to graph_graphson3_row_factory and GraphProtocol.GRAPHSON_3_0 + for the core graph engine (DSE 6.8+), otherwise graph_object_row_factory and GraphProtocol.GRAPHSON_1_0 - self._lock = RLock() + In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to + :class:`cassandra.policies.NeverRetryPolicy`. + """ + retry_policy = retry_policy or NeverRetryPolicy() + super(GraphExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level, + serial_consistency_level, request_timeout, row_factory, + continuous_paging_options=continuous_paging_options) + self.graph_options = graph_options or GraphOptions(graph_source=b'g', + graph_language=b'gremlin-groovy') - if self.metrics_enabled: - self.metrics = Metrics(weakref.proxy(self)) - self.control_connection = ControlConnection(self) - for address in contact_points: - self.add_host(address, signal=True) +class GraphAnalyticsExecutionProfile(GraphExecutionProfile): - def get_min_requests_per_connection(self, host_distance): - return self._min_requests_per_connection[host_distance] + def __init__(self, load_balancing_policy=None, retry_policy=None, + consistency_level=_NOT_SET, serial_consistency_level=None, + request_timeout=3600. * 24. * 7., row_factory=None, + graph_options=None): + """ + Execution profile with timeout and load balancing appropriate for graph analytics queries. - def set_min_requests_per_connection(self, host_distance, min_requests): - self._min_requests_per_connection[host_distance] = min_requests + See also :class:`~.GraphExecutionPolicy`. - def get_max_requests_per_connection(self, host_distance): - return self._max_requests_per_connection[host_distance] + In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to + :class:`cassandra.policies.NeverRetryPolicy`, and ``load_balancing_policy`` to one that targets the current Spark + master. - def set_max_requests_per_connection(self, host_distance, max_requests): - self._max_requests_per_connection[host_distance] = max_requests + Note: The graph_options.graph_source is set automatically to b'a' (analytics) + when using GraphAnalyticsExecutionProfile. This is mandatory to target analytics nodes. + """ + load_balancing_policy = load_balancing_policy or DefaultLoadBalancingPolicy(default_lbp_factory()) + graph_options = graph_options or GraphOptions(graph_language=b'gremlin-groovy') + super(GraphAnalyticsExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level, + serial_consistency_level, request_timeout, row_factory, + graph_options) + # ensure the graph_source is analytics, since this is the purpose of the GraphAnalyticsExecutionProfile + self.graph_options.set_source_analytics() + + +class ProfileManager(object): + + def __init__(self): + self.profiles = dict() + + def _profiles_without_explicit_lbps(self): + names = (profile_name for + profile_name, profile in self.profiles.items() + if not profile._load_balancing_policy_explicit) + return tuple( + 'EXEC_PROFILE_DEFAULT' if n is EXEC_PROFILE_DEFAULT else n + for n in names + ) + + def distance(self, host): + distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) + return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ + HostDistance.REMOTE if HostDistance.REMOTE in distances else \ + HostDistance.IGNORED + + def populate(self, cluster, hosts): + for p in self.profiles.values(): + p.load_balancing_policy.populate(cluster, hosts) + + def check_supported(self): + for p in self.profiles.values(): + p.load_balancing_policy.check_supported() - def get_core_connections_per_host(self, host_distance): - return self._core_connections_per_host[host_distance] + def on_up(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_up(host) - def set_core_connections_per_host(self, host_distance, core_connections): - old = self._core_connections_per_host[host_distance] - self._core_connections_per_host[host_distance] = core_connections - if old < core_connections: - self.ensure_core_connections() + def on_down(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_down(host) - def get_max_connections_per_host(self, host_distance): - return self._max_connections_per_host[host_distance] + def on_add(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_add(host) - def set_max_connections_per_host(self, host_distance, max_connections): - self._max_connections_per_host[host_distance] = max_connections + def on_remove(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_remove(host) - def connection_factory(self, address, *args, **kwargs): + @property + def default(self): """ - Called to create a new connection with proper configuration. - Intended for internal use only. + internal-only; no checks are done because this entry is populated on cluster init """ - if self.auth_provider: - kwargs['credentials'] = self.auth_provider(address) + return self.profiles[EXEC_PROFILE_DEFAULT] - kwargs['port'] = self.port - kwargs['compression'] = self.compression - kwargs['sockopts'] = self.sockopts - return self.connection_class.factory(address, *args, **kwargs) +EXEC_PROFILE_DEFAULT = object() +""" +Key for the ``Cluster`` default execution profile, used when no other profile is selected in +``Session.execute(execution_profile)``. - def _make_connection_factory(self, host, *args, **kwargs): - if self.auth_provider: - kwargs['credentials'] = self.auth_provider(host) +Use this as the key in ``Cluster(execution_profiles)`` to override the default profile. +""" - kwargs['port'] = self.port - kwargs['compression'] = self.compression - kwargs['sockopts'] = self.sockopts +EXEC_PROFILE_GRAPH_DEFAULT = object() +""" +Key for the default graph execution profile, used when no other profile is selected in +``Session.execute_graph(execution_profile)``. - return partial(self.connection_class.factory, host.address, *args, **kwargs) +Use this as the key in :doc:`Cluster(execution_profiles) ` +to override the default graph profile. +""" - def connect(self, keyspace=None): - """ - Creates and returns a new :class:`~.Session` object. If `keyspace` - is specified, that keyspace will be the default keyspace for - operations on the ``Session``. - """ - with self._lock: - if self._is_shutdown: - raise Exception("Cluster is already shut down") +EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT = object() +""" +Key for the default graph system execution profile. This can be used for graph statements using the DSE graph +system API. - if not self._is_setup: - self.load_balancing_policy.populate( - weakref.proxy(self), self.metadata.all_hosts()) - self._is_setup = True +Selected using ``Session.execute_graph(execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)``. +""" - if self.control_connection: - try: - self.control_connection.connect() - log.debug("Control connection created") - except Exception: - log.exception("Control connection failed to connect, " - "shutting down Cluster:") - self.shutdown() - raise +EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT = object() +""" +Key for the default graph analytics execution profile. This can be used for graph statements intended to +use Spark/analytics as the traversal source. - session = self._new_session() - if keyspace: - session.set_keyspace(keyspace) - return session +Selected using ``Session.execute_graph(execution_profile=EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT)``. +""" - def shutdown(self): - """ - Closes all sessions and connection associated with this Cluster. - Once shutdown, a Cluster should not be used for any purpose. - """ - with self._lock: - if self._is_shutdown: - raise Exception("The Cluster was already shutdown") - else: - self._is_shutdown = True - if self.scheduler: - self.scheduler.shutdown() +class _ConfigMode(object): + UNCOMMITTED = 0 + LEGACY = 1 + PROFILES = 2 - if self.control_connection: - self.control_connection.shutdown() - if self.sessions: - for session in self.sessions: - session.shutdown() +class Cluster(object): + """ + The main class to use when interacting with a Cassandra cluster. + Typically, one instance of this class will be created for each + separate Cassandra cluster that your application interacts with. - if self.executor: - self.executor.shutdown() + Example usage:: - def __del__(self): - # we don't use shutdown() because we want to avoid shutting down - # Sessions while they are still being used (in case there are no - # longer any references to this Cluster object, but there are - # still references to the Session object) - if not self._is_shutdown: - if self.scheduler: - self.scheduler.shutdown() - if self.control_connection: - self.control_connection.shutdown() - if self.executor: - self.executor.shutdown(wait=False) - - def _new_session(self): - session = Session(self, self.metadata.all_hosts()) - self.sessions.add(session) - return session + >>> from cassandra.cluster import Cluster + >>> cluster = Cluster(['192.168.1.1', '192.168.1.2']) + >>> session = cluster.connect() + >>> session.execute("CREATE KEYSPACE ...") + >>> ... + >>> cluster.shutdown() - def on_up(self, host): - """ - Called when a host is marked up by its :class:`~.HealthMonitor`. - Intended for internal use only. - """ - reconnector = host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() + ``Cluster`` and ``Session`` also provide context management functions + which implicitly handle shutdown when leaving scope. + """ - self._prepare_all_queries(host) + contact_points = ['127.0.0.1'] + """ + The list of contact points to try connecting for cluster discovery. A + contact point can be a string (ip or hostname), a tuple (ip/hostname, port) or a + :class:`.connection.EndPoint` instance. - self.control_connection.on_up(host) - for session in self.sessions: - session.on_up(host) + Defaults to loopback interface. - def on_down(self, host): - """ - Called when a host is marked down by its :class:`~.HealthMonitor`. - Intended for internal use only. - """ - self.control_connection.on_down(host) - for session in self.sessions: + Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit + local_dc set (as is the default), the DC is chosen from an arbitrary + host in contact_points. In this case, contact_points should contain + only nodes from a single, local DC. + + Note: In the next major version, if you specify contact points, you will + also be required to also explicitly specify a load-balancing policy. This + change will help prevent cases where users had hard-to-debug issues + surrounding unintuitive default load-balancing policy behavior. + """ + # tracks if contact_points was set explicitly or with default values + _contact_points_explicit = None + + port = 9042 + """ + The server-side port to open connections to. Defaults to 9042. + """ + + cql_version = None + """ + If a specific version of CQL should be used, this may be set to that + string version. Otherwise, the highest CQL version supported by the + server will be automatically used. + """ + + protocol_version = ProtocolVersion.DSE_V2 + """ + The maximum version of the native protocol to use. + + See :class:`.ProtocolVersion` for more information about versions. + + If not set in the constructor, the driver will automatically downgrade + version based on a negotiation with the server, but it is most efficient + to set this to the maximum supported by your version of Cassandra. + Setting this will also prevent conflicting versions negotiated if your + cluster is upgraded. + + """ + + allow_beta_protocol_version = False + + no_compact = False + + """ + Setting true injects a flag in all messages that makes the server accept and use "beta" protocol version. + Used for testing new protocol features incrementally before the new version is complete. + """ + + compression = True + """ + Controls compression for communications between the driver and Cassandra. + If left as the default of :const:`True`, either lz4 or snappy compression + may be used, depending on what is supported by both the driver + and Cassandra. If both are fully supported, lz4 will be preferred. + + You may also set this to 'snappy' or 'lz4' to request that specific + compression type. + + Setting this to :const:`False` disables compression. + """ + + _auth_provider = None + _auth_provider_callable = None + + @property + def auth_provider(self): + """ + When :attr:`~.Cluster.protocol_version` is 2 or higher, this should + be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, + such as :class:`~.PlainTextAuthProvider`. + + When :attr:`~.Cluster.protocol_version` is 1, this should be + a function that accepts one argument, the IP address of a node, + and returns a dict of credentials for that node. + + When not using authentication, this should be left as :const:`None`. + """ + return self._auth_provider + + @auth_provider.setter # noqa + def auth_provider(self, value): + if not value: + self._auth_provider = value + return + + try: + self._auth_provider_callable = value.new_authenticator + except AttributeError: + if self.protocol_version > 1: + raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " + "interface when protocol_version >= 2") + elif not callable(value): + raise TypeError("auth_provider must be callable when protocol_version == 1") + self._auth_provider_callable = value + + self._auth_provider = value + + _load_balancing_policy = None + @property + def load_balancing_policy(self): + """ + An instance of :class:`.policies.LoadBalancingPolicy` or + one of its subclasses. + + .. versionchanged:: 2.6.0 + + Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`). + when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy` + otherwise. Default local DC will be chosen from contact points. + + **Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to + DC locality and remote nodes.** + """ + return self._load_balancing_policy + + @load_balancing_policy.setter + def load_balancing_policy(self, lbp): + if self._config_mode == _ConfigMode.PROFILES: + raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.") + self._load_balancing_policy = lbp + self._config_mode = _ConfigMode.LEGACY + + @property + def _default_load_balancing_policy(self): + return self.profile_manager.default.load_balancing_policy + + reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) + """ + An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance + of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and + a max delay of ten minutes. + """ + + _default_retry_policy = RetryPolicy() + @property + def default_retry_policy(self): + """ + A default :class:`.policies.RetryPolicy` instance to use for all + :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` + explicitly set. + """ + return self._default_retry_policy + + @default_retry_policy.setter + def default_retry_policy(self, policy): + if self._config_mode == _ConfigMode.PROFILES: + raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.") + self._default_retry_policy = policy + self._config_mode = _ConfigMode.LEGACY + + conviction_policy_factory = SimpleConvictionPolicy + """ + A factory function which creates instances of + :class:`.policies.ConvictionPolicy`. Defaults to + :class:`.policies.SimpleConvictionPolicy`. + """ + + address_translator = IdentityTranslator() + """ + :class:`.policies.AddressTranslator` instance to be used in translating server node addresses + to driver connection addresses. + """ + + connect_to_remote_hosts = True + """ + If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` + by the :attr:`~.Cluster.load_balancing_policy` will have a connection + opened to them. Otherwise, they will not have a connection opened to them. + + Note that the default load balancing policy ignores remote hosts by default. + + .. versionadded:: 2.1.0 + """ + + metrics_enabled = False + """ + Whether or not metric collection is enabled. If enabled, :attr:`.metrics` + will be an instance of :class:`~cassandra.metrics.Metrics`. + """ + + metrics = None + """ + An instance of :class:`cassandra.metrics.Metrics` if :attr:`.metrics_enabled` is + :const:`True`, else :const:`None`. + """ + + ssl_options = None + """ + Using ssl_options without ssl_context is deprecated and will be removed in the + next major release. + + An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` + when new sockets are created. This should be used when client encryption is enabled + in Cassandra. + + The following documentation only applies when ssl_options is used without ssl_context. + + By default, a ``ca_certs`` value should be supplied (the value should be + a string pointing to the location of the CA certs file), and you probably + want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLS`` to match + Cassandra's default protocol. + + .. versionchanged:: 3.3.0 + + In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname + as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so + should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into + Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname`` + with a custom or `back-ported function `_. + + .. versionchanged:: 3.29.0 + + ``ssl.match_hostname`` has been deprecated since Python 3.7 (and removed in Python 3.12). This functionality is now implemented + via ``ssl.SSLContext.check_hostname``. All options specified above (including ``check_hostname``) should continue to behave in a + way that is consistent with prior implementations. + """ + + ssl_context = None + """ + An optional ``ssl.SSLContext`` instance which will be used when new sockets are created. + This should be used when client encryption is enabled in Cassandra. + + ``wrap_socket`` options can be set using :attr:`~Cluster.ssl_options`. ssl_options will + be used as kwargs for ``ssl.SSLContext.wrap_socket``. + + .. versionadded:: 3.17.0 + """ + + sockopts = None + """ + An optional list of tuples which will be used as arguments to + ``socket.setsockopt()`` for all created sockets. + + Note: some drivers find setting TCPNODELAY beneficial in the context of + their execution model. It was not found generally beneficial for this driver. + To try with your own workload, set ``sockopts = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` + """ + + max_schema_agreement_wait = 10 + """ + The maximum duration (in seconds) that the driver will wait for schema + agreement across the cluster. Defaults to ten seconds. + If set <= 0, the driver will bypass schema agreement waits altogether. + """ + + metadata = None + """ + An instance of :class:`cassandra.metadata.Metadata`. + """ + + connection_class = DefaultConnection + """ + This determines what event loop system will be used for managing + I/O with Cassandra. These are the current options: + + * :class:`cassandra.io.asyncorereactor.AsyncoreConnection` + * :class:`cassandra.io.libevreactor.LibevConnection` + * :class:`cassandra.io.eventletreactor.EventletConnection` (requires monkey-patching - see doc for details) + * :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching - see doc for details) + * :class:`cassandra.io.twistedreactor.TwistedConnection` + * EXPERIMENTAL: :class:`cassandra.io.asyncioreactor.AsyncioConnection` + + By default, ``AsyncoreConnection`` will be used, which uses + the ``asyncore`` module in the Python standard library. + + If ``libev`` is installed, ``LibevConnection`` will be used instead. + + If ``gevent`` or ``eventlet`` monkey-patching is detected, the corresponding + connection class will be used automatically. + + ``AsyncioConnection``, which uses the ``asyncio`` module in the Python + standard library, is also available, but currently experimental. Note that + it requires ``asyncio`` features that were only introduced in the 3.4 line + in 3.4.6, and in the 3.5 line in 3.5.1. + """ + + control_connection_timeout = 2.0 + """ + A timeout, in seconds, for queries made by the control connection, such + as querying the current schema and information about nodes in the cluster. + If set to :const:`None`, there will be no timeout for these queries. + """ + + idle_heartbeat_interval = 30 + """ + Interval, in seconds, on which to heartbeat idle connections. This helps + keep connections open through network devices that expire idle connections. + It also helps discover bad connections early in low-traffic scenarios. + Setting to zero disables heartbeats. + """ + + idle_heartbeat_timeout = 30 + """ + Timeout, in seconds, on which the heartbeat wait for idle connection responses. + Lowering this value can help to discover bad connections earlier. + """ + + schema_event_refresh_window = 2 + """ + Window, in seconds, within which a schema component will be refreshed after + receiving a schema_change event. + + The driver delays a random amount of time in the range [0.0, window) + before executing the refresh. This serves two purposes: + + 1.) Spread the refresh for deployments with large fanout from C* to client tier, + preventing a 'thundering herd' problem with many clients refreshing simultaneously. + + 2.) Remove redundant refreshes. Redundant events arriving within the delay period + are discarded, and only one refresh is executed. + + Setting this to zero will execute refreshes immediately. + + Setting this negative will disable schema refreshes in response to push events + (refreshes will still occur in response to schema change responses to DDL statements + executed by Sessions of this Cluster). + """ + + topology_event_refresh_window = 10 + """ + Window, in seconds, within which the node and token list will be refreshed after + receiving a topology_change event. + + Setting this to zero will execute refreshes immediately. + + Setting this negative will disable node refreshes in response to push events. + + See :attr:`.schema_event_refresh_window` for discussion of rationale + """ + + status_event_refresh_window = 2 + """ + Window, in seconds, within which the driver will start the reconnect after + receiving a status_change event. + + Setting this to zero will connect immediately. + + This is primarily used to avoid 'thundering herd' in deployments with large fanout from cluster to clients. + When nodes come up, clients attempt to reprepare prepared statements (depending on :attr:`.reprepare_on_up`), and + establish connection pools. This can cause a rush of connections and queries if not mitigated with this factor. + """ + + prepare_on_all_hosts = True + """ + Specifies whether statements should be prepared on all hosts, or just one. + + This can reasonably be disabled on long-running applications with numerous clients preparing statements on startup, + where a randomized initial condition of the load balancing policy can be expected to distribute prepares from + different clients across the cluster. + """ + + reprepare_on_up = True + """ + Specifies whether all known prepared statements should be prepared on a node when it comes up. + + May be used to avoid overwhelming a node on return, or if it is supposed that the node was only marked down due to + network. If statements are not reprepared, they are prepared on the first execution, causing + an extra roundtrip for one or more client requests. + """ + + connect_timeout = 5 + """ + Timeout, in seconds, for creating new connections. + + This timeout covers the entire connection negotiation, including TCP + establishment, options passing, and authentication. + """ + + timestamp_generator = None + """ + An object, shared between all sessions created by this cluster instance, + that generates timestamps when client-side timestamp generation is enabled. + By default, each :class:`Cluster` uses a new + :class:`~.MonotonicTimestampGenerator`. + + Applications can set this value for custom timestamp behavior. See the + documentation for :meth:`Session.timestamp_generator`. + """ + + monitor_reporting_enabled = True + """ + A boolean indicating if monitor reporting, which sends gathered data to + Insights when running against DSE 6.8 and higher. + """ + + monitor_reporting_interval = 30 + """ + A boolean indicating if monitor reporting, which sends gathered data to + Insights when running against DSE 6.8 and higher. + """ + + client_id = None + """ + A UUID that uniquely identifies this Cluster object to Insights. This will + be generated automatically unless the user provides one. + """ + + application_name = '' + """ + A string identifying this application to Insights. + """ + + application_version = '' + """ + A string identifying this application's version to Insights + """ + + cloud = None + """ + A dict of the cloud configuration. Example:: + + { + # path to the secure connect bundle + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip', + + # optional config options + 'use_default_tempdir': True # use the system temp dir for the zip extraction + } + + The zip file will be temporarily extracted in the same directory to + load the configuration and certificates. + """ + + column_encryption_policy = None + """ + An instance of :class:`cassandra.policies.ColumnEncryptionPolicy` specifying encryption materials to be + used for columns in this cluster. + """ + + @property + def schema_metadata_enabled(self): + """ + Flag indicating whether internal schema metadata is updated. + + When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This + can be used to speed initial connection, and reduce load on client and server during operation. Turning this off + gives away token aware request routing, and programmatic inspection of the metadata model. + """ + return self.control_connection._schema_meta_enabled + + @schema_metadata_enabled.setter + def schema_metadata_enabled(self, enabled): + self.control_connection._schema_meta_enabled = bool(enabled) + + @property + def token_metadata_enabled(self): + """ + Flag indicating whether internal token metadata is updated. + + When disabled, the driver does not query node token information on connect, or on topology change events. This + can be used to speed initial connection, and reduce load on client and server during operation. It is most useful + in large clusters using vnodes, where the token map can be expensive to compute. Turning this off + gives away token aware request routing, and programmatic inspection of the token ring. + """ + return self.control_connection._token_meta_enabled + + @token_metadata_enabled.setter + def token_metadata_enabled(self, enabled): + self.control_connection._token_meta_enabled = bool(enabled) + + endpoint_factory = None + """ + An :class:`~.connection.EndPointFactory` instance to use internally when creating + a socket connection to a node. You can ignore this unless you need a special + connection mechanism. + """ + + profile_manager = None + _config_mode = _ConfigMode.UNCOMMITTED + + sessions = None + control_connection = None + scheduler = None + executor = None + is_shutdown = False + _is_setup = False + _prepared_statements = None + _prepared_statement_lock = None + _idle_heartbeat = None + _protocol_version_explicit = False + _discount_down_events = True + + _user_types = None + """ + A map of {keyspace: {type_name: UserType}} + """ + + _listeners = None + _listener_lock = None + + def __init__(self, + contact_points=_NOT_SET, + port=9042, + compression=True, + auth_provider=None, + load_balancing_policy=None, + reconnection_policy=None, + default_retry_policy=None, + conviction_policy_factory=None, + metrics_enabled=False, + connection_class=None, + ssl_options=None, + sockopts=None, + cql_version=None, + protocol_version=_NOT_SET, + executor_threads=2, + max_schema_agreement_wait=10, + control_connection_timeout=2.0, + idle_heartbeat_interval=30, + schema_event_refresh_window=2, + topology_event_refresh_window=10, + connect_timeout=5, + schema_metadata_enabled=True, + token_metadata_enabled=True, + address_translator=None, + status_event_refresh_window=2, + prepare_on_all_hosts=True, + reprepare_on_up=True, + execution_profiles=None, + allow_beta_protocol_version=False, + timestamp_generator=None, + idle_heartbeat_timeout=30, + no_compact=False, + ssl_context=None, + endpoint_factory=None, + application_name=None, + application_version=None, + monitor_reporting_enabled=True, + monitor_reporting_interval=30, + client_id=None, + cloud=None, + column_encryption_policy=None): + """ + ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as + establishing connection pools or refreshing metadata. + + Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. + """ + if connection_class is not None: + self.connection_class = connection_class + + if cloud is not None: + self.cloud = cloud + if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options: + raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options " + "cannot be specified with a cloud configuration") + + uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection) + uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection) + cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet) + + ssl_context = cloud_config.ssl_context + ssl_options = {'check_hostname': True} + if (auth_provider is None and cloud_config.username + and cloud_config.password): + auth_provider = PlainTextAuthProvider(cloud_config.username, cloud_config.password) + + endpoint_factory = SniEndPointFactory(cloud_config.sni_host, cloud_config.sni_port) + contact_points = [ + endpoint_factory.create_from_sni(host_id) + for host_id in cloud_config.host_ids + ] + + if contact_points is not None: + if contact_points is _NOT_SET: + self._contact_points_explicit = False + contact_points = ['127.0.0.1'] + else: + self._contact_points_explicit = True + + if isinstance(contact_points, str): + raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") + + if None in contact_points: + raise ValueError("contact_points should not contain None (it can resolve to localhost)") + self.contact_points = contact_points + + self.port = port + + if column_encryption_policy is not None: + self.column_encryption_policy = column_encryption_policy + + self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) + self.endpoint_factory.configure(self) + + raw_contact_points = [] + for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]: + raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, port)) + + self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)] + self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port) + for ep in self.endpoints_resolved} + + strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points) + self.endpoints_resolved.extend(list(chain( + *[ + [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None] + for xs in strs_resolved_map.values() if xs is not None + ] + ))) + + self._endpoint_map_for_insights.update( + {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value] + for key, value in strs_resolved_map.items() if value is not None} + ) + + if contact_points and (not self.endpoints_resolved): + # only want to raise here if the user specified CPs but resolution failed + raise UnresolvableContactPoints(self._endpoint_map_for_insights) + + self.compression = compression + + if protocol_version is not _NOT_SET: + self.protocol_version = protocol_version + self._protocol_version_explicit = True + self.allow_beta_protocol_version = allow_beta_protocol_version + + self.no_compact = no_compact + + self.auth_provider = auth_provider + + if load_balancing_policy is not None: + if isinstance(load_balancing_policy, type): + raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") + self.load_balancing_policy = load_balancing_policy + else: + self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode + + if reconnection_policy is not None: + if isinstance(reconnection_policy, type): + raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") + self.reconnection_policy = reconnection_policy + + if default_retry_policy is not None: + if isinstance(default_retry_policy, type): + raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") + self.default_retry_policy = default_retry_policy + + if conviction_policy_factory is not None: + if not callable(conviction_policy_factory): + raise ValueError("conviction_policy_factory must be callable") + self.conviction_policy_factory = conviction_policy_factory + + if address_translator is not None: + if isinstance(address_translator, type): + raise TypeError("address_translator should not be a class, it should be an instance of that class") + self.address_translator = address_translator + + if timestamp_generator is not None: + if not callable(timestamp_generator): + raise ValueError("timestamp_generator must be callable") + self.timestamp_generator = timestamp_generator + else: + self.timestamp_generator = MonotonicTimestampGenerator() + + self.profile_manager = ProfileManager() + self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( + self.load_balancing_policy, + self.default_retry_policy, + request_timeout=Session._default_timeout, + row_factory=Session._row_factory + ) + + # legacy mode if either of these is not default + if load_balancing_policy or default_retry_policy: + if execution_profiles: + raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters " + "load_balancing_policy or default_retry_policy. Configure this in a profile instead.") + + self._config_mode = _ConfigMode.LEGACY + warn("Legacy execution parameters will be removed in 4.0. Consider using " + "execution profiles.", DeprecationWarning) + + else: + profiles = self.profile_manager.profiles + if execution_profiles: + profiles.update(execution_profiles) + self._config_mode = _ConfigMode.PROFILES + + lbp = DefaultLoadBalancingPolicy(self.profile_manager.default.load_balancing_policy) + profiles.setdefault(EXEC_PROFILE_GRAPH_DEFAULT, GraphExecutionProfile(load_balancing_policy=lbp)) + profiles.setdefault(EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + GraphExecutionProfile(load_balancing_policy=lbp, request_timeout=60. * 3.)) + profiles.setdefault(EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, + GraphAnalyticsExecutionProfile(load_balancing_policy=lbp)) + + if self._contact_points_explicit and not self.cloud: # avoid this warning for cloud users. + if self._config_mode is _ConfigMode.PROFILES: + default_lbp_profiles = self.profile_manager._profiles_without_explicit_lbps() + if default_lbp_profiles: + log.warning( + 'Cluster.__init__ called with contact_points ' + 'specified, but load-balancing policies are not ' + 'specified in some ExecutionProfiles. In the next ' + 'major version, this will raise an error; please ' + 'specify a load-balancing policy. ' + '(contact_points = {cp}, ' + 'EPs without explicit LBPs = {eps})' + ''.format(cp=contact_points, eps=default_lbp_profiles)) + else: + if load_balancing_policy is None: + log.warning( + 'Cluster.__init__ called with contact_points ' + 'specified, but no load_balancing_policy. In the next ' + 'major version, this will raise an error; please ' + 'specify a load-balancing policy. ' + '(contact_points = {cp}, lbp = {lbp})' + ''.format(cp=contact_points, lbp=load_balancing_policy)) + + self.metrics_enabled = metrics_enabled + + if ssl_options and not ssl_context: + warn('Using ssl_options without ssl_context is ' + 'deprecated and will result in an error in ' + 'the next major release. Please use ssl_context ' + 'to prepare for that release.', + DeprecationWarning) + + self.ssl_options = ssl_options + self.ssl_context = ssl_context + self.sockopts = sockopts + self.cql_version = cql_version + self.max_schema_agreement_wait = max_schema_agreement_wait + self.control_connection_timeout = control_connection_timeout + self.idle_heartbeat_interval = idle_heartbeat_interval + self.idle_heartbeat_timeout = idle_heartbeat_timeout + self.schema_event_refresh_window = schema_event_refresh_window + self.topology_event_refresh_window = topology_event_refresh_window + self.status_event_refresh_window = status_event_refresh_window + self.connect_timeout = connect_timeout + self.prepare_on_all_hosts = prepare_on_all_hosts + self.reprepare_on_up = reprepare_on_up + self.monitor_reporting_enabled = monitor_reporting_enabled + self.monitor_reporting_interval = monitor_reporting_interval + + self._listeners = set() + self._listener_lock = Lock() + + # let Session objects be GC'ed (and shutdown) when the user no longer + # holds a reference. + self.sessions = WeakSet() + self.metadata = Metadata() + self.control_connection = None + self._prepared_statements = WeakValueDictionary() + self._prepared_statement_lock = Lock() + + self._user_types = defaultdict(dict) + + self._min_requests_per_connection = { + HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, + HostDistance.REMOTE: DEFAULT_MIN_REQUESTS + } + + self._max_requests_per_connection = { + HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, + HostDistance.REMOTE: DEFAULT_MAX_REQUESTS + } + + self._core_connections_per_host = { + HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, + HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST + } + + self._max_connections_per_host = { + HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, + HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST + } + + self.executor = self._create_thread_pool_executor(max_workers=executor_threads) + self.scheduler = _Scheduler(self.executor) + + self._lock = RLock() + + if self.metrics_enabled: + from cassandra.metrics import Metrics + self.metrics = Metrics(weakref.proxy(self)) + + self.control_connection = ControlConnection( + self, self.control_connection_timeout, + self.schema_event_refresh_window, self.topology_event_refresh_window, + self.status_event_refresh_window, + schema_metadata_enabled, token_metadata_enabled) + + if client_id is None: + self.client_id = uuid.uuid4() + if application_name is not None: + self.application_name = application_name + if application_version is not None: + self.application_version = application_version + + def _create_thread_pool_executor(self, **kwargs): + """ + Create a ThreadPoolExecutor for the cluster. In most cases, the built-in + `concurrent.futures.ThreadPoolExecutor` is used. + + Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` + to hang indefinitely. In that case, the user needs to have the `futurist` + package so we can use the `futurist.GreenThreadPoolExecutor` class instead. + + :param kwargs: All keyword args are passed to the ThreadPoolExecutor constructor. + :return: A ThreadPoolExecutor instance. + """ + tpe_class = ThreadPoolExecutor + if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: + try: + from cassandra.io.eventletreactor import EventletConnection + is_eventlet = issubclass(self.connection_class, EventletConnection) + except: + # Eventlet is not available or can't be detected + return tpe_class(**kwargs) + + if is_eventlet: + try: + from futurist import GreenThreadPoolExecutor + tpe_class = GreenThreadPoolExecutor + except ImportError: + # futurist is not available + raise ImportError( + ("Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` " + "to hang indefinitely. If you want to use the Eventlet reactor, you " + "need to install the `futurist` package to allow the driver to use " + "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " + "for more details.")) + + return tpe_class(**kwargs) + + def register_user_type(self, keyspace, user_type, klass): + """ + Registers a class to use to represent a particular user-defined type. + Query parameters for this user-defined type will be assumed to be + instances of `klass`. Result sets for this user-defined type will + be instances of `klass`. If no class is registered for a user-defined + type, a namedtuple will be used for result sets, and non-prepared + statements may not encode parameters for this type correctly. + + `keyspace` is the name of the keyspace that the UDT is defined in. + + `user_type` is the string name of the UDT to register the mapping + for. + + `klass` should be a class with attributes whose names match the + fields of the user-defined type. The constructor must accept kwargs + for each of the fields in the UDT. + + This method should only be called after the type has been created + within Cassandra. + + Example:: + + cluster = Cluster(protocol_version=3) + session = cluster.connect() + session.set_keyspace('mykeyspace') + session.execute("CREATE TYPE address (street text, zipcode int)") + session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)") + + # create a class to map to the "address" UDT + class Address(object): + + def __init__(self, street, zipcode): + self.street = street + self.zipcode = zipcode + + cluster.register_user_type('mykeyspace', 'address', Address) + + # insert a row using an instance of Address + session.execute("INSERT INTO users (id, location) VALUES (%s, %s)", + (0, Address("123 Main St.", 78723))) + + # results will include Address instances + results = session.execute("SELECT * FROM users") + row = results[0] + print(row.id, row.location.street, row.location.zipcode) + + """ + if self.protocol_version < 3: + log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). " + "CQL encoding for simple statements will still work, but named tuples will " + "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) + + self._user_types[keyspace][user_type] = klass + for session in tuple(self.sessions): + session.user_type_registered(keyspace, user_type, klass) + UserType.evict_udt_class(keyspace, user_type) + + def add_execution_profile(self, name, profile, pool_wait_timeout=5): + """ + Adds an :class:`.ExecutionProfile` to the cluster. This makes it available for use by ``name`` in :meth:`.Session.execute` + and :meth:`.Session.execute_async`. This method will raise if the profile already exists. + + Normally profiles will be injected at cluster initialization via ``Cluster(execution_profiles)``. This method + provides a way of adding them dynamically. + + Adding a new profile updates the connection pools according to the specified ``load_balancing_policy``. By default, + this method will wait up to five seconds for the pool creation to complete, so the profile can be used immediately + upon return. This behavior can be controlled using ``pool_wait_timeout`` (see + `concurrent.futures.wait `_ + for timeout semantics). + """ + if not isinstance(profile, ExecutionProfile): + raise TypeError("profile must be an instance of ExecutionProfile") + if self._config_mode == _ConfigMode.LEGACY: + raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly.") + if name in self.profile_manager.profiles: + raise ValueError("Profile {} already exists".format(name)) + contact_points_but_no_lbp = ( + self._contact_points_explicit and not + profile._load_balancing_policy_explicit) + if contact_points_but_no_lbp: + log.warning( + 'Tried to add an ExecutionProfile with name {name}. ' + '{self} was explicitly configured with contact_points, but ' + '{ep} was not explicitly configured with a ' + 'load_balancing_policy. In the next major version, trying to ' + 'add an ExecutionProfile without an explicitly configured LBP ' + 'to a cluster with explicitly configured contact_points will ' + 'raise an exception; please specify a load-balancing policy ' + 'in the ExecutionProfile.' + ''.format(name=_execution_profile_to_string(name), self=self, ep=profile)) + + self.profile_manager.profiles[name] = profile + profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) + # on_up after populate allows things like DCA LBP to choose default local dc + for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): + profile.load_balancing_policy.on_up(host) + futures = set() + for session in tuple(self.sessions): + self._set_default_dbaas_consistency(session) + futures.update(session.update_created_pools()) + _, not_done = wait_futures(futures, pool_wait_timeout) + if not_done: + raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.") + + def get_min_requests_per_connection(self, host_distance): + return self._min_requests_per_connection[host_distance] + + def set_min_requests_per_connection(self, host_distance, min_requests): + """ + Sets a threshold for concurrent requests per connection, below which + connections will be considered for disposal (down to core connections; + see :meth:`~Cluster.set_core_connections_per_host`). + + Pertains to connection pool management in protocol versions {1,2}. + """ + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_min_requests_per_connection() only has an effect " + "when using protocol_version 1 or 2.") + if min_requests < 0 or min_requests > 126 or \ + min_requests >= self._max_requests_per_connection[host_distance]: + raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" % + (self._min_requests_per_connection[host_distance],)) + self._min_requests_per_connection[host_distance] = min_requests + + def get_max_requests_per_connection(self, host_distance): + return self._max_requests_per_connection[host_distance] + + def set_max_requests_per_connection(self, host_distance, max_requests): + """ + Sets a threshold for concurrent requests per connection, above which new + connections will be created to a host (up to max connections; + see :meth:`~Cluster.set_max_connections_per_host`). + + Pertains to connection pool management in protocol versions {1,2}. + """ + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_max_requests_per_connection() only has an effect " + "when using protocol_version 1 or 2.") + if max_requests < 1 or max_requests > 127 or \ + max_requests <= self._min_requests_per_connection[host_distance]: + raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" % + (self._min_requests_per_connection[host_distance],)) + self._max_requests_per_connection[host_distance] = max_requests + + def get_core_connections_per_host(self, host_distance): + """ + Gets the minimum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for + :attr:`~HostDistance.REMOTE`. + + This property is ignored if :attr:`~.Cluster.protocol_version` is + 3 or higher. + """ + return self._core_connections_per_host[host_distance] + + def set_core_connections_per_host(self, host_distance, core_connections): + """ + Sets the minimum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for + :attr:`~HostDistance.REMOTE`. + + Protocol version 1 and 2 are limited in the number of concurrent + requests they can send per connection. The driver implements connection + pooling to support higher levels of concurrency. + + If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this + is not supported (there is always one connection per host, unless + the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) + and using this will result in an :exc:`~.UnsupportedOperation`. + """ + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_core_connections_per_host() only has an effect " + "when using protocol_version 1 or 2.") + old = self._core_connections_per_host[host_distance] + self._core_connections_per_host[host_distance] = core_connections + if old < core_connections: + self._ensure_core_connections() + + def get_max_connections_per_host(self, host_distance): + """ + Gets the maximum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for + :attr:`~HostDistance.REMOTE`. + + This property is ignored if :attr:`~.Cluster.protocol_version` is + 3 or higher. + """ + return self._max_connections_per_host[host_distance] + + def set_max_connections_per_host(self, host_distance, max_connections): + """ + Sets the maximum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for + :attr:`~HostDistance.REMOTE`. + + If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this + is not supported (there is always one connection per host, unless + the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) + and using this will result in an :exc:`~.UnsupportedOperation`. + """ + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_max_connections_per_host() only has an effect " + "when using protocol_version 1 or 2.") + self._max_connections_per_host[host_distance] = max_connections + + def connection_factory(self, endpoint, *args, **kwargs): + """ + Called to create a new connection with proper configuration. + Intended for internal use only. + """ + kwargs = self._make_connection_kwargs(endpoint, kwargs) + return self.connection_class.factory(endpoint, self.connect_timeout, *args, **kwargs) + + def _make_connection_factory(self, host, *args, **kwargs): + kwargs = self._make_connection_kwargs(host.endpoint, kwargs) + return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs) + + def _make_connection_kwargs(self, endpoint, kwargs_dict): + if self._auth_provider_callable: + kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address)) + + kwargs_dict.setdefault('port', self.port) + kwargs_dict.setdefault('compression', self.compression) + kwargs_dict.setdefault('sockopts', self.sockopts) + kwargs_dict.setdefault('ssl_options', self.ssl_options) + kwargs_dict.setdefault('ssl_context', self.ssl_context) + kwargs_dict.setdefault('cql_version', self.cql_version) + kwargs_dict.setdefault('protocol_version', self.protocol_version) + kwargs_dict.setdefault('user_type_map', self._user_types) + kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) + kwargs_dict.setdefault('no_compact', self.no_compact) + + return kwargs_dict + + def protocol_downgrade(self, host_endpoint, previous_version): + if self._protocol_version_explicit: + raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) + new_version = ProtocolVersion.get_lower_supported(previous_version) + if new_version < ProtocolVersion.MIN_SUPPORTED: + raise DriverException( + "Cannot downgrade protocol version below minimum supported version: %d" % (ProtocolVersion.MIN_SUPPORTED,)) + + log.warning("Downgrading core protocol version from %d to %d for %s. " + "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " + "https://docs.datastax.com/en/developer/python-driver/latest/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) + self.protocol_version = new_version + + def connect(self, keyspace=None, wait_for_all_pools=False): + """ + Creates and returns a new :class:`~.Session` object. + + If `keyspace` is specified, that keyspace will be the default keyspace for + operations on the ``Session``. + + `wait_for_all_pools` specifies whether this call should wait for all connection pools to be + established or attempted. Default is `False`, which means it will return when the first + successful connection is established. Remaining pools are added asynchronously. + """ + with self._lock: + if self.is_shutdown: + raise DriverException("Cluster is already shut down") + + if not self._is_setup: + log.debug("Connecting to cluster, contact points: %s; protocol version: %s", + self.contact_points, self.protocol_version) + self.connection_class.initialize_reactor() + _register_cluster_shutdown(self) + for endpoint in self.endpoints_resolved: + host, new = self.add_host(endpoint, signal=False) + if new: + host.set_up() + for listener in self.listeners: + listener.on_add(host) + + self.profile_manager.populate( + weakref.proxy(self), self.metadata.all_hosts()) + self.load_balancing_policy.populate( + weakref.proxy(self), self.metadata.all_hosts() + ) + + try: + self.control_connection.connect() + + # we set all contact points up for connecting, but we won't infer state after this + for endpoint in self.endpoints_resolved: + h = self.metadata.get_host(endpoint) + if h and self.profile_manager.distance(h) == HostDistance.IGNORED: + h.is_up = None + + log.debug("Control connection created") + except Exception: + log.exception("Control connection failed to connect, " + "shutting down Cluster:") + self.shutdown() + raise + + self.profile_manager.check_supported() # todo: rename this method + + if self.idle_heartbeat_interval: + self._idle_heartbeat = ConnectionHeartbeat( + self.idle_heartbeat_interval, + self.get_connection_holders, + timeout=self.idle_heartbeat_timeout + ) + self._is_setup = True + + session = self._new_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + + return session + + def _set_default_dbaas_consistency(self, session): + if session.cluster.metadata.dbaas: + for profile in self.profile_manager.profiles.values(): + if not profile._consistency_level_explicit: + profile.consistency_level = ConsistencyLevel.LOCAL_QUORUM + session._default_consistency_level = ConsistencyLevel.LOCAL_QUORUM + + def get_connection_holders(self): + holders = [] + for s in tuple(self.sessions): + holders.extend(s.get_pools()) + holders.append(self.control_connection) + return holders + + def shutdown(self): + """ + Closes all sessions and connection associated with this Cluster. + To ensure all connections are properly closed, **you should always + call shutdown() on a Cluster instance when you are done with it**. + + Once shutdown, a Cluster should not be used for any purpose. + """ + with self._lock: + if self.is_shutdown: + return + else: + self.is_shutdown = True + + if self._idle_heartbeat: + self._idle_heartbeat.stop() + + self.scheduler.shutdown() + + self.control_connection.shutdown() + + for session in tuple(self.sessions): + session.shutdown() + + self.executor.shutdown() + + _discard_cluster_shutdown(self) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.shutdown() + + def _new_session(self, keyspace): + session = Session(self, self.metadata.all_hosts(), keyspace) + self._session_register_user_types(session) + self.sessions.add(session) + return session + + def _session_register_user_types(self, session): + for keyspace, type_map in self._user_types.items(): + for udt_name, klass in type_map.items(): + session.user_type_registered(keyspace, udt_name, klass) + + def _cleanup_failed_on_up_handling(self, host): + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + for session in tuple(self.sessions): + session.remove_pool(host) + + self._start_reconnector(host, is_host_addition=False) + + def _on_up_future_completed(self, host, futures, results, lock, finished_future): + with lock: + futures.discard(finished_future) + + try: + results.append(finished_future.result()) + except Exception as exc: + results.append(exc) + + if futures: + return + + try: + # all futures have completed at this point + for exc in [f for f in results if isinstance(f, Exception)]: + log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) + self._cleanup_failed_on_up_handling(host) + return + + if not all(results): + log.debug("Connection pool could not be created, not marking node %s up", host) + self._cleanup_failed_on_up_handling(host) + return + + log.info("Connection pools established for node %s", host) + # mark the host as up and notify all listeners + host.set_up() + for listener in self.listeners: + listener.on_up(host) + finally: + with host.lock: + host._currently_handling_node_up = False + + # see if there are any pools to add or remove now that the host is marked up + for session in tuple(self.sessions): + session.update_created_pools() + + def on_up(self, host): + """ + Intended for internal use only. + """ + if self.is_shutdown: + return + + log.debug("Waiting to acquire lock for handling up status of node %s", host) + with host.lock: + if host._currently_handling_node_up: + log.debug("Another thread is already handling up status of node %s", host) + return + + if host.is_up: + log.debug("Host %s was already marked up", host) + return + + host._currently_handling_node_up = True + log.debug("Starting to handle up status of node %s", host) + + have_future = False + futures = set() + try: + log.info("Host %s may be up; will prepare queries and open connection pool", host) + + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + log.debug("Now that host %s is up, cancelling the reconnection handler", host) + reconnector.cancel() + + if self.profile_manager.distance(host) != HostDistance.IGNORED: + self._prepare_all_queries(host) + log.debug("Done preparing all queries for host %s, ", host) + + for session in tuple(self.sessions): + session.remove_pool(host) + + log.debug("Signalling to load balancing policies that host %s is up", host) + self.profile_manager.on_up(host) + + log.debug("Signalling to control connection that host %s is up", host) + self.control_connection.on_up(host) + + log.debug("Attempting to open new connection pools for host %s", host) + futures_lock = Lock() + futures_results = [] + callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + for session in tuple(self.sessions): + future = session.add_or_renew_pool(host, is_host_addition=False) + if future is not None: + have_future = True + future.add_done_callback(callback) + futures.add(future) + except Exception: + log.exception("Unexpected failure handling node %s being marked up:", host) + for future in futures: + future.cancel() + + self._cleanup_failed_on_up_handling(host) + + with host.lock: + host._currently_handling_node_up = False + raise + else: + if not have_future: + with host.lock: + host.set_up() + host._currently_handling_node_up = False + + # for testing purposes + return futures + + def _start_reconnector(self, host, is_host_addition): + if self.profile_manager.distance(host) == HostDistance.IGNORED: + return + + schedule = self.reconnection_policy.new_schedule() + + # in order to not hold references to this Cluster open and prevent + # proper shutdown when the program ends, we'll just make a closure + # of the current Cluster attributes to create new Connections with + conn_factory = self._make_connection_factory(host) + + reconnector = _HostReconnectionHandler( + host, conn_factory, is_host_addition, self.on_add, self.on_up, + self.scheduler, schedule, host.get_and_set_reconnection_handler, + new_handler=None) + + old_reconnector = host.get_and_set_reconnection_handler(reconnector) + if old_reconnector: + log.debug("Old host reconnector found for %s, cancelling", host) + old_reconnector.cancel() + + log.debug("Starting reconnector for host %s", host) + reconnector.start() + + @run_in_executor + def on_down(self, host, is_host_addition, expect_host_to_be_down=False): + """ + Intended for internal use only. + """ + if self.is_shutdown: + return + + with host.lock: + was_up = host.is_up + + # ignore down signals if we have open pools to the host + # this is to avoid closing pools when a control connection host became isolated + if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: + connected = False + for session in tuple(self.sessions): + pool_states = session.get_pool_state() + pool_state = pool_states.get(host) + if pool_state: + connected |= pool_state['open_count'] > 0 + if connected: + return + + host.set_down() + if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + return + + log.warning("Host %s has been marked down", host) + + self.profile_manager.on_down(host) + self.control_connection.on_down(host) + for session in tuple(self.sessions): session.on_down(host) - schedule = self.reconnection_policy.new_schedule() + for listener in self.listeners: + listener.on_down(host) + + self._start_reconnector(host, is_host_addition) + + def on_add(self, host, refresh_nodes=True): + if self.is_shutdown: + return + + log.debug("Handling new host %r and notifying listeners", host) + + distance = self.profile_manager.distance(host) + if distance != HostDistance.IGNORED: + self._prepare_all_queries(host) + log.debug("Done preparing queries for new host %r", host) + + self.profile_manager.on_add(host) + self.control_connection.on_add(host, refresh_nodes) + + if distance == HostDistance.IGNORED: + log.debug("Not adding connection pool for new host %r because the " + "load balancing policy has marked it as IGNORED", host) + self._finalize_add(host, set_up=False) + return + + futures_lock = Lock() + futures_results = [] + futures = set() + + def future_completed(future): + with futures_lock: + futures.discard(future) + + try: + futures_results.append(future.result()) + except Exception as exc: + futures_results.append(exc) + + if futures: + return + + log.debug('All futures have completed for added host %s', host) + + for exc in [f for f in futures_results if isinstance(f, Exception)]: + log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) + return + + if not all(futures_results): + log.warning("Connection pool could not be created, not marking node %s up", host) + return + + self._finalize_add(host) + + have_future = False + for session in tuple(self.sessions): + future = session.add_or_renew_pool(host, is_host_addition=True) + if future is not None: + have_future = True + futures.add(future) + future.add_done_callback(future_completed) + + if not have_future: + self._finalize_add(host) + + def _finalize_add(self, host, set_up=True): + if set_up: + host.set_up() + + for listener in self.listeners: + listener.on_add(host) + + # see if there are any pools to add or remove now that the host is marked up + for session in tuple(self.sessions): + session.update_created_pools() + + def on_remove(self, host): + if self.is_shutdown: + return + + log.debug("Removing host %s", host) + host.set_down() + self.profile_manager.on_remove(host) + for session in tuple(self.sessions): + session.on_remove(host) + for listener in self.listeners: + listener.on_remove(host) + self.control_connection.on_remove(host) + + reconnection_handler = host.get_and_set_reconnection_handler(None) + if reconnection_handler: + reconnection_handler.cancel() + + def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): + is_down = host.signal_connection_failure(connection_exc) + if is_down: + self.on_down(host, is_host_addition, expect_host_to_be_down) + return is_down + + def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True): + """ + Called when adding initial contact points and when the control + connection subsequently discovers a new node. + Returns a Host instance, and a flag indicating whether it was new in + the metadata. + Intended for internal use only. + """ + host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack)) + if new and signal: + log.info("New Cassandra host %r discovered", host) + self.on_add(host, refresh_nodes) + + return host, new + + def remove_host(self, host): + """ + Called when the control connection observes that a node has left the + ring. Intended for internal use only. + """ + if host and self.metadata.remove_host(host): + log.info("Cassandra host %s removed", host) + self.on_remove(host) + + def register_listener(self, listener): + """ + Adds a :class:`cassandra.policies.HostStateListener` subclass instance to + the list of listeners to be notified when a host is added, removed, + marked up, or marked down. + """ + with self._listener_lock: + self._listeners.add(listener) + + def unregister_listener(self, listener): + """ Removes a registered listener. """ + with self._listener_lock: + self._listeners.remove(listener) + + @property + def listeners(self): + with self._listener_lock: + return self._listeners.copy() + + def _ensure_core_connections(self): + """ + If any host has fewer than the configured number of core connections + open, attempt to open connections until that number is met. + """ + for session in tuple(self.sessions): + for pool in tuple(session._pools.values()): + pool.ensure_core_connections() + + @staticmethod + def _validate_refresh_schema(keyspace, table, usertype, function, aggregate): + if any((table, usertype, function, aggregate)): + if not keyspace: + raise ValueError("keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}") + if sum(1 for e in (table, usertype, function) if e) > 1: + raise ValueError("{table, usertype, function, aggregate} are mutually exclusive") + + @staticmethod + def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregate): + if aggregate: + return SchemaTargetType.AGGREGATE + elif function: + return SchemaTargetType.FUNCTION + elif usertype: + return SchemaTargetType.TYPE + elif table: + return SchemaTargetType.TABLE + elif keyspace: + return SchemaTargetType.KEYSPACE + return None + + def get_control_connection_host(self): + """ + Returns the control connection host metadata. + """ + connection = self.control_connection._connection + endpoint = connection.endpoint if connection else None + return self.metadata.get_host(endpoint) if endpoint else None + + def refresh_schema_metadata(self, max_schema_agreement_wait=None): + """ + Synchronously refresh all schema metadata. + + By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` + and :attr:`~.Cluster.control_connection_timeout`. + + Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. + + Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. + + An Exception is raised if schema refresh fails for any reason. + """ + if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Schema metadata was not refreshed. See log for details.") + + def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): + """ + Synchronously refresh keyspace metadata. This applies to keyspace-level information such as replication + and durability settings. It does not refresh tables, types, etc. contained in the keyspace. + + See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior + """ + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Keyspace metadata was not refreshed. See log for details.") + + def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): + """ + Synchronously refresh table metadata. This applies to a table, and any triggers or indexes attached + to the table. - # in order to not hold references to this Cluster open and prevent - # proper shutdown when the program ends, we'll just make a closure - # of the current Cluster attributes to create new Connections with - conn_factory = self._make_connection_factory(host) + See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior + """ + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Table metadata was not refreshed. See log for details.") - reconnector = _HostReconnectionHandler( - host, conn_factory, self.scheduler, schedule, - host.get_and_set_reconnection_handler, new_handler=None) + def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): + """ + Synchronously refresh materialized view metadata. - old_reconnector = host.get_and_set_reconnection_handler(reconnector) - if old_reconnector: - old_reconnector.cancel() + See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior + """ + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("View metadata was not refreshed. See log for details.") - reconnector.start() + def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): + """ + Synchronously refresh user defined type metadata. - def add_host(self, address, signal): + See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - Called when adding initial contact points and when the control - connection subsequently discovers a new node. Intended for internal - use only. + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Type metadata was not refreshed. See log for details.") + + def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): """ - log.info("Now considering host %s for new connections", address) - new_host = self.metadata.add_host(address) - if new_host and signal: - self._prepare_all_queries(new_host) - self.control_connection.on_add(new_host) - for session in self.sessions: # TODO need to copy/lock? - session.on_add(new_host) + Synchronously refresh user defined function metadata. - return new_host + ``function`` is a :class:`cassandra.UserFunctionDescriptor`. - def remove_host(self, host): + See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - Called when the control connection observes that a node has left the - ring. Intended for internal use only. + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Function metadata was not refreshed. See log for details.") + + def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): """ - log.info("Host %s will no longer be considered for new connections", host) - if host and self.metadata.remove_host(host): - self.control_connection.on_remove(host) - for session in self.sessions: - session.on_remove(host) + Synchronously refresh user defined aggregate metadata. + + ``aggregate`` is a :class:`cassandra.UserAggregateDescriptor`. - def ensure_core_connections(self): + See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - If any host has fewer than the configured number of core connections - open, attempt to open connections until that number is met. + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Aggregate metadata was not refreshed. See log for details.") + + def refresh_nodes(self, force_token_rebuild=False): """ - for session in self.sessions: - for pool in session._pools.values(): - pool.ensure_core_connections() + Synchronously refresh the node list and token metadata - def submit_schema_refresh(self, keyspace=None, table=None): + `force_token_rebuild` can be used to rebuild the token map metadata, even if no new nodes are discovered. + + An Exception is raised if node refresh fails for any reason. """ - Schedule a refresh of the internal representation of the current - schema for this cluster. If `keyspace` is specified, only that - keyspace will be refreshed, and likewise for `table`. + if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild): + raise DriverException("Node list was not refreshed. See log for details.") + + def set_meta_refresh_enabled(self, enabled): """ - return self.executor.submit( - self.control_connection.refresh_schema, keyspace, table) + *Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead - def _prepare_all_queries(self, host): - if not self._prepared_statements: - return + Sets a flag to enable (True) or disable (False) all metadata refresh queries. + This applies to both schema and node topology. - log.debug("Preparing all known prepared statements against host %s" % (host,)) - try: - connection = self.connection_factory(host.address) - try: - self.control_connection.wait_for_schema_agreement(connection) - except Exception: - pass + Disabling this is useful to minimize refreshes during multiple changes. - statements = self._prepared_statements.values() - for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): - if keyspace is not None: - connection.set_keyspace(keyspace) + Meta refresh must be enabled for the driver to become aware of any cluster + topology changes or schema updates. + """ + warn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set " + "Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", DeprecationWarning) + self.schema_metadata_enabled = enabled + self.token_metadata_enabled = enabled + + @classmethod + def _send_chunks(cls, connection, host, chunks, set_keyspace=False): + for ks_chunk in chunks: + messages = [PrepareMessage(query=s.query_string, + keyspace=s.keyspace if set_keyspace else None) + for s in ks_chunk] + # TODO: make this timeout configurable somehow? + responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False) + for success, response in responses: + if not success: + log.debug("Got unexpected response when preparing " + "statement on host %s: %r", host, response) - # note: we could potentially prepare some of these in parallel, - # but at the same time, we don't want to put too much load on - # the server at once - for statement in ks_statements: - message = PrepareMessage(query=statement.query_string) - try: - response = connection.wait_for_response(message) - if (not isinstance(response, ResultMessage) or - response.kind != ResultMessage.KIND_PREPARED): - log.debug("Got unexpected response when preparing " - "statement on host %s: %r" % (host, response)) - except Exception: - log.exception("Error trying to prepare statement on " - "host %s" % (host,)) + def _prepare_all_queries(self, host): + if not self._prepared_statements or not self.reprepare_on_up: + return + log.debug("Preparing all known prepared statements against host %s", host) + connection = None + try: + connection = self.connection_factory(host.endpoint) + statements = list(self._prepared_statements.values()) + if ProtocolVersion.uses_keyspace_flag(self.protocol_version): + # V5 protocol and higher, no need to set the keyspace + chunks = [] + for i in range(0, len(statements), 10): + chunks.append(statements[i:i + 10]) + self._send_chunks(connection, host, chunks, True) + else: + for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): + if keyspace is not None: + connection.set_keyspace_blocking(keyspace) + + # prepare 10 statements at a time + ks_statements = list(ks_statements) + chunks = [] + for i in range(0, len(ks_statements), 10): + chunks.append(ks_statements[i:i + 10]) + self._send_chunks(connection, host, chunks) + + log.debug("Done preparing all known prepared statements against host %s", host) + except OperationTimedOut as timeout: + log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout) + except (ConnectionException, socket.error) as exc: + log.warning("Error trying to prepare all statements on host %s: %r", host, exc) except Exception: - # log and ignore - log.exception("Error trying to prepare all statements on host %s" % (host,)) + log.exception("Error trying to prepare all statements on host %s", host) + finally: + if connection: + connection.close() - def prepare_on_all_sessions(self, query_id, prepared_statement, excluded_host): - self._prepared_statements[query_id] = prepared_statement - for session in self.sessions: - session.prepare_on_all_hosts(prepared_statement.query_string, excluded_host) + def add_prepared(self, query_id, prepared_statement): + with self._prepared_statement_lock: + self._prepared_statements[query_id] = prepared_statement class Session(object): @@ -537,7 +2359,7 @@ class Session(object): Queries and statements can be executed through ``Session`` instances using the :meth:`~.Session.execute()` and :meth:`~.Session.execute_async()` - method. + methods. Example usage:: @@ -551,175 +2373,883 @@ class Session(object): hosts = None keyspace = None is_shutdown = False + session_id = None + _monitor_reporter = None + + _row_factory = staticmethod(named_tuple_factory) + @property + def row_factory(self): + """ + The format to return row results in. By default, each + returned row will be a named tuple. You can alternatively + use any of the following: + + - :func:`cassandra.query.tuple_factory` - return a result row as a tuple + - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple + - :func:`cassandra.query.dict_factory` - return a result row as a dict + - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict + + """ + return self._row_factory + + @row_factory.setter + def row_factory(self, rf): + self._validate_set_legacy_config('row_factory', rf) + + _default_timeout = 10.0 + + @property + def default_timeout(self): + """ + A default timeout, measured in seconds, for queries executed through + :meth:`.execute()` or :meth:`.execute_async()`. This default may be + overridden with the `timeout` parameter for either of those methods. + + Setting this to :const:`None` will cause no timeouts to be set by default. + + Please see :meth:`.ResponseFuture.result` for details on the scope and + effect of this timeout. + + .. versionadded:: 2.0.0 + """ + return self._default_timeout + + @default_timeout.setter + def default_timeout(self, timeout): + self._validate_set_legacy_config('default_timeout', timeout) + + _default_consistency_level = ConsistencyLevel.LOCAL_ONE + + @property + def default_consistency_level(self): + """ + *Deprecated:* use execution profiles instead + The default :class:`~ConsistencyLevel` for operations executed through + this session. This default may be overridden by setting the + :attr:`~.Statement.consistency_level` on individual statements. + + .. versionadded:: 1.2.0 + + .. versionchanged:: 3.0.0 + + default changed from ONE to LOCAL_ONE + """ + return self._default_consistency_level + + @default_consistency_level.setter + def default_consistency_level(self, cl): + """ + *Deprecated:* use execution profiles instead + """ + warn("Setting the consistency level at the session level will be removed in 4.0. Consider using " + "execution profiles and setting the desired consistency level to the EXEC_PROFILE_DEFAULT profile." + , DeprecationWarning) + self._validate_set_legacy_config('default_consistency_level', cl) + + _default_serial_consistency_level = None + + @property + def default_serial_consistency_level(self): + """ + The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through + this session. This default may be overridden by setting the + :attr:`~.Statement.serial_consistency_level` on individual statements. + + Only valid for ``protocol_version >= 2``. + """ + return self._default_serial_consistency_level + + @default_serial_consistency_level.setter + def default_serial_consistency_level(self, cl): + if (cl is not None and + not ConsistencyLevel.is_serial(cl)): + raise ValueError("default_serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL.") + + self._validate_set_legacy_config('default_serial_consistency_level', cl) + + max_trace_wait = 2.0 + """ + The maximum amount of time (in seconds) the driver will wait for trace + details to be populated server-side for a query before giving up. + If the `trace` parameter for :meth:`~.execute()` or :meth:`~.execute_async()` + is :const:`True`, the driver will repeatedly attempt to fetch trace + details for the query (using exponential backoff) until this limit is + hit. If the limit is passed, an error will be logged and the + :attr:`.Statement.trace` will be left as :const:`None`. """ + + default_fetch_size = 5000 + """ + By default, this many rows will be fetched at a time. Setting + this to :const:`None` will disable automatic paging for large query + results. The fetch size can be also specified per-query through + :attr:`.Statement.fetch_size`. + + This only takes effect when protocol version 2 or higher is used. + See :attr:`.Cluster.protocol_version` for details. + + .. versionadded:: 2.0.0 + """ + + use_client_timestamp = True + """ + When using protocol version 3 or higher, write timestamps may be supplied + client-side at the protocol level. (Normally they are generated + server-side by the coordinator node.) Note that timestamps specified + within a CQL query will override this timestamp. + + .. versionadded:: 2.1.0 + """ + + timestamp_generator = None + """ + When :attr:`use_client_timestamp` is set, sessions call this object and use + the result as the timestamp. (Note that timestamps specified within a CQL + query will override this timestamp.) By default, a new + :class:`~.MonotonicTimestampGenerator` is created for + each :class:`Cluster` instance. + + Applications can set this value for custom timestamp behavior. For + example, an application could share a timestamp generator across + :class:`Cluster` objects to guarantee that the application will use unique, + increasing timestamps across clusters, or set it to to ``lambda: + int(time.time() * 1e6)`` if losing records over clock inconsistencies is + acceptable for the application. Custom :attr:`timestamp_generator` s should + be callable, and calling them should return an integer representing microseconds + since some point in time, typically UNIX epoch. + + .. versionadded:: 3.8.0 + """ + + encoder = None + """ + A :class:`~cassandra.encoder.Encoder` instance that will be used when + formatting query parameters for non-prepared statements. This is not used + for prepared statements (because prepared statements give the driver more + information about what CQL types are expected, allowing it to accept a + wider range of python types). + + The encoder uses a mapping from python types to encoder methods (for + specific CQL types). This mapping can be be modified by users as they see + fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping + values if possible, because they take precautions to avoid injections and + properly sanitize data. + + Example:: + + cluster = Cluster() + session = cluster.connect("mykeyspace") + session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple + + session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple)") + session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')]) + + .. versionadded:: 2.1.0 + """ + + client_protocol_handler = ProtocolHandler + """ + Specifies a protocol handler that will be used for client-initiated requests (i.e. no + internal driver requests). This can be used to override or extend features such as + message or type ser/des. + + The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`. + + When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser` + """ + + session_id = None + """ + A UUID that uniquely identifies this Session to Insights. This will be + generated automatically. + """ + + _lock = None + _pools = None + _profile_manager = None + _metrics = None + _request_init_callbacks = None + _graph_paging_available = False + + def __init__(self, cluster, hosts, keyspace=None): + self.cluster = cluster + self.hosts = hosts + self.keyspace = keyspace + + self._lock = RLock() + self._pools = {} + self._profile_manager = cluster.profile_manager + self._metrics = cluster.metrics + self._request_init_callbacks = [] + self._protocol_version = self.cluster.protocol_version + + self.encoder = Encoder() + + # create connection pools in parallel + self._initial_connect_futures = set() + for host in hosts: + future = self.add_or_renew_pool(host, is_host_addition=False) + if future: + self._initial_connect_futures.add(future) + + futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + while futures.not_done and not any(f.result() for f in futures.done): + futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) + + if not any(f.result() for f in self._initial_connect_futures): + msg = "Unable to connect to any servers" + if self.keyspace: + msg += " using keyspace '%s'" % self.keyspace + raise NoHostAvailable(msg, [h.address for h in hosts]) + + self.session_id = uuid.uuid4() + self._graph_paging_available = self._check_graph_paging_available() + + if self.cluster.column_encryption_policy is not None: + try: + self.client_protocol_handler = type( + str(self.session_id) + "-ProtocolHandler", + (ProtocolHandler,), + {"column_encryption_policy": self.cluster.column_encryption_policy}) + except AttributeError: + log.info("Unable to set column encryption policy for session") + + if self.cluster.monitor_reporting_enabled: + cc_host = self.cluster.get_control_connection_host() + valid_insights_version = (cc_host and version_supports_insights(cc_host.dse_version)) + if valid_insights_version: + self._monitor_reporter = MonitorReporter( + interval_sec=self.cluster.monitor_reporting_interval, + session=self, + ) + else: + if cc_host: + log.debug('Not starting MonitorReporter thread for Insights; ' + 'not supported by server version {v} on ' + 'ControlConnection host {c}'.format(v=cc_host.release_version, c=cc_host)) + + log.debug('Started Session with client_id {} and session_id {}'.format(self.cluster.client_id, + self.session_id)) + + def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, + custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None, execute_as=None): + """ + Execute the given query and synchronously wait for the response. + + If an error is encountered while executing the query, an Exception + will be raised. + + `query` may be a query string or an instance of :class:`cassandra.query.Statement`. + + `parameters` may be a sequence or dict of parameters to bind. If a + sequence is used, ``%s`` should be used the placeholder for each + argument. If a dict is used, ``%(name)s`` style placeholders must + be used. + + `timeout` should specify a floating-point timeout (in seconds) after + which an :exc:`.OperationTimedOut` exception will be raised if the query + has not completed. If not set, the timeout defaults to the request_timeout of the selected ``execution_profile``. + If set to :const:`None`, there is no timeout. Please see :meth:`.ResponseFuture.result` for details on + the scope and effect of this timeout. + + If `trace` is set to :const:`True`, the query will be sent with tracing enabled. + The trace details can be obtained using the returned :class:`.ResultSet` object. + + `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. + If `query` is a Statement with its own custom_payload. The message payload + will be a union of the two, with the values specified here taking precedence. + + `execution_profile` is the execution profile to use for this request. It can be a key to a profile configured + via :meth:`Cluster.add_execution_profile` or an instance (from :meth:`Session.execution_profile_clone_update`, + for example + + `paging_state` is an optional paging state, reused from a previous :class:`ResultSet`. + + `host` is the :class:`cassandra.pool.Host` that should handle the query. If the host specified is down or + not yet connected, the query will fail with :class:`NoHostAvailable`. Using this is + discouraged except in a few cases, e.g., querying node-local tables and applying schema changes. + + `execute_as` the user that will be used on the server to execute the request. This is only available + on a DSE cluster. + """ + + return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host, execute_as).result() + + def execute_async(self, query, parameters=None, trace=False, custom_payload=None, + timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None, execute_as=None): + """ + Execute the given query and return a :class:`~.ResponseFuture` object + which callbacks may be attached to for asynchronous response + delivery. You may also call :meth:`~.ResponseFuture.result()` + on the :class:`.ResponseFuture` to synchronously block for results at + any time. + + See :meth:`Session.execute` for parameter definitions. + + Example usage:: + + >>> session = cluster.connect() + >>> future = session.execute_async("SELECT * FROM mycf") + + >>> def log_results(results): + ... for row in results: + ... log.info("Results: %s", row) + + >>> def log_error(exc): + >>> log.error("Operation failed: %s", exc) + + >>> future.add_callbacks(log_results, log_error) + + Async execution with blocking wait for results:: + + >>> future = session.execute_async("SELECT * FROM mycf") + >>> # do other stuff... + + >>> try: + ... results = future.result() + ... except Exception: + ... log.exception("Operation failed:") + + """ + custom_payload = custom_payload if custom_payload else {} + if execute_as: + custom_payload[_proxy_execute_key] = execute_as.encode() + + future = self._create_response_future( + query, parameters, trace, custom_payload, timeout, + execution_profile, paging_state, host) + future._protocol_handler = self.client_protocol_handler + self._on_request(future) + future.send_request() + return future + + def execute_concurrent(self, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False, execution_profile=EXEC_PROFILE_DEFAULT): + """ + Executes a sequence of (statement, parameters) tuples concurrently. Each + ``parameters`` item must be a sequence or :const:`None`. + + The `concurrency` parameter controls how many statements will be executed + concurrently. When :attr:`.Cluster.protocol_version` is set to 1 or 2, + it is recommended that this be kept below 100 times the number of + core connections per host times the number of connected hosts (see + :meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded, + the event loop thread may attempt to block on new connection creation, + substantially impacting throughput. If :attr:`~.Cluster.protocol_version` + is 3 or higher, you can safely experiment with higher levels of concurrency. + + If `raise_on_first_error` is left as :const:`True`, execution will stop + after the first failed statement and the corresponding exception will be + raised. + + `results_generator` controls how the results are returned. + + * If :const:`False`, the results are returned only after all requests have completed. + * If :const:`True`, a generator expression is returned. Using a generator results in a constrained + memory footprint when the results set will be large -- results are yielded + as they return instead of materializing the entire list at once. The trade for lower memory + footprint is marginal CPU overhead (more thread coordination and sorting out-of-order results + on-the-fly). + + `execution_profile` argument is the execution profile to use for this + request, it is passed directly to :meth:`Session.execute_async`. + + A sequence of ``ExecutionResult(success, result_or_exc)`` namedtuples is returned + in the same order that the statements were passed in. If ``success`` is :const:`False`, + there was an error executing the statement, and ``result_or_exc`` + will be an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc`` + will be the query result. + + Example usage:: + + select_statement = session.prepare("SELECT * FROM users WHERE id=?") + + statements_and_params = [] + for user_id in user_ids: + params = (user_id, ) + statements_and_params.append((select_statement, params)) + + results = session.execute_concurrent(statements_and_params, raise_on_first_error=False) + + for (success, result) in results: + if not success: + handle_error(result) # result will be an Exception + else: + process_user(result[0]) # result will be a list of rows + + Note: in the case that `generators` are used, it is important to ensure the consumers do not + block or attempt further synchronous requests, because no further IO will be processed until + the consumer returns. This may also produce a deadlock in the IO event thread. + """ + from cassandra.concurrent import execute_concurrent + return execute_concurrent(self, statements_and_parameters, concurrency, raise_on_first_error, results_generator, execution_profile) + + def execute_concurrent_with_args(self, statement, parameters, *args, **kwargs): + """ + Like :meth:`~cassandra.concurrent.execute_concurrent()`, but takes a single + statement and a sequence of parameters. Each item in ``parameters`` + should be a sequence or :const:`None`. + + Example usage:: + + statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)") + parameters = [(x,) for x in range(1000)] + session.execute_concurrent_with_args(statement, parameters, concurrency=50) + """ + from cassandra.concurrent import execute_concurrent_with_args + return execute_concurrent_with_args(self, statement, parameters, *args, **kwargs) + + def execute_concurrent_async(self, statements_and_parameters, concurrency=100, raise_on_first_error=False, execution_profile=EXEC_PROFILE_DEFAULT): + """ + Asynchronously executes a sequence of (statement, parameters) tuples concurrently. + + Args: + session: Cassandra session object. + statement_and_parameters: Iterable of (prepared CQL statement, bind parameters) tuples. + concurrency (int, optional): Number of concurrent operations. Default is 100. + raise_on_first_error (bool, optional): If True, execution stops on the first error. Default is True. + execution_profile (ExecutionProfile, optional): Execution profile to use. Default is EXEC_PROFILE_DEFAULT. + + Returns: + A `Future` object that will be completed when all operations are done. + """ + from cassandra.concurrent import execute_concurrent_async + return execute_concurrent_async(self, statements_and_parameters, concurrency, raise_on_first_error, execution_profile) + + def execute_graph(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None): + """ + Executes a Gremlin query string or GraphStatement synchronously, + and returns a ResultSet from this execution. + + `parameters` is dict of named parameters to bind. The values must be + JSON-serializable. + + `execution_profile`: Selects an execution profile for the request. + + `execute_as` the user that will be used on the server to execute the request. + """ + return self.execute_graph_async(query, parameters, trace, execution_profile, execute_as).result() + + def execute_graph_async(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None): + """ + Execute the graph query and return a :class:`ResponseFuture` + object which callbacks may be attached to for asynchronous response delivery. You may also call ``ResponseFuture.result()`` to synchronously block for + results at any time. + """ + if self.cluster._config_mode is _ConfigMode.LEGACY: + raise ValueError(("Cannot execute graph queries using Cluster legacy parameters. " + "Consider using Execution profiles: " + "https://docs.datastax.com/en/developer/python-driver/latest/execution_profiles/#execution-profiles")) + + if not isinstance(query, GraphStatement): + query = SimpleGraphStatement(query) + + # Clone and look up instance here so we can resolve and apply the extended attributes + execution_profile = self.execution_profile_clone_update(execution_profile) + + if not hasattr(execution_profile, 'graph_options'): + raise ValueError( + "Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options") + + self._resolve_execution_profile_options(execution_profile) + + # make sure the graphson context row factory is binded to this cluster + try: + if issubclass(execution_profile.row_factory, _GraphSONContextRowFactory): + execution_profile.row_factory = execution_profile.row_factory(self.cluster) + except TypeError: + # issubclass might fail if arg1 is an instance + pass + + # set graph paging if needed + self._maybe_set_graph_paging(execution_profile) + + graph_parameters = None + if parameters: + graph_parameters = self._transform_params(parameters, graph_options=execution_profile.graph_options) + + custom_payload = execution_profile.graph_options.get_options_map() + if execute_as: + custom_payload[_proxy_execute_key] = execute_as.encode() + custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000)) + + future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, + timeout=_NOT_SET, execution_profile=execution_profile) + + future.message.query_params = graph_parameters + future._protocol_handler = self.client_protocol_handler + + if execution_profile.graph_options.is_analytics_source and \ + isinstance(execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy): + self._target_analytics_master(future) + else: + future.send_request() + return future + + def _maybe_set_graph_paging(self, execution_profile): + graph_paging = execution_profile.continuous_paging_options + if execution_profile.continuous_paging_options is _NOT_SET: + graph_paging = ContinuousPagingOptions() if self._graph_paging_available else None + + execution_profile.continuous_paging_options = graph_paging + + def _check_graph_paging_available(self): + """Verify if we can enable graph paging. This executed only once when the session is created.""" + + if not ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version): + return False + + for host in self.cluster.metadata.all_hosts(): + if host.dse_version is None: + return False + + version = Version(host.dse_version) + if version < _GRAPH_PAGING_MIN_DSE_VERSION: + return False + + return True + + def _resolve_execution_profile_options(self, execution_profile): + """ + Determine the GraphSON protocol and row factory for a graph query. This is useful + to configure automatically the execution profile when executing a query on a + core graph. + + If `graph_protocol` is not explicitly specified, the following rules apply: + - Default to GraphProtocol.GRAPHSON_1_0, or GRAPHSON_2_0 if the `graph_language` is not gremlin-groovy. + - If `graph_options.graph_name` is specified and is a Core graph, set GraphSON_3_0. + If `row_factory` is not explicitly specified, the following rules apply: + - Default to graph_object_row_factory. + - If `graph_options.graph_name` is specified and is a Core graph, set graph_graphson3_row_factory. + """ + if execution_profile.graph_options.graph_protocol is not None and \ + execution_profile.row_factory is not None: + return + + graph_options = execution_profile.graph_options + + is_core_graph = False + if graph_options.graph_name: + # graph_options.graph_name is bytes ... + name = graph_options.graph_name.decode('utf-8') + if name in self.cluster.metadata.keyspaces: + ks_metadata = self.cluster.metadata.keyspaces[name] + if ks_metadata.graph_engine == 'Core': + is_core_graph = True + + if is_core_graph: + graph_protocol = GraphProtocol.GRAPHSON_3_0 + row_factory = graph_graphson3_row_factory + else: + if graph_options.graph_language == GraphOptions.DEFAULT_GRAPH_LANGUAGE: + graph_protocol = GraphOptions.DEFAULT_GRAPH_PROTOCOL + row_factory = graph_object_row_factory + else: + # if not gremlin-groovy, GraphSON_2_0 + graph_protocol = GraphProtocol.GRAPHSON_2_0 + row_factory = graph_graphson2_row_factory + + # Only apply if not set explicitly + if graph_options.graph_protocol is None: + graph_options.graph_protocol = graph_protocol + if execution_profile.row_factory is None: + execution_profile.row_factory = row_factory + + def _transform_params(self, parameters, graph_options): + if not isinstance(parameters, dict): + raise ValueError('The parameters must be a dictionary. Unnamed parameters are not allowed.') + + # Serialize python types to graphson + serializer = GraphSON1Serializer + if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0: + serializer = GraphSON2Serializer() + elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0: + # only required for core graphs + context = { + 'cluster': self.cluster, + 'graph_name': graph_options.graph_name.decode('utf-8') if graph_options.graph_name else None + } + serializer = GraphSON3Serializer(context) + + serialized_parameters = serializer.serialize(parameters) + return [json.dumps(serialized_parameters).encode('utf-8')] + + def _target_analytics_master(self, future): + future._start_timer() + master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()", + parameters=None, trace=False, + custom_payload=None, timeout=future.timeout) + master_query_future.row_factory = tuple_factory + master_query_future.send_request() + + cb = self._on_analytics_master_result + args = (master_query_future, future) + master_query_future.add_callbacks(callback=cb, callback_args=args, errback=cb, errback_args=args) + + def _on_analytics_master_result(self, response, master_future, query_future): + try: + row = master_future.result()[0] + addr = row[0]['location'] + delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided + if delimiter_index > 0: + addr = addr[:delimiter_index] + targeted_query = HostTargetingStatement(query_future.query, addr) + query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) + except Exception: + log.debug("Failed querying analytics master (request might not be routed optimally). " + "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True) - row_factory = staticmethod(named_tuple_factory) - """ - The format to return row results in. By default, each - returned row will be a named tuple. You can alternatively - use any of the following: + self.submit(query_future.send_request) - - :func:`cassandra.decoder.tuple_factory` - - :func:`cassandra.decoder.named_tuple_factory` - - :func:`cassandra.decoder.dict_factory` - - :func:`cassandra.decoder.ordered_dict_factory` + def _create_response_future(self, query, parameters, trace, custom_payload, + timeout, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): + """ Returns the ResponseFuture before calling send_request() on it """ - """ + prepared_statement = None - _lock = None - _pools = None - _load_balancer = None - _metrics = None + if isinstance(query, str): + query = SimpleStatement(query) + elif isinstance(query, PreparedStatement): + query = query.bind(parameters) - def __init__(self, cluster, hosts): - self.cluster = cluster - self.hosts = hosts + if self.cluster._config_mode == _ConfigMode.LEGACY: + if execution_profile is not EXEC_PROFILE_DEFAULT: + raise ValueError("Cannot specify execution_profile while using legacy parameters.") - self._lock = RLock() - self._pools = {} - self._load_balancer = cluster.load_balancing_policy - self._metrics = cluster.metrics + if timeout is _NOT_SET: + timeout = self.default_timeout - for host in hosts: - self.add_host(host) + cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level + serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level - def execute(self, query, parameters=None, trace=False): - """ - Execute the given query and synchronously wait for the response. + retry_policy = query.retry_policy or self.cluster.default_retry_policy + row_factory = self.row_factory + load_balancing_policy = self.cluster.load_balancing_policy + spec_exec_policy = None + continuous_paging_options = None + else: + execution_profile = self._maybe_get_execution_profile(execution_profile) - If an error is encountered while executing the query, an Exception - will be raised. + if timeout is _NOT_SET: + timeout = execution_profile.request_timeout - `query` may be a query string or an instance of :class:`cassandra.query.Statement`. + cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level + serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level + continuous_paging_options = execution_profile.continuous_paging_options - `parameters` may be a sequence or dict of parameters to bind. If a - sequence is used, ``%s`` should be used the placeholder for each - argument. If a dict is used, ``%(name)s`` style placeholders must - be used. + retry_policy = query.retry_policy or execution_profile.retry_policy + row_factory = execution_profile.row_factory + load_balancing_policy = execution_profile.load_balancing_policy + spec_exec_policy = execution_profile.speculative_execution_policy - If `trace` is set to :const:`True`, an attempt will be made to - fetch the trace details and attach them to the `query`'s - :attr:`~.Statement.trace` attribute in the form of a :class:`.QueryTrace` - instance. This requires that `query` be a :class:`.Statement` subclass - instance and not just a string. If there is an error fetching the - trace details, the :attr:`~.Statement.trace` attribute will be left as - :const:`None`. - """ - if trace and not isinstance(query, Statement): - raise TypeError( - "The query argument must be an instance of a subclass of " - "cassandra.query.Statement when trace=True") + fetch_size = query.fetch_size + if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: + fetch_size = self.default_fetch_size + elif self._protocol_version == 1: + fetch_size = None - future = self.execute_async(query, parameters, trace) - try: - result = future.result() - finally: - if trace: - try: - query.trace = future.get_query_trace() - except Exception: - log.exception("Unable to fetch query trace:") + start_time = time.time() + if self._protocol_version >= 3 and self.use_client_timestamp: + timestamp = self.cluster.timestamp_generator() + else: + timestamp = None - return result + supports_continuous_paging_state = ( + ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version) + ) + if continuous_paging_options and supports_continuous_paging_state: + continuous_paging_state = ContinuousPagingState(continuous_paging_options.max_queue_size) + else: + continuous_paging_state = None - def execute_async(self, query, parameters=None, trace=False): + if isinstance(query, SimpleStatement): + query_string = query.query_string + statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None + if parameters: + query_string = bind_params(query_string, parameters, self.encoder) + message = QueryMessage( + query_string, cl, serial_cl, + fetch_size, paging_state, timestamp, + continuous_paging_options, statement_keyspace) + elif isinstance(query, BoundStatement): + prepared_statement = query.prepared_statement + message = ExecuteMessage( + prepared_statement.query_id, query.values, cl, + serial_cl, fetch_size, paging_state, timestamp, + skip_meta=bool(prepared_statement.result_metadata), + continuous_paging_options=continuous_paging_options, + result_metadata_id=prepared_statement.result_metadata_id) + elif isinstance(query, BatchStatement): + if self._protocol_version < 2: + raise UnsupportedOperation( + "BatchStatement execution is only supported with protocol version " + "2 or higher (supported in Cassandra 2.0 and higher). Consider " + "setting Cluster.protocol_version to 2 to support this operation.") + statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None + message = BatchMessage( + query.batch_type, query._statements_and_parameters, cl, + serial_cl, timestamp, statement_keyspace) + elif isinstance(query, GraphStatement): + # the statement_keyspace is not aplicable to GraphStatement + message = QueryMessage(query.query, cl, serial_cl, fetch_size, + paging_state, timestamp, + continuous_paging_options) + + message.tracing = trace + message.update_custom_payload(query.custom_payload) + message.update_custom_payload(custom_payload) + message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version + + spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None + return ResponseFuture( + self, message, query, timeout, metrics=self._metrics, + prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, + load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan, + continuous_paging_state=continuous_paging_state, host=host) + + def get_execution_profile(self, name): """ - Execute the given query and return a :class:`~.ResponseFuture` object - which callbacks may be attached to for asynchronous response - delivery. You may also call :meth:`~.ResponseFuture.result()` - on the ``ResponseFuture`` to syncronously block for results at - any time. - - If `trace` is set to :const:`True`, you may call - :meth:`.ResponseFuture.get_query_trace()` after the request - completes to retrieve a :class:`.QueryTrace` instance. + Returns the execution profile associated with the provided ``name``. - Example usage:: + :param name: The name (or key) of the execution profile. + """ + profiles = self.cluster.profile_manager.profiles + try: + return profiles[name] + except KeyError: + eps = [_execution_profile_to_string(ep) for ep in profiles.keys()] + raise ValueError("Invalid execution_profile: %s; valid profiles are: %s." % ( + _execution_profile_to_string(name), ', '.join(eps))) - >>> session = cluster.connect() - >>> future = session.execute_async("SELECT * FROM mycf") + def _maybe_get_execution_profile(self, ep): + return ep if isinstance(ep, ExecutionProfile) else self.get_execution_profile(ep) - >>> def log_results(results): - ... for row in results: - ... log.info("Results: %s", row) + def execution_profile_clone_update(self, ep, **kwargs): + """ + Returns a clone of the ``ep`` profile. ``kwargs`` can be specified to update attributes + of the returned profile. - >>> def log_error(exc): - >>> log.error("Operation failed: %s", exc) + This is a shallow clone, so any objects referenced by the profile are shared. This means Load Balancing Policy + is maintained by inclusion in the active profiles. It also means updating any other rich objects will be seen + by the active profile. In cases where this is not desirable, be sure to replace the instance instead of manipulating + the shared object. + """ + clone = copy(self._maybe_get_execution_profile(ep)) + for attr, value in kwargs.items(): + setattr(clone, attr, value) + return clone - >>> future.add_callbacks(log_results, log_error) + def add_request_init_listener(self, fn, *args, **kwargs): + """ + Adds a callback with arguments to be called when any request is created. - Async execution with blocking wait for results:: + It will be invoked as `fn(response_future, *args, **kwargs)` after each client request is created, + and before the request is sent. This can be used to create extensions by adding result callbacks to the + response future. - >>> future = session.execute_async("SELECT * FROM mycf") - >>> # do other stuff... + `response_future` is the :class:`.ResponseFuture` for the request. - >>> try: - ... results = future.result() - ... except Exception: - ... log.exception("Operation failed:") + Note that the init callback is done on the client thread creating the request, so you may need to consider + synchronization if you have multiple threads. Any callbacks added to the response future will be executed + on the event loop thread, so the normal advice about minimizing cycles and avoiding blocking apply (see Note in + :meth:`.ResponseFuture.add_callbacks`. + See `this example `_ in the + source tree for an example. """ - if isinstance(query, basestring): - query = SimpleStatement(query) - elif isinstance(query, PreparedStatement): - query = query.bind(parameters) + self._request_init_callbacks.append((fn, args, kwargs)) - if isinstance(query, BoundStatement): - message = ExecuteMessage( - query_id=query.prepared_statement.query_id, - query_params=query.values, - consistency_level=query.consistency_level) - else: - query_string = query.query_string - if parameters: - query_string = bind_params(query.query_string, parameters) - message = QueryMessage(query=query_string, consistency_level=query.consistency_level) + def remove_request_init_listener(self, fn, *args, **kwargs): + """ + Removes a callback and arguments from the list. - if trace: - message.tracing = True + See :meth:`.Session.add_request_init_listener`. + """ + self._request_init_callbacks.remove((fn, args, kwargs)) - future = ResponseFuture(self, message, query, metrics=self._metrics) - future.send_request() - return future + def _on_request(self, response_future): + for fn, args, kwargs in self._request_init_callbacks: + fn(response_future, *args, **kwargs) - def prepare(self, query): + def prepare(self, query, custom_payload=None, keyspace=None): """ - Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement` + Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement` instance which can be used as follows:: >>> session = cluster.connect("mykeyspace") >>> query = "INSERT INTO users (id, name, age) VALUES (?, ?, ?)" >>> prepared = session.prepare(query) - >>> session.execute(prepared.bind((user.id, user.name, user.age))) + >>> session.execute(prepared, (user.id, user.name, user.age)) + Or you may bind values to the prepared statement ahead of time:: + + >>> prepared = session.prepare(query) + >>> bound_stmt = prepared.bind((user.id, user.name, user.age)) + >>> session.execute(bound_stmt) + + Of course, prepared statements may (and should) be reused:: + + >>> prepared = session.prepare(query) + >>> for user in users: + ... bound = prepared.bind((user.id, user.name, user.age)) + ... session.execute(bound) + + Alternatively, if :attr:`~.Cluster.protocol_version` is 5 or higher + (requires Cassandra 4.0+), the keyspace can be specified as a + parameter. This will allow you to avoid specifying the keyspace in the + query without specifying a keyspace in :meth:`~.Cluster.connect`. It + even will let you prepare and use statements against a keyspace other + than the one originally specified on connection: + + >>> analyticskeyspace_prepared = session.prepare( + ... "INSERT INTO user_activity id, last_activity VALUES (?, ?)", + ... keyspace="analyticskeyspace") # note the different keyspace + + **Important**: PreparedStatements should be prepared only once. + Preparing the same query more than once will likely affect performance. + + `custom_payload` is a key value map to be passed along with the prepare + message. See :ref:`custom_payload`. """ - message = PrepareMessage(query=query) - future = ResponseFuture(self, message, query=None) + message = PrepareMessage(query=query, keyspace=keyspace) + future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() - query_id, column_metadata = future.result() + response = future.result().one() except Exception: log.exception("Error preparing query:") raise + prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( - query_id, column_metadata, self.cluster.metadata, query, self.keyspace) + response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, + self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy) + prepared_statement.custom_payload = future.custom_payload - host = future._current_host - try: - self.cluster.prepare_on_all_sessions(query_id, prepared_statement, host) - except Exception: - log.exception("Error preparing query on all hosts:") + self.cluster.add_prepared(response.query_id, prepared_statement) + + if self.cluster.prepare_on_all_hosts: + host = future._current_host + try: + self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) + except Exception: + log.exception("Error preparing query on all hosts:") return prepared_statement - def prepare_on_all_hosts(self, query, excluded_host): + def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. """ - for host, pool in self._pools.items(): - if host != excluded_host: - future = ResponseFuture(self, PrepareMessage(query=query), None) + futures = [] + for host in tuple(self._pools.keys()): + if host != excluded_host and host.is_up: + future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace), + None, self.default_timeout) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared @@ -727,18 +3257,22 @@ def prepare_on_all_hosts(self, query, excluded_host): try: request_id = future._query(host) except Exception: - log.exception("Error preparing query for host %s:" % (host,)) + log.exception("Error preparing query for host %s:", host) continue if request_id is None: - # the error has already been logged by ResponsFuture - log.debug("Failed to prepare query for host %s" % (host,)) + # the error has already been logged by ResponseFuture + log.debug("Failed to prepare query for host %s: %r", + host, future._errors.get(host)) continue - try: - future.result() - except Exception: - log.exception("Error preparing query for host %s:" % (host,)) + futures.append((host, future)) + + for host, future in futures: + try: + future.result() + except Exception: + log.exception("Error preparing query for host %s:", host) def shutdown(self): """ @@ -751,92 +3285,233 @@ def shutdown(self): else: self.is_shutdown = True - for pool in self._pools.values(): + # PYTHON-673. If shutdown was called shortly after session init, avoid + # a race by cancelling any initial connection attempts haven't started, + # then blocking on any that have. + for future in self._initial_connect_futures: + future.cancel() + wait_futures(self._initial_connect_futures) + + if self._monitor_reporter: + self._monitor_reporter.stop() + + for pool in tuple(self._pools.values()): pool.shutdown() + def __enter__(self): + return self + + def __exit__(self, *args): + self.shutdown() + def __del__(self): try: + # Ensure all connections are closed, in case the Session object is deleted by the GC self.shutdown() - del self.cluster - except TypeError: + except: + # Ignore all errors. Shutdown errors can be caught by the user + # when cluster.shutdown() is called explicitly. pass - def add_host(self, host): - """ Internal """ - distance = self._load_balancer.distance(host) + def add_or_renew_pool(self, host, is_host_addition): + """ + For internal use only. + """ + distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: - return self._pools.get(host) - else: + return None + + def run_add_or_renew_pool(): try: - new_pool = HostConnectionPool(host, distance, self) + if self._protocol_version >= 3: + new_pool = HostConnection(host, distance, self) + else: + # TODO remove host pool again ??? + new_pool = HostConnectionPool(host, distance, self) except AuthenticationFailed as auth_exc: - conn_exc = ConnectionException(str(auth_exc), host=host) - host.monitor.signal_connection_failure(conn_exc) - return self._pools.get(host) + conn_exc = ConnectionException(str(auth_exc), endpoint=host) + self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + return False except Exception as conn_exc: - host.monitor.signal_connection_failure(conn_exc) - return self._pools.get(host) + log.warning("Failed to create connection pool for new host %s:", + host, exc_info=conn_exc) + # the host itself will still be marked down, so we need to pass + # a special flag to make sure the reconnector is created + self.cluster.signal_connection_failure( + host, conn_exc, is_host_addition, expect_host_to_be_down=True) + return False previous = self._pools.get(host) - self._pools[host] = new_pool - return previous + with self._lock: + while new_pool._keyspace != self.keyspace: + self._lock.release() + set_keyspace_event = Event() + errors_returned = [] + + def callback(pool, errors): + errors_returned.extend(errors) + set_keyspace_event.set() + + new_pool._set_keyspace_for_all_conns(self.keyspace, callback) + set_keyspace_event.wait(self.cluster.connect_timeout) + if not set_keyspace_event.is_set() or errors_returned: + log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) + self.cluster.on_down(host, is_host_addition) + new_pool.shutdown() + self._lock.acquire() + return False + self._lock.acquire() + self._pools[host] = new_pool + + log.debug("Added pool for host %s to session", host) + if previous: + previous.shutdown() - def on_up(self, host): - """ - Called by the parent Cluster instance when a host's :class:`HealthMonitor` - marks it up. Only intended for internal use. - """ - previous_pool = self.add_host(host) - self._load_balancer.on_up(host) - if previous_pool: - previous_pool.shutdown() + return True - def on_down(self, host): - """ - Called by the parent Cluster instance when a host's :class:`HealthMonitor` - marks it down. Only intended for internal use. - """ - self._load_balancer.on_down(host) + return self.submit(run_add_or_renew_pool) + + def remove_pool(self, host): pool = self._pools.pop(host, None) if pool: - pool.shutdown() + log.debug("Removed connection pool for %r", host) + return self.submit(pool.shutdown) + else: + return None - for host in self.cluster.metadata.all_hosts(): - if not host.monitor.is_up: - continue + def update_created_pools(self): + """ + When the set of live nodes change, the loadbalancer will change its + mind on host distances. It might change it on the node that came/left + but also on other nodes (for instance, if a node dies, another + previously ignored node may be now considered). - distance = self._load_balancer.distance(host) - if distance != HostDistance.IGNORED: - pool = self._pools.get(host) - if not pool: - self.add_host(host) + This method ensures that all hosts for which a pool should exist + have one, and hosts that shouldn't don't. + + For internal use only. + """ + futures = set() + for host in self.cluster.metadata.all_hosts(): + distance = self._profile_manager.distance(host) + pool = self._pools.get(host) + future = None + if not pool or pool.is_shutdown: + # we don't eagerly set is_up on previously ignored hosts. None is included here + # to allow us to attempt connections to hosts that have gone from ignored to something + # else. + if distance != HostDistance.IGNORED and host.is_up in (True, None): + future = self.add_or_renew_pool(host, False) + elif distance != pool.host_distance: + # the distance has changed + if distance == HostDistance.IGNORED: + future = self.remove_pool(host) else: pool.host_distance = distance + if future: + futures.add(future) + return futures - def on_add(self, host): - """ Internal """ - previous_pool = self.add_host(host) - self._load_balancer.on_add(host) - if previous_pool: - previous_pool.shutdown() + def on_down(self, host): + """ + Called by the parent Cluster instance when a node is marked down. + Only intended for internal use. + """ + future = self.remove_pool(host) + if future: + future.add_done_callback(lambda f: self.update_created_pools()) def on_remove(self, host): """ Internal """ - self._load_balancer.on_remove(host) - pool = self._pools.pop(host) - if pool: - pool.shutdown() + self.on_down(host) def set_keyspace(self, keyspace): """ Set the default keyspace for all queries made through this Session. This operation blocks until complete. """ - self.execute('USE "%s"' % (keyspace,)) + self.execute('USE %s' % (protect_name(keyspace),)) + + def _set_keyspace_for_all_pools(self, keyspace, callback): + """ + Asynchronously sets the keyspace on all pools. When all + pools have set all of their connections, `callback` will be + called with a dictionary of all errors that occurred, keyed + by the `Host` that they occurred against. + """ + with self._lock: + self.keyspace = keyspace + remaining_callbacks = set(self._pools.values()) + errors = {} + + if not remaining_callbacks: + callback(errors) + return + + def pool_finished_setting_keyspace(pool, host_errors): + remaining_callbacks.remove(pool) + if host_errors: + errors[pool.host] = host_errors + + if not remaining_callbacks: + callback(host_errors) + + for pool in tuple(self._pools.values()): + pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) + + def user_type_registered(self, keyspace, user_type, klass): + """ + Called by the parent Cluster instance when the user registers a new + mapping from a user-defined type to a class. Intended for internal + use only. + """ + try: + ks_meta = self.cluster.metadata.keyspaces[keyspace] + except KeyError: + raise UserTypeDoesNotExist( + 'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,)) + + try: + type_meta = ks_meta.user_types[user_type] + except KeyError: + raise UserTypeDoesNotExist( + 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) + + field_names = type_meta.field_names + + def encode(val): + return '{ %s }' % ' , '.join('%s : %s' % ( + field_name, + self.encoder.cql_encode_all_types(getattr(val, field_name, None)) + ) for field_name in field_names) + + self.encoder.mapping[klass] = encode def submit(self, fn, *args, **kwargs): """ Internal """ - return self.cluster.executor.submit(fn, *args, **kwargs) + if not self.is_shutdown: + return self.cluster.executor.submit(fn, *args, **kwargs) + + def get_pool_state(self): + return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) + + def get_pools(self): + return self._pools.values() + + def _validate_set_legacy_config(self, attr_name, value): + if self.cluster._config_mode == _ConfigMode.PROFILES: + raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) + setattr(self, '_' + attr_name, value) + self.cluster._config_mode = _ConfigMode.LEGACY + + +class UserTypeDoesNotExist(Exception): + """ + An attempt was made to use a user-defined type that does not exist. + + .. versionadded:: 2.1.0 + """ + pass class _ControlReconnectionHandler(_ReconnectionHandler): @@ -849,7 +3524,6 @@ def __init__(self, control_connection, *args, **kwargs): self.control_connection = weakref.proxy(control_connection) def try_reconnect(self): - # we'll either get back a new Connection or a NoHostAvailable return self.control_connection._reconnect_internal() def on_reconnection(self, connection): @@ -860,36 +3534,91 @@ def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): return False else: - log.debug("Error trying to reconnect control connection: %r" % (exc,)) + log.debug("Error trying to reconnect control connection: %r", exc) return True +def _watch_callback(obj_weakref, method_name, *args, **kwargs): + """ + A callback handler for the ControlConnection that tolerates + weak references. + """ + obj = obj_weakref() + if obj is None: + return + getattr(obj, method_name)(*args, **kwargs) + + +def _clear_watcher(conn, expiring_weakref): + """ + Called when the ControlConnection object is about to be finalized. + This clears watchers on the underlying Connection object. + """ + try: + conn.control_conn_disposed() + except ReferenceError: + pass + + class ControlConnection(object): """ Internal """ - _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" - _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" - _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" + _SELECT_PEERS = "SELECT * FROM system.peers" + _SELECT_PEERS_NO_TOKENS_TEMPLATE = "SELECT host_id, peer, data_center, rack, rpc_address, {nt_col_name}, release_version, schema_version FROM system.peers" + _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'" + _SELECT_LOCAL_NO_TOKENS = "SELECT host_id, cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" + # Used only when token_metadata_enabled is set to False + _SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = "SELECT rpc_address FROM system.local WHERE key='local'" - _SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address FROM system.peers" - _SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner FROM system.local WHERE key='local'" - - _SELECT_SCHEMA_PEERS = "SELECT rpc_address, schema_version FROM system.peers" + _SELECT_SCHEMA_PEERS_TEMPLATE = "SELECT peer, host_id, {nt_col_name}, schema_version FROM system.peers" _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" + _SELECT_PEERS_V2 = "SELECT * FROM system.peers_v2" + _SELECT_PEERS_NO_TOKENS_V2 = "SELECT host_id, peer, peer_port, data_center, rack, native_address, native_port, release_version, schema_version FROM system.peers_v2" + _SELECT_SCHEMA_PEERS_V2 = "SELECT host_id, peer, peer_port, native_address, native_port, schema_version FROM system.peers_v2" + + _MINIMUM_NATIVE_ADDRESS_DSE_VERSION = Version("6.0.0") + + class PeersQueryType(object): + """internal Enum for _peers_query""" + PEERS = 0 + PEERS_SCHEMA = 1 + + _is_shutdown = False + _timeout = None + _protocol_version = None + + _schema_event_refresh_window = None + _topology_event_refresh_window = None + _status_event_refresh_window = None + + _schema_meta_enabled = True + _token_meta_enabled = True + + _uses_peers_v2 = True + # for testing purposes _time = time - def __init__(self, cluster): + def __init__(self, cluster, timeout, + schema_event_refresh_window, + topology_event_refresh_window, + status_event_refresh_window, + schema_meta_enabled=True, + token_meta_enabled=True): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) - self._balancing_policy = cluster.load_balancing_policy - self._balancing_policy.populate(cluster, []) - self._reconnection_policy = cluster.reconnection_policy self._connection = None + self._timeout = timeout + + self._schema_event_refresh_window = schema_event_refresh_window + self._topology_event_refresh_window = topology_event_refresh_window + self._status_event_refresh_window = status_event_refresh_window + self._schema_meta_enabled = schema_meta_enabled + self._token_meta_enabled = token_meta_enabled self._lock = RLock() self._schema_agreement_lock = Lock() @@ -897,14 +3626,17 @@ def __init__(self, cluster): self._reconnection_handler = None self._reconnection_lock = RLock() - self._is_shutdown = False + self._event_schedule_times = {} def connect(self): if self._is_shutdown: return + self._protocol_version = self._cluster.protocol_version self._set_new_connection(self._reconnect_internal()) + self._cluster.metadata.dbaas = self._connection._product_type == dscloud.DATASTAX_CLOUD_PRODUCT_TYPE + def _set_new_connection(self, conn): """ Replace existing connection (if there is one) and close it. @@ -914,6 +3646,7 @@ def _set_new_connection(self, conn): self._connection = conn if old: + log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() def _reconnect_internal(self): @@ -926,16 +3659,24 @@ def _reconnect_internal(self): a connection to that host. """ errors = {} - for host in self._balancing_policy.make_query_plan(): + lbp = ( + self._cluster.load_balancing_policy + if self._cluster._config_mode == _ConfigMode.LEGACY else + self._cluster._default_load_balancing_policy + ) + + for host in lbp.make_query_plan(): try: return self._try_connect(host) except ConnectionException as exc: - errors[host.address] = exc - host.monitor.signal_connection_failure(exc) - log.warn("[control connection] Error connecting to %s:", host, exc_info=True) + errors[str(host.endpoint)] = exc + log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: - errors[host.address] = exc - log.warn("[control connection] Error connecting to %s:", host, exc_info=True) + errors[str(host.endpoint)] = exc + log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + if self._is_shutdown: + raise DriverException("[control connection] Reconnection in progress during shutdown") raise NoHostAvailable("Unable to connect to any servers", errors) @@ -945,20 +3686,62 @@ def _try_connect(self, host): node/token and schema metadata. """ log.debug("[control connection] Opening new connection to %s", host) - connection = self._cluster.connection_factory(host.address) - log.debug("[control connection] Established new connection to %s, " + while True: + try: + connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True) + if self._is_shutdown: + connection.close() + raise DriverException("Reconnecting during shutdown") + break + except ProtocolVersionUnsupported as e: + self._cluster.protocol_downgrade(host.endpoint, e.startup_version) + except ProtocolException as e: + # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver + # protocol version. If the protocol version was not explicitly specified, + # and that the server raises a beta protocol error, we should downgrade. + if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: + self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) + else: + raise + + log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", - host) + connection) + + # use weak references in both directions + # _clear_watcher will be called when this ControlConnection is about to be finalized + # _watch_callback will get the actual callback from the Connection and relay it to + # this object (after a dereferencing a weakref) + self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: connection.register_watchers({ - "TOPOLOGY_CHANGE": self._handle_topology_change, - "STATUS_CHANGE": self._handle_status_change, - "SCHEMA_CHANGE": self._handle_schema_change - }) - - self._refresh_node_list_and_token_map(connection) - self._refresh_schema(connection) + "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), + "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), + "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') + }, register_timeout=self._timeout) + + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS + peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) + local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE) + (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses( + peers_query, local_query, timeout=self._timeout, fail_on_error=False) + + if not local_success: + raise local_result + + if not peers_success: + # error with the peers v2 query, fallback to peers v1 + self._uses_peers_v2 = False + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) + peers_result = connection.wait_for_response( + peers_query, timeout=self._timeout) + + shared_results = (peers_result, local_result) + self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) + self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) except Exception: connection.close() raise @@ -969,7 +3752,7 @@ def reconnect(self): if self._is_shutdown: return - self._cluster.executor.submit(self._reconnect) + self._submit(self._reconnect) def _reconnect(self): log.debug("[control connection] Attempting to reconnect") @@ -977,7 +3760,7 @@ def _reconnect(self): self._set_new_connection(self._reconnect_internal()) except NoHostAvailable: # make a retry schedule (which includes backoff) - schedule = self._reconnection_policy.new_schedule() + schedule = self._cluster.reconnection_policy.new_schedule() with self._reconnection_lock: @@ -1008,262 +3791,499 @@ def _get_and_set_reconnection_handler(self, new_handler): self._reconnection_handler = new_handler return old + def _submit(self, *args, **kwargs): + try: + if not self._cluster.is_shutdown: + return self._cluster.executor.submit(*args, **kwargs) + except ReferenceError: + pass + return None + def shutdown(self): + # stop trying to reconnect (if we are) + with self._reconnection_lock: + if self._reconnection_handler: + self._reconnection_handler.cancel() + with self._lock: if self._is_shutdown: return else: self._is_shutdown = True - # stop trying to reconnect (if we are) - if self._reconnection_handler: - self._reconnection_handler.cancel() - - if self._connection: - self._connection.close() + log.debug("Shutting down control connection") + if self._connection: + self._connection.close() + self._connection = None - def refresh_schema(self, keyspace=None, table=None): + def refresh_schema(self, force=False, **kwargs): try: if self._connection: - self._refresh_schema(self._connection, keyspace, table) + return self._refresh_schema(self._connection, force=force, **kwargs) + except ReferenceError: + pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing schema", exc_info=True) self._signal_error() + return False + + def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): + if self._cluster.is_shutdown: + return False - def _refresh_schema(self, connection, keyspace=None, table=None): - self.wait_for_schema_agreement(connection) + agreed = self.wait_for_schema_agreement(connection, + preloaded_results=preloaded_results, + wait_time=schema_agreement_wait) - where_clause = "" - if keyspace: - where_clause = " WHERE keyspace_name = '%s'" % (keyspace,) - if table: - where_clause += " AND columnfamily_name = '%s'" % (table,) + if not self._schema_meta_enabled and not force: + log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") + return False - cl = ConsistencyLevel.ONE - if table: - ks_query = None - else: - ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl) - cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) - col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) - - if ks_query: - ks_result, cf_result, col_result = connection.wait_for_responses(ks_query, cf_query, col_query) - ks_result = dict_factory(*ks_result.results) - cf_result = dict_factory(*cf_result.results) - col_result = dict_factory(*col_result.results) - else: - ks_result = None - cf_result, col_result = connection.wait_for_responses(cf_query, col_query) - cf_result = dict_factory(*cf_result.results) - col_result = dict_factory(*col_result.results) + if not agreed: + log.debug("Skipping schema refresh due to lack of schema agreement") + return False + + self._cluster.metadata.refresh(connection, self._timeout, **kwargs) - self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result) + return True - def refresh_node_list_and_token_map(self): + def refresh_node_list_and_token_map(self, force_token_rebuild=False): try: if self._connection: - self._refresh_node_list_and_token_map(self._connection) + self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) + return True + except ReferenceError: + pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing node list and token map", exc_info=True) self._signal_error() + return False + + def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, + force_token_rebuild=False): + if preloaded_results: + log.debug("[control connection] Refreshing node list and token map using preloaded results") + peers_result = preloaded_results[0] + local_result = preloaded_results[1] + else: + cl = ConsistencyLevel.ONE + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + if not self._token_meta_enabled: + log.debug("[control connection] Refreshing node list without token map") + sel_local = self._SELECT_LOCAL_NO_TOKENS + else: + log.debug("[control connection] Refreshing node list and token map") + sel_local = self._SELECT_LOCAL + peers_query = QueryMessage(query=sel_peers, consistency_level=cl) + local_query = QueryMessage(query=sel_local, consistency_level=cl) + peers_result, local_result = connection.wait_for_responses( + peers_query, local_query, timeout=self._timeout) - def _refresh_node_list_and_token_map(self, connection): - log.debug("[control connection] Refreshing node list and token map") - cl = ConsistencyLevel.ONE - peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl) - local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl) - peers_result, local_result = connection.wait_for_responses(peers_query, local_query) - peers_result = dict_factory(*peers_result.results) - log.debug("[control connection] Got system table results to refresh node list and token map") + peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) partitioner = None token_map = {} - if local_result.results: - local_rows = dict_factory(*(local_result.results)) + found_hosts = set() + if local_result.parsed_rows: + found_hosts.add(connection.endpoint) + local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] cluster_name = local_row["cluster_name"] self._cluster.metadata.cluster_name = cluster_name - host = self._cluster.metadata.get_host(connection.host) - if host: - host.set_location_info(local_row["data_center"], local_row["rack"]) - partitioner = local_row.get("partitioner") tokens = local_row.get("tokens") - if partitioner and tokens: - token_map[host] = tokens - found_hosts = set() + host = self._cluster.metadata.get_host(connection.endpoint) + if host: + datacenter = local_row.get("data_center") + rack = local_row.get("rack") + self._update_location_info(host, datacenter, rack) + host.host_id = local_row.get("host_id") + host.listen_address = local_row.get("listen_address") + host.listen_port = local_row.get("listen_port") + host.broadcast_address = _NodeInfo.get_broadcast_address(local_row) + host.broadcast_port = _NodeInfo.get_broadcast_port(local_row) + + host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row) + host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row) + if host.broadcast_rpc_address is None: + if self._token_meta_enabled: + # local rpc_address is not available, use the connection endpoint + host.broadcast_rpc_address = connection.endpoint.address + host.broadcast_rpc_port = connection.endpoint.port + else: + # local rpc_address has not been queried yet, try to fetch it + # separately, which might fail because C* < 2.1.6 doesn't have rpc_address + # in system.local. See CASSANDRA-9436. + local_rpc_address_query = QueryMessage(query=self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, + consistency_level=ConsistencyLevel.ONE) + success, local_rpc_address_result = connection.wait_for_response( + local_rpc_address_query, timeout=self._timeout, fail_on_error=False) + if success: + row = dict_factory( + local_rpc_address_result.column_names, + local_rpc_address_result.parsed_rows) + host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0]) + host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0]) + else: + host.broadcast_rpc_address = connection.endpoint.address + host.broadcast_rpc_port = connection.endpoint.port + + host.release_version = local_row.get("release_version") + host.dse_version = local_row.get("dse_version") + host.dse_workload = local_row.get("workload") + host.dse_workloads = local_row.get("workloads") + + if partitioner and tokens: + token_map[host] = tokens + + # Check metadata.partitioner to see if we haven't built anything yet. If + # every node in the cluster was in the contact points, we won't discover + # any new nodes, so we need this additional check. (See PYTHON-90) + should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None for row in peers_result: - addr = row.get("rpc_address") + if not self._is_valid_peer(row): + log.warning( + "Found an invalid row for peer (%s). Ignoring host." % + _NodeInfo.get_broadcast_rpc_address(row)) + continue + + endpoint = self._cluster.endpoint_factory.create(row) - # TODO handle ipv6 equivalent - if not addr or addr == "0.0.0.0": - addr = row.get("peer") + if endpoint in found_hosts: + log.warning("Found multiple hosts with the same endpoint (%s). Excluding peer %s", endpoint, row.get("peer")) + continue - found_hosts.add(addr) + found_hosts.add(endpoint) - host = self._cluster.metadata.get_host(addr) + host = self._cluster.metadata.get_host(endpoint) + datacenter = row.get("data_center") + rack = row.get("rack") if host is None: - log.debug("[control connection] Found new host to connect to: %s" % (addr,)) - host = self._cluster.add_host(addr, signal=True) - host.set_location_info(row.get("data_center"), row.get("rack")) - - tokens = row.get("tokens") - if partitioner and tokens: + log.debug("[control connection] Found new host to connect to: %s", endpoint) + host, _ = self._cluster.add_host(endpoint, datacenter, rack, signal=True, refresh_nodes=False) + should_rebuild_token_map = True + else: + should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + + host.host_id = row.get("host_id") + host.broadcast_address = _NodeInfo.get_broadcast_address(row) + host.broadcast_port = _NodeInfo.get_broadcast_port(row) + host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row) + host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row) + host.release_version = row.get("release_version") + host.dse_version = row.get("dse_version") + host.dse_workload = row.get("workload") + host.dse_workloads = row.get("workloads") + + tokens = row.get("tokens", None) + if partitioner and tokens and self._token_meta_enabled: token_map[host] = tokens for old_host in self._cluster.metadata.all_hosts(): - if old_host.address != connection.host and \ - old_host.address not in found_hosts: + if old_host.endpoint.address != connection.endpoint and old_host.endpoint not in found_hosts: + should_rebuild_token_map = True + log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) self._cluster.remove_host(old_host) - if partitioner: + log.debug("[control connection] Finished fetching ring info") + if partitioner and should_rebuild_token_map: + log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) + @staticmethod + def _is_valid_peer(row): + return bool(_NodeInfo.get_broadcast_rpc_address(row) and row.get("host_id") and + row.get("data_center") and row.get("rack") and + ('tokens' not in row or row.get('tokens'))) + + def _update_location_info(self, host, datacenter, rack): + if host.datacenter == datacenter and host.rack == rack: + return False + + # If the dc/rack information changes, we need to update the load balancing policy. + # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes + # that the policy will update correctly, but in practice this should work. + self._cluster.profile_manager.on_down(host) + host.set_location_info(datacenter, rack) + self._cluster.profile_manager.on_up(host) + return True + + def _delay_for_event_type(self, event_type, delay_window): + # this serves to order processing correlated events (received within the window) + # the window and randomization still have the desired effect of skew across client instances + next_time = self._event_schedule_times.get(event_type, 0) + now = self._time.time() + if now <= next_time: + this_time = next_time + 0.01 + delay = this_time - now + else: + delay = random() * delay_window + this_time = now + delay + self._event_schedule_times[event_type] = this_time + return delay + + def _refresh_nodes_if_not_up(self, host): + """ + Used to mitigate refreshes for nodes that are already known. + Some versions of the server send superfluous NEW_NODE messages in addition to UP events. + """ + if not host or not host.is_up: + self.refresh_node_list_and_token_map() + def _handle_topology_change(self, event): change_type = event["change_type"] addr, port = event["address"] - if change_type == "NEW_NODE": - self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True) + host = self._cluster.metadata.get_host(addr, port) + if change_type == "NEW_NODE" or change_type == "MOVED_NODE": + if self._topology_event_refresh_window >= 0: + delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) + self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, host) elif change_type == "REMOVED_NODE": - host = self._cluster.metadata.get_host(addr) - self._cluster.scheduler.schedule(0, self._cluster.remove_host, host) - elif change_type == "MOVED_NODE": - self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map) + self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) def _handle_status_change(self, event): change_type = event["change_type"] addr, port = event["address"] - host = self._cluster.metadata.get_host(addr) + host = self._cluster.metadata.get_host(addr, port) if change_type == "UP": + delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node - self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True) + self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: - self._cluster.scheduler.schedule(1, host.monitor.set_up) + self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. - # But it is unlikely, and don't have too much consequence since we'll try reconnecting + # This is unlikely, and will not have much consequence because we'll try reconnecting # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: - self._cluster.scheduler.schedule(1, host.monitor.set_down) + # this will be run by the scheduler + self._cluster.on_down(host, is_host_addition=False) def _handle_schema_change(self, event): - keyspace = event['keyspace'] or None - table = event['table'] or None - if event['change_type'] in ("CREATED", "DROPPED"): - keyspace = keyspace if table else None - self._cluster.executor.submit(self.refresh_schema, keyspace) - elif event['change_type'] == "UPDATED": - self._cluster.executor.submit(self.refresh_schema, keyspace, table) - - def wait_for_schema_agreement(self, connection=None): + if self._schema_event_refresh_window < 0: + return + delay = self._delay_for_event_type('schema_change', self._schema_event_refresh_window) + self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) + + def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): + + total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait + if total_timeout <= 0: + return True + # Each schema change typically generates two schema refreshes, one # from the response type and one from the pushed notification. Holding # a lock is just a simple way to cut down on the number of schema queries # we'll make. with self._schema_agreement_lock: - log.debug("[control connection] Waiting for schema agreement") + if self._is_shutdown: + return + if not connection: connection = self._connection - start = self._time.time() - elapsed = 0 - cl = ConsistencyLevel.ONE - while elapsed < self._cluster.max_schema_agreement_wait: - peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) - local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) - peers_result, local_result = connection.wait_for_responses(peers_query, local_query) - peers_result = dict_factory(*peers_result.results) + if preloaded_results: + log.debug("[control connection] Attempting to use preloaded results for schema agreement") + + peers_result = preloaded_results[0] + local_result = preloaded_results[1] + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) + if schema_mismatches is None: + return True + + log.debug("[control connection] Waiting for schema agreement") + start = self._time.time() + elapsed = 0 + cl = ConsistencyLevel.ONE + schema_mismatches = None + select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection) + + while elapsed < total_timeout: + peers_query = QueryMessage(query=select_peers_query, consistency_level=cl) + local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) + try: + timeout = min(self._timeout, total_timeout - elapsed) + peers_result, local_result = connection.wait_for_responses( + peers_query, local_query, timeout=timeout) + except OperationTimedOut as timeout: + log.debug("[control connection] Timed out waiting for " + "response during schema agreement check: %s", timeout) + elapsed = self._time.time() - start + continue + except ConnectionShutdown: + if self._is_shutdown: + log.debug("[control connection] Aborting wait for schema match due to shutdown") + return None + else: + raise + + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) + if schema_mismatches is None: + return True + + log.debug("[control connection] Schemas mismatched, trying again") + self._time.sleep(0.2) + elapsed = self._time.time() - start + + log.warning("Node %s is reporting a schema disagreement: %s", + connection.endpoint, schema_mismatches) + return False + + def _get_schema_mismatches(self, peers_result, local_result, local_address): + peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) + + versions = defaultdict(set) + if local_result.parsed_rows: + local_row = dict_factory(local_result.column_names, local_result.parsed_rows)[0] + if local_row.get("schema_version"): + versions[local_row.get("schema_version")].add(local_address) + + for row in peers_result: + schema_ver = row.get('schema_version') + if not schema_ver: + continue + endpoint = self._cluster.endpoint_factory.create(row) + peer = self._cluster.metadata.get_host(endpoint) + if peer and peer.is_up is not False: + versions[schema_ver].add(endpoint) + + if len(versions) == 1: + log.debug("[control connection] Schemas match") + return None - versions = set() - if local_result.results: - local_row = dict_factory(*local_result.results)[0] - if local_row.get("schema_version"): - versions.add(local_row.get("schema_version")) + return dict((version, list(nodes)) for version, nodes in versions.items()) - for row in peers_result: - if not row.get("rpc_address") or not row.get("schema_version"): - continue + def _get_peers_query(self, peers_query_type, connection=None): + """ + Determine the peers query to use. - rpc = row.get("rpc_address") - if rpc == "0.0.0.0": # TODO ipv6 check - rpc = row.get("peer") + :param peers_query_type: Should be one of PeersQueryType enum. - peer = self._cluster.metadata.get_host(rpc) - if peer and peer.monitor.is_up: - versions.add(row.get("schema_version")) + If _uses_peers_v2 is True, return the proper peers_v2 query (no templating). + Else, apply the logic below to choose the peers v1 address column name: - if len(versions) == 1: - return True + Given a connection: - log.debug("[control connection] Schemas mismatched, trying again") - self._time.sleep(0.2) - elapsed = self._time.time() - start + - find the server product version running on the connection's host, + - use that to choose the column name for the transport address (see APOLLO-1130), and + - use that column name in the provided peers query template. + """ + if peers_query_type not in (self.PeersQueryType.PEERS, self.PeersQueryType.PEERS_SCHEMA): + raise ValueError("Invalid peers query type: %s" % peers_query_type) - return False + if self._uses_peers_v2: + if peers_query_type == self.PeersQueryType.PEERS: + query = self._SELECT_PEERS_V2 if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS_V2 + else: + query = self._SELECT_SCHEMA_PEERS_V2 + else: + if peers_query_type == self.PeersQueryType.PEERS and self._token_meta_enabled: + query = self._SELECT_PEERS + else: + query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE + if peers_query_type == self.PeersQueryType.PEERS_SCHEMA + else self._SELECT_PEERS_NO_TOKENS_TEMPLATE) + + host_release_version = self._cluster.metadata.get_host(connection.endpoint).release_version + host_dse_version = self._cluster.metadata.get_host(connection.endpoint).dse_version + uses_native_address_query = ( + host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION) + + if uses_native_address_query: + query = query_template.format(nt_col_name="native_transport_address") + elif host_release_version: + query = query_template.format(nt_col_name="rpc_address") + else: + query = self._SELECT_PEERS + + return query def _signal_error(self): - # try just signaling the host monitor, as this will trigger a reconnect - # as part of marking the host down - if self._connection and self._connection.is_defunct: - host = self._cluster.metadata.get_host(self._connection.host) - # host may be None if it's already been removed, but that indicates - # that errors have already been reported, so we're fine - if host: - host.monitor.signal_connection_failure(self._connection.last_error) + with self._lock: + if self._is_shutdown: return + # try just signaling the cluster, as this will trigger a reconnect + # as part of marking the host down + if self._connection and self._connection.is_defunct: + host = self._cluster.metadata.get_host(self._connection.endpoint) + # host may be None if it's already been removed, but that indicates + # that errors have already been reported, so we're fine + if host: + self._cluster.signal_connection_failure( + host, self._connection.last_error, is_host_addition=False) + return + # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() - @property - def is_open(self): - conn = self._connection - return bool(conn and conn.is_open) - def on_up(self, host): - log.debug("[control connection] Host %s is considered up" % (host,)) - self._balancing_policy.on_up(host) + pass def on_down(self, host): - log.debug("[control connection] Host %s is considered down" % (host,)) - self._balancing_policy.on_down(host) conn = self._connection - if conn and conn.host == host.address and \ + if conn and conn.endpoint == host.endpoint and \ self._reconnection_handler is None: + log.debug("[control connection] Control connection host (%s) is " + "considered down, starting reconnection", host) + # this will result in a task being submitted to the executor to reconnect self.reconnect() - def on_add(self, host): - log.debug("[control connection] Adding host %r and refreshing topology" % (host,)) - self._balancing_policy.on_add(host) - self.refresh_node_list_and_token_map() + def on_add(self, host, refresh_nodes=True): + if refresh_nodes: + self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): - log.debug("[control connection] Removing host %r and refreshing topology" % (host,)) - self._balancing_policy.on_remove(host) - self.refresh_node_list_and_token_map() + c = self._connection + if c and c.endpoint == host.endpoint: + log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) + # refresh will be done on reconnect + self.reconnect() + else: + self.refresh_node_list_and_token_map(force_token_rebuild=True) + + def get_connections(self): + c = getattr(self, '_connection', None) + return [c] if c else [] + + def return_connection(self, connection): + if connection is self._connection and (connection.is_defunct or connection.is_closed): + self.reconnect() + +def _stop_scheduler(scheduler, thread): + try: + if not scheduler.is_shutdown: + scheduler.shutdown() + except ReferenceError: + pass + + thread.join() -class _Scheduler(object): - _scheduled = None +class _Scheduler(Thread): + + _queue = None + _scheduled_tasks = None _executor = None is_shutdown = False def __init__(self, executor): - self._scheduled = Queue.PriorityQueue() + self._queue = queue.PriorityQueue() + self._scheduled_tasks = set() + self._count = count() self._executor = executor - t = Thread(target=self.run, name="Task Scheduler") - t.daemon = True - t.start() + Thread.__init__(self, name="Task Scheduler") + self.daemon = True + self.start() def shutdown(self): try: @@ -1272,13 +4292,26 @@ def shutdown(self): # this can happen on interpreter shutdown pass self.is_shutdown = True + self._queue.put_nowait((0, 0, None)) + self.join() def schedule(self, delay, fn, *args, **kwargs): + self._insert_task(delay, (fn, args, tuple(kwargs.items()))) + + def schedule_unique(self, delay, fn, *args, **kwargs): + task = (fn, args, tuple(kwargs.items())) + if task not in self._scheduled_tasks: + self._insert_task(delay, task) + else: + log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + + def _insert_task(self, delay, task): if not self.is_shutdown: run_at = time.time() + delay - self._scheduled.put_nowait((run_at, (fn, args, kwargs))) + self._scheduled_tasks.add(task) + self._queue.put_nowait((run_at, next(self._count), task)) else: - log.debug("Ignoring scheduled function after shutdown: %r" % fn) + log.debug("Ignoring scheduled task after shutdown: %r", task) def run(self): while True: @@ -1287,34 +4320,45 @@ def run(self): try: while True: - run_at, task = self._scheduled.get(block=True, timeout=None) + run_at, i, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: - log.debug("Not executing scheduled task due to Scheduler shutdown") + if task: + log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): + self._scheduled_tasks.discard(task) fn, args, kwargs = task - self._executor.submit(fn, *args, **kwargs) + kwargs = dict(kwargs) + future = self._executor.submit(fn, *args, **kwargs) + future.add_done_callback(self._log_if_failed) else: - self._scheduled.put_nowait((run_at, task)) + self._queue.put_nowait((run_at, i, task)) break - except Queue.Empty: + except queue.Empty: pass time.sleep(0.1) + def _log_if_failed(self, future): + exc = future.exception() + if exc: + log.warning( + "An internally scheduled tasked failed with an unhandled exception:", + exc_info=exc) -def refresh_schema_and_set_result(keyspace, table, control_conn, response_future): + +def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: - control_conn.refresh_schema(keyspace, table) + log.debug("Refreshing schema in response to schema change. " + "%s", kwargs) + response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") + response_future.session.submit(control_conn.refresh_schema, **kwargs) finally: response_future._set_final_result(None) -_NO_RESULT_YET = object() - - class ResponseFuture(object): """ An asynchronous response delivery mechanism that is returned from calls @@ -1326,49 +4370,204 @@ class ResponseFuture(object): :meth:`.add_callback()`, :meth:`.add_errback()`, and :meth:`.add_callbacks()`. """ + + query = None + """ + The :class:`~.Statement` instance that is being executed through this + :class:`.ResponseFuture`. + """ + + is_schema_agreed = True + """ + For DDL requests, this may be set ``False`` if the schema agreement poll after the response fails. + + Always ``True`` for non-DDL requests. + """ + + request_encoded_size = None + """ + Size of the request message sent + """ + + coordinator_host = None + """ + The host from which we received a response + """ + + attempted_hosts = None + """ + A list of hosts tried, including all speculative executions, retries, and pages + """ + session = None row_factory = None message = None - query = None + default_timeout = None + + _retry_policy = None + _profile_manager = None _req_id = None - _final_result = _NO_RESULT_YET + _final_result = _NOT_SET + _col_names = None + _col_types = None _final_exception = None - _query_trace = None - _callback = None - _errback = None + _query_traces = None + _callbacks = None + _errbacks = None _current_host = None - _current_pool = None _connection = None _query_retries = 0 _start_time = None _metrics = None - - def __init__(self, session, message, query, metrics=None): + _paging_state = None + _custom_payload = None + _warnings = None + _timer = None + _protocol_handler = ProtocolHandler + _spec_execution_plan = NoSpeculativeExecutionPlan() + _continuous_paging_options = None + _continuous_paging_session = None + _host = None + + _warned_timeout = False + + def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, + retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, + speculative_execution_plan=None, continuous_paging_state=None, host=None): self.session = session - self.row_factory = session.row_factory + # TODO: normalize handling of retry policy and row factory + self.row_factory = row_factory or session.row_factory + self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy self.message = message self.query = query + self.timeout = timeout + self._retry_policy = retry_policy self._metrics = metrics - if metrics is not None: - self._start_time = time.time() - - # convert the list/generator/etc to an iterator so that subsequent - # calls to send_request (which retries may do) will resume where - # they last left off - self.query_plan = iter(session._load_balancer.make_query_plan( - session.keyspace, query)) - + self.prepared_statement = prepared_statement + self._callback_lock = Lock() + self._start_time = start_time or time.time() + self._host = host + self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan + self._make_query_plan() self._event = Event() self._errors = {} + self._callbacks = [] + self._errbacks = [] + self.attempted_hosts = [] + self._start_timer() + self._continuous_paging_state = continuous_paging_state - def __del__(self): - try: - del self.session - except AttributeError: - pass + @property + def _time_remaining(self): + if self.timeout is None: + return None + return (self._start_time + self.timeout) - time.time() + + def _start_timer(self): + if self._timer is None: + spec_delay = self._spec_execution_plan.next_execution(self._current_host) + if spec_delay >= 0: + if self._time_remaining is None or self._time_remaining > spec_delay: + self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) + return + if self._time_remaining is not None: + self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) + + def _cancel_timer(self): + if self._timer: + self._timer.cancel() + + def _on_timeout(self, _attempts=0): + """ + Called when the request associated with this ResponseFuture times out. + + This function may reschedule itself. The ``_attempts`` parameter tracks + the number of times this has happened. This parameter should only be + set in those cases, where ``_on_timeout`` reschedules itself. + """ + # PYTHON-853: for short timeouts, we sometimes race with our __init__ + if self._connection is None and _attempts < 3: + self._timer = self.session.cluster.connection_class.create_timer( + 0.01, + partial(self._on_timeout, _attempts=_attempts + 1) + ) + return + + if self._connection is not None: + try: + self._connection._requests.pop(self._req_id) + # PYTHON-1044 + # This request might have been removed from the connection after the latter was defunct by heartbeat. + # We should still raise OperationTimedOut to reject the future so that the main event thread will not + # wait for it endlessly + except KeyError: + key = "Connection defunct by heartbeat" + errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + self._set_final_exception(OperationTimedOut(errors, self._current_host)) + return - def send_request(self): + pool = self.session._pools.get(self._current_host) + if pool and not pool.is_shutdown: + # Do not return the stream ID to the pool yet. We cannot reuse it + # because the node might still be processing the query and will + # return a late response to that query - if we used such stream + # before the response to the previous query has arrived, the new + # query could get a response from the old query + with self._connection.lock: + self._connection.orphaned_request_ids.add(self._req_id) + if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold: + self._connection.orphaned_threshold_reached = True + + pool.return_connection(self._connection, stream_was_orphaned=True) + + errors = self._errors + if not errors: + if self.is_schema_agreed: + key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' + errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + else: + connection = self.session.cluster.control_connection._connection + host = str(connection.endpoint) if connection else 'unknown' + errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} + + self._set_final_exception(OperationTimedOut(errors, self._current_host)) + + def _on_speculative_execute(self): + self._timer = None + if not self._event.is_set(): + + # PYTHON-836, the speculative queries must be after + # the query is sent from the main thread, otherwise the + # query from the main thread may raise NoHostAvailable + # if the _query_plan has been exhausted by the speculative queries. + # This also prevents a race condition accessing the iterator. + # We reschedule this call until the main thread has succeeded + # making a query + if not self.attempted_hosts: + self._timer = self.session.cluster.connection_class.create_timer(0.01, self._on_speculative_execute) + return + + if self._time_remaining is not None: + if self._time_remaining <= 0: + self._on_timeout() + return + self.send_request(error_no_hosts=False) + self._start_timer() + + def _make_query_plan(self): + # set the query_plan according to the load balancing policy, + # or to the explicit host target if set + if self._host: + # returning a single value effectively disables retries + self.query_plan = [self._host] + else: + # convert the list/generator/etc to an iterator so that subsequent + # calls to send_request (which retries may do) will resume where + # they last left off + self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) + + def send_request(self, error_no_hosts=True): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times @@ -1376,68 +4575,186 @@ def send_request(self): req_id = self._query(host) if req_id is not None: self._req_id = req_id - return - - self._final_exception = NoHostAvailable( - "Unable to complete the operation against any hosts", self._errors) + return True + if self.timeout is not None and time.time() - self._start_time > self.timeout: + self._on_timeout() + return True + if error_no_hosts: + self._set_final_exception(NoHostAvailable( + "Unable to complete the operation against any hosts", self._errors)) + return False + + def _query(self, host, message=None, cb=None): + if message is None: + message = self.message - def _query(self, host): pool = self.session._pools.get(host) - if not pool or pool.is_shutdown: + if not pool: + self._errors[host] = ConnectionException("Host has been marked down or removed") + return None + elif pool.is_shutdown: self._errors[host] = ConnectionException("Pool is shutdown") return None + self._current_host = host + connection = None try: # TODO get connectTimeout from cluster settings - connection = pool.borrow_connection(timeout=2.0) - request_id = connection.send_msg(self.message, cb=self._set_result) - self._current_host = host - self._current_pool = pool + connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + + if cb is None: + cb = partial(self._set_result, host, connection, pool) + + self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, + encoder=self._protocol_handler.encode_message, + decoder=self._protocol_handler.decode_message, + result_metadata=result_meta) + self.attempted_hosts.append(host) return request_id + except NoConnectionsAvailable as exc: + log.debug("All connections for host %s are at capacity, moving to the next host", host) + self._errors[host] = exc + except ConnectionBusy as exc: + log.debug("Connection for host %s is busy, moving to the next host", host) + self._errors[host] = exc except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc + if self._metrics is not None: + self._metrics.on_connection_error() if connection: pool.return_connection(connection) - return None - def _set_result(self, response): + return None + + @property + def has_more_pages(self): + """ + Returns :const:`True` if there are more pages left in the + query results, :const:`False` otherwise. This should only + be checked after the first page has been returned. + + .. versionadded:: 2.0.0 + """ + return self._paging_state is not None + + @property + def warnings(self): + """ + Warnings returned from the server, if any. This will only be + set for protocol_version 4+. + + Warnings may be returned for such things as oversized batches, + or too many tombstones in slice queries. + + Ensure the future is complete before trying to access this property + (call :meth:`.result()`, or after callback is invoked). + Otherwise, it may throw if the response has not been received. + """ + # TODO: When timers are introduced, just make this wait + if not self._event.is_set(): + raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") + return self._warnings + + @property + def custom_payload(self): + """ + The custom payload returned from the server, if any. This will only be + set by Cassandra servers implementing a custom QueryHandler, and only + for protocol_version 4+. + + Ensure the future is complete before trying to access this property + (call :meth:`.result()`, or after callback is invoked). + Otherwise, it may throw if the response has not been received. + + :return: :ref:`custom_payload`. + """ + # TODO: When timers are introduced, just make this wait + if not self._event.is_set(): + raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") + return self._custom_payload + + def start_fetching_next_page(self): + """ + If there are more pages left in the query result, this asynchronously + starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted` + is raised. Also see :attr:`.has_more_pages`. + + This should only be called after the first page has been returned. + + .. versionadded:: 2.0.0 + """ + if not self._paging_state: + raise QueryExhausted() + + self._make_query_plan() + self.message.paging_state = self._paging_state + self._event.clear() + self._final_result = _NOT_SET + self._final_exception = None + self._start_timer() + self.send_request() + + def _reprepare(self, prepare_message, host, connection, pool): + cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) + request_id = self._query(host, prepare_message, cb=cb) + if request_id is None: + # try to submit the original prepared statement on some other host + self.send_request() + + def _set_result(self, host, connection, pool, response): try: - if self._current_pool and self._connection: - self._current_pool.return_connection(self._connection) + self.coordinator_host = host + if pool: + pool.return_connection(connection) trace_id = getattr(response, 'trace_id', None) if trace_id: - self._query_trace = QueryTrace(trace_id, self.session) + if not self._query_traces: + self._query_traces = [] + self._query_traces.append(QueryTrace(trace_id, self.session)) + + self._warnings = getattr(response, 'warnings', None) + self._custom_payload = getattr(response, 'custom_payload', None) if isinstance(response, ResultMessage): - if response.kind == ResultMessage.KIND_SET_KEYSPACE: + if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) + # since we're running on the event loop thread, we need to + # use a non-blocking method for setting the keyspace on + # all connections in this session, otherwise the event + # loop thread will deadlock waiting for keyspaces to be + # set. This uses a callback chain which ends with + # self._set_keyspace_completed() being called in the + # event loop thread. if session: - session.keyspace = response.results - self._set_final_result(None) - elif response.kind == ResultMessage.KIND_SCHEMA_CHANGE: + session._set_keyspace_for_all_pools( + response.new_keyspace, self._set_keyspace_completed) + elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread + self.is_schema_agreed = False self.session.submit( refresh_schema_and_set_result, - response.results['keyspace'], - response.results['table'], self.session.cluster.control_connection, - self) + self, connection, **response.schema_change_event) + elif response.kind == RESULT_KIND_ROWS: + self._paging_state = response.paging_state + self._col_names = response.column_names + self._col_types = response.column_types + if getattr(self.message, 'continuous_paging_options', None): + self._handle_continuous_paging_first_response(connection, response) + else: + self._set_final_result(self.row_factory(response.column_names, response.parsed_rows)) + elif response.kind == RESULT_KIND_VOID: + self._set_final_result(None) else: - results = getattr(response, 'results', None) - if results is not None and response.kind == ResultMessage.KIND_ROWS: - results = self.row_factory(*results) - self._set_final_result(results) + self._set_final_result(response) elif isinstance(response, ErrorMessage): - retry_policy = None - if self.query: - retry_policy = self.query.retry_policy - if not retry_policy: - retry_policy = self.session.cluster.default_retry_policy + retry_policy = self._retry_policy if isinstance(response, ReadTimeoutErrorMessage): if self._metrics is not None: @@ -1454,32 +4771,41 @@ def _set_result(self, response): self._metrics.on_unavailable() retry = retry_policy.on_unavailable( self.query, retry_num=self._query_retries, **response.info) - elif isinstance(response, OverloadedErrorMessage): - if self._metrics is not None: - self._metrics.on_other_error() - # need to retry against a different host here - log.warn("Host %s is overloaded, retrying against a different " - "host" % (self._current_host)) - self._retry(reuse_connection=False, consistency_level=None) - return - elif isinstance(response, IsBootstrappingErrorMessage): + elif isinstance(response, (OverloadedErrorMessage, + IsBootstrappingErrorMessage, + TruncateError, ServerError)): + log.warning("Host %s error: %s.", host, response.summary) if self._metrics is not None: self._metrics.on_other_error() - # need to retry against a different host here - self._retry(reuse_connection=False, consistency_level=None) - return + cl = getattr(self.message, 'consistency_level', None) + retry = retry_policy.on_request_error( + self.query, cl, error=response, + retry_num=self._query_retries) elif isinstance(response, PreparedQueryNotFound): - query_id = response.info + if self.prepared_statement: + query_id = self.prepared_statement.query_id + assert query_id == response.info, \ + "Got different query ID in server response (%s) than we " \ + "had before (%s)" % (response.info, query_id) + else: + query_id = response.info + try: prepared_statement = self.session.cluster._prepared_statements[query_id] except KeyError: - log.error("Tried to execute unknown prepared statement %s" % (query_id.encode('hex'),)) - self._set_final_exception(response) - return + if not self.prepared_statement: + log.error("Tried to execute unknown prepared statement: id=%s", + query_id.encode('hex')) + self._set_final_exception(response) + return + else: + prepared_statement = self.prepared_statement + self.session.cluster._prepared_statements[query_id] = prepared_statement current_keyspace = self._connection.keyspace prepared_keyspace = prepared_statement.keyspace - if current_keyspace != prepared_keyspace: + if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \ + and prepared_keyspace and current_keyspace != prepared_keyspace: self._set_final_exception( ValueError("The Session's current keyspace (%s) does " "not match the keyspace the statement was " @@ -1487,12 +4813,15 @@ def _set_result(self, response): (current_keyspace, prepared_keyspace))) return - prepare_message = PrepareMessage(query=prepared_statement.query_string) + log.debug("Re-preparing unrecognized prepared statement against host %s: %s", + host, prepared_statement.query_string) + prepared_keyspace = prepared_statement.keyspace \ + if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None + prepare_message = PrepareMessage(query=prepared_statement.query_string, + keyspace=prepared_keyspace) # since this might block, run on the executor to avoid hanging # the event loop thread - self.session.submit(self._connection.send_msg, - prepare_message, - cb=self._execute_after_prepare) + self.session.submit(self._reprepare, prepare_message, host, connection, pool) return else: if hasattr(response, 'to_exception'): @@ -1501,22 +4830,16 @@ def _set_result(self, response): self._set_final_exception(response) return - retry_type, consistency = retry - if retry_type is RetryPolicy.RETRY: - self._query_retries += 1 - self._retry(reuse_connection=True, consistency_level=consistency) - elif retry_type is RetryPolicy.RETHROW: - self._set_final_exception(response.to_exception()) - else: # IGNORE - if self._metrics is not None: - self._metrics.on_ignore() - self._set_final_result(None) + self._handle_retry_decision(retry, response, host) elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) - self._retry(reuse_connection=False, consistency_level=None) + cl = getattr(self.message, 'consistency_level', None) + retry = self._retry_policy.on_request_error( + self.query, cl, error=response, retry_num=self._query_retries) + self._handle_retry_decision(retry, response, host) elif isinstance(response, Exception): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) @@ -1525,7 +4848,8 @@ def _set_result(self, response): else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) - exc = ConnectionException(msg, self._current_host) + exc = ConnectionException(msg, host) + self._cancel_timer() self._connection.defunct(exc) self._set_final_exception(exc) except Exception as exc: @@ -1533,70 +4857,161 @@ def _set_result(self, response): log.exception("Unexpected exception while handling result in ResponseFuture:") self._set_final_exception(exc) - def _execute_after_prepare(self, response): + def _handle_continuous_paging_first_response(self, connection, response): + self._continuous_paging_session = connection.new_continuous_paging_session(response.stream_id, + self._protocol_handler.decode_message, + self.row_factory, + self._continuous_paging_state) + self._continuous_paging_session.on_message(response) + self._set_final_result(self._continuous_paging_session.results()) + + def _set_keyspace_completed(self, errors): + if not errors: + self._set_final_result(None) + else: + self._set_final_exception(ConnectionException( + "Failed to set keyspace on all hosts: %s" % (errors,))) + + def _execute_after_prepare(self, host, connection, pool, response): """ Handle the response to our attempt to prepare a statement. If it succeeded, run the original query again against the same host. """ + if pool: + pool.return_connection(connection) + + if self._final_exception: + return + if isinstance(response, ResultMessage): - if response.kind == ResultMessage.KIND_PREPARED: + if response.kind == RESULT_KIND_PREPARED: + if self.prepared_statement: + if self.prepared_statement.query_id != response.query_id: + self._set_final_exception(DriverException( + "ID mismatch while trying to reprepare (expected {expected}, got {got}). " + "This prepared statement won't work anymore. " + "This usually happens when you run a 'USE...' " + "query after the statement was prepared.".format( + expected=hexlify(self.prepared_statement.query_id), got=hexlify(response.query_id) + ) + )) + self.prepared_statement.result_metadata = response.column_metadata + new_metadata_id = response.result_metadata_id + if new_metadata_id is not None: + self.prepared_statement.result_metadata_id = new_metadata_id + # use self._query to re-use the same host and # at the same time properly borrow the connection - self._query(self._current_host) + request_id = self._query(host) + if request_id is None: + # this host errored out, move on to the next + self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " - "on host %s: %s" % (self._host, response))) + "on host %s: %s" % (host, response))) elif isinstance(response, ErrorMessage): - self._set_final_exception(response) + if hasattr(response, 'to_exception'): + self._set_final_exception(response.to_exception()) + else: + self._set_final_exception(response) + elif isinstance(response, ConnectionException): + log.debug("Connection error when preparing statement on host %s: %s", + host, response) + # try again on a different host, preparing again if necessary + self._errors[host] = response + self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " - "statement on host %s: %s" % (self._host, response))) + "statement on host %s: %s" % (host, response))) def _set_final_result(self, response): + self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) - if hasattr(self, 'session'): - try: - del self.session # clear reference cycles - except AttributeError: - pass - self._final_result = response + + with self._callback_lock: + self._final_result = response + # save off current callbacks inside lock for execution outside it + # -- prevents case where _final_result is set, then a callback is + # added and executed on the spot, then executed again as a + # registered callback + to_call = tuple( + partial(fn, response, *args, **kwargs) + for (fn, args, kwargs) in self._callbacks + ) + self._event.set() - if self._callback: - fn, args, kwargs = self._callback - fn(response, *args, **kwargs) + + # apply each callback + for callback_partial in to_call: + callback_partial() def _set_final_exception(self, response): + self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) - try: - del self.session # clear reference cycles - except AttributeError: - pass - self._final_exception = response + + with self._callback_lock: + self._final_exception = response + # save off current errbacks inside lock for execution outside it -- + # prevents case where _final_exception is set, then an errback is + # added and executed on the spot, then executed again as a + # registered errback + to_call = tuple( + partial(fn, response, *args, **kwargs) + for (fn, args, kwargs) in self._errbacks + ) self._event.set() - if self._errback: - fn, args, kwargs = self._errback - fn(response, *args, **kwargs) - def _retry(self, reuse_connection, consistency_level): + # apply each callback + for callback_partial in to_call: + callback_partial() + + def _handle_retry_decision(self, retry_decision, response, host): + + def exception_from_response(response): + if hasattr(response, 'to_exception'): + return response.to_exception() + else: + return response + + retry_type, consistency = retry_decision + if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): + self._query_retries += 1 + reuse = retry_type == RetryPolicy.RETRY + self._retry(reuse, consistency, host) + elif retry_type is RetryPolicy.RETHROW: + self._set_final_exception(exception_from_response(response)) + else: # IGNORE + if self._metrics is not None: + self._metrics.on_ignore() + self._set_final_result(None) + + self._errors[host] = exception_from_response(response) + + def _retry(self, reuse_connection, consistency_level, host): + if self._final_exception: + # the connection probably broke while we were waiting + # to retry the operation + return + if self._metrics is not None: self._metrics.on_retry() if consistency_level is not None: self.message.consistency_level = consistency_level # don't retry on the event loop thread - self.session.submit(self._retry_task, reuse_connection) + self.session.submit(self._retry_task, reuse_connection, host) - def _retry_task(self, reuse_connection): + def _retry_task(self, reuse_connection, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return - if reuse_connection and self._query(self._current_host): + if reuse_connection and self._query(host) is not None: return # otherwise, move onto another host @@ -1606,7 +5021,13 @@ def result(self): """ Return the final result or raise an Exception if errors were encountered. If the final result or error has not been set - yet, this method will block until that time. + yet, this method will block until it is set, or the timeout + set for the request expires. + + Timeout is specified in the Session request execution functions. + If the timeout is exceeded, an :exc:`cassandra.OperationTimedOut` will be raised. + This is a client-side timeout. For more information + about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`. Example usage:: @@ -1621,31 +5042,54 @@ def result(self): ... log.exception("Operation failed:") """ - if self._final_result is not _NO_RESULT_YET: - return self._final_result - elif self._final_exception: - raise self._final_exception + self._event.wait() + if self._final_result is not _NOT_SET: + return ResultSet(self, self._final_result) else: - self._event.wait() - if self._final_result is not _NO_RESULT_YET: - return self._final_result - elif self._final_exception: - raise self._final_exception - else: - assert False # shouldn't get here + raise self._final_exception - def get_query_trace(self): + def get_query_trace_ids(self): """ - Returns the :class:`~.query.QueryTrace` instance representing a trace - of the last attempt for this operation, or :const:`None` if tracing was - not enabled for this query. Note that this may raise an exception if - there are problems retrieving the trace details from Cassandra. + Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data). """ - if not self._query_trace: - return None + return [trace.trace_id for trace in self._query_traces] + + def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): + """ + Fetches and returns the query trace of the last response, or `None` if tracing was + not enabled. + + Note that this may raise an exception if there are problems retrieving the trace + details from Cassandra. If the trace is not available after `max_wait`, + :exc:`cassandra.query.TraceUnavailable` will be raised. + + If the ResponseFuture is not done (async execution) and you try to retrieve the trace, + :exc:`cassandra.query.TraceUnavailable` will be raised. + + `query_cl` is the consistency level used to poll the trace tables. + """ + if self._final_result is _NOT_SET and self._final_exception is None: + raise TraceUnavailable( + "Trace information was not available. The ResponseFuture is not done.") - self._query_trace.populate() - return self._query_trace + if self._query_traces: + return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) + + def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): + """ + Fetches and returns the query traces for all query pages, if tracing was enabled. + + See note in :meth:`~.get_query_trace` regarding possible exceptions. + """ + if self._query_traces: + return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] + return [] + + def _get_query_trace(self, i, max_wait, query_cl): + trace = self._query_traces[i] + if not trace.events: + trace.populate(max_wait=max_wait, query_cl=query_cl) + return trace def add_callback(self, fn, *args, **kwargs): """ @@ -1662,6 +5106,15 @@ def add_callback(self, fn, *args, **kwargs): If the final result has already been seen when this method is called, the callback will be called immediately (before this method returns). + Note: in the case that the result is not available when the callback is added, + the callback is executed by IO event thread. This means that the callback + should not block or attempt further synchronous requests, because no further + IO will be processed until the callback returns. + + **Important**: if the callback you attach results in an exception being + raised, **the exception will be ignored**, so please ensure your + callback handles all error cases that you care about. + Usage example:: >>> session = cluster.connect("mykeyspace") @@ -1675,10 +5128,16 @@ def add_callback(self, fn, *args, **kwargs): >>> future.add_callback(handle_results, time.time(), should_log=True) """ - if self._final_result is not _NO_RESULT_YET: + run_now = False + with self._callback_lock: + # Always add fn to self._callbacks, even when we're about to + # execute it, to prevent races with functions like + # start_fetching_next_page that reset _final_result + self._callbacks.append((fn, args, kwargs)) + if self._final_result is not _NOT_SET: + run_now = True + if run_now: fn(self._final_result, *args, **kwargs) - else: - self._callback = (fn, args, kwargs) return self def add_errback(self, fn, *args, **kwargs): @@ -1687,10 +5146,16 @@ def add_errback(self, fn, *args, **kwargs): An Exception instance will be passed as the first positional argument to `fn`. """ - if self._final_exception: + run_now = False + with self._callback_lock: + # Always add fn to self._errbacks, even when we're about to execute + # it, to prevent races with functions like start_fetching_next_page + # that reset _final_exception + self._errbacks.append((fn, args, kwargs)) + if self._final_exception: + run_now = True + if run_now: fn(self._final_exception, *args, **kwargs) - else: - self._errback = (fn, args, kwargs) return self def add_callbacks(self, callback, errback, @@ -1721,7 +5186,233 @@ def add_callbacks(self, callback, errback, self.add_callback(callback, *callback_args, **(callback_kwargs or {})) self.add_errback(errback, *errback_args, **(errback_kwargs or {})) + def clear_callbacks(self): + with self._callback_lock: + self._callbacks = [] + self._errbacks = [] + def __str__(self): - query = self.query.query_string - return "" \ - % (query, self._req_id, self._final_result, self._final_exception, self._current_host) + result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result + return "" \ + % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) + __repr__ = __str__ + + +class QueryExhausted(Exception): + """ + Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and + there are no more pages. You can check :attr:`.ResponseFuture.has_more_pages` + before calling to avoid this. + + .. versionadded:: 2.0.0 + """ + pass + + +class ResultSet(object): + """ + An iterator over the rows from a query result. Also supplies basic equality + and indexing methods for backward-compatability. These methods materialize + the entire result set (loading all pages), and should only be used if the + total result size is understood. Warnings are emitted when paged results + are materialized in this fashion. + + You can treat this as a normal iterator over rows:: + + >>> from cassandra.query import SimpleStatement + >>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10) + >>> for user_row in session.execute(statement): + ... process_user(user_row) + + Whenever there are no more rows in the current page, the next page will + be fetched transparently. However, note that it *is* possible for + an :class:`Exception` to be raised while fetching the next page, just + like you might see on a normal call to ``session.execute()``. + """ + + def __init__(self, response_future, initial_response): + self.response_future = response_future + self.column_names = response_future._col_names + self.column_types = response_future._col_types + self._set_current_rows(initial_response) + self._page_iter = None + self._list_mode = False + + @property + def has_more_pages(self): + """ + True if the last response indicated more pages; False otherwise + """ + return self.response_future.has_more_pages + + @property + def current_rows(self): + """ + The list of current page rows. May be empty if the result was empty, + or this is the last page. + """ + return self._current_rows or [] + + def all(self): + """ + Returns all the remaining rows as a list. This is basically + a convenient shortcut to `list(result_set)`. + + This function is not recommended for queries that return a large number of elements. + """ + return list(self) + + def one(self): + """ + Return a single row of the results or None if empty. This is basically + a shortcut to `result_set.current_rows[0]` and should only be used when + you know a query returns a single row. Consider using an iterator if the + ResultSet contains more than one row. + """ + row = None + if self._current_rows: + try: + row = self._current_rows[0] + except TypeError: # generator object is not subscriptable, PYTHON-1026 + row = next(iter(self._current_rows)) + + return row + + def __iter__(self): + if self._list_mode: + return iter(self._current_rows) + self._page_iter = iter(self._current_rows) + return self + + def next(self): + try: + return next(self._page_iter) + except StopIteration: + if not self.response_future.has_more_pages: + if not self._list_mode: + self._current_rows = [] + raise + + if not self.response_future._continuous_paging_session: + self.fetch_next_page() + self._page_iter = iter(self._current_rows) + + # Some servers can return empty pages in this case; Scylla is known to do + # so in some circumstances. Guard against this by recursing to handle + # the next(iter) call. If we have an empty page in that case it will + # get handled by the StopIteration handler when we recurse. + return self.next() + + return next(self._page_iter) + + __next__ = next + + def fetch_next_page(self): + """ + Manually, synchronously fetch the next page. Supplied for manually retrieving pages + and inspecting :meth:`~.current_page`. It is not necessary to call this when iterating + through results; paging happens implicitly in iteration. + """ + if self.response_future.has_more_pages: + self.response_future.start_fetching_next_page() + result = self.response_future.result() + self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form + else: + self._current_rows = [] + + def _set_current_rows(self, result): + if isinstance(result, Mapping): + self._current_rows = [result] if result else [] + return + try: + iter(result) # can't check directly for generator types because cython generators are different + self._current_rows = result + except TypeError: + self._current_rows = [result] if result else [] + + def _fetch_all(self): + self._current_rows = list(self) + self._page_iter = None + + def _enter_list_mode(self, operator): + if self._list_mode: + return + if self._page_iter: + raise RuntimeError("Cannot use %s when results have been iterated." % operator) + if self.response_future.has_more_pages: + log.warning("Using %s on paged results causes entire result set to be materialized.", operator) + self._fetch_all() # done regardless of paging status in case the row factory produces a generator + self._list_mode = True + + def __eq__(self, other): + self._enter_list_mode("equality operator") + return self._current_rows == other + + def __getitem__(self, i): + if i == 0: + warn("ResultSet indexing support will be removed in 4.0. Consider using " + "ResultSet.one() to get a single row.", DeprecationWarning) + self._enter_list_mode("index operator") + return self._current_rows[i] + + def __nonzero__(self): + return bool(self._current_rows) + + __bool__ = __nonzero__ + + def get_query_trace(self, max_wait_sec=None): + """ + Gets the last query trace from the associated future. + See :meth:`.ResponseFuture.get_query_trace` for details. + """ + return self.response_future.get_query_trace(max_wait_sec) + + def get_all_query_traces(self, max_wait_sec_per=None): + """ + Gets all query traces from the associated future. + See :meth:`.ResponseFuture.get_all_query_traces` for details. + """ + return self.response_future.get_all_query_traces(max_wait_sec_per) + + def cancel_continuous_paging(self): + try: + self.response_future._continuous_paging_session.cancel() + except AttributeError: + raise DriverException("Attempted to cancel paging with no active session. This is only for requests with ContinuousPagingOptions.") + + @property + def was_applied(self): + """ + For LWT results, returns whether the transaction was applied. + + Result is indeterminate if called on a result that was not an LWT request or on + a :class:`.query.BatchStatement` containing LWT. In the latter case either all the batch + succeeds or fails. + + Only valid when one of the internal row factories is in use. + """ + if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): + raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,)) + + is_batch_statement = isinstance(self.response_future.query, BatchStatement) + if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): + raise RuntimeError("No LWT were present in the BatchStatement") + + if not is_batch_statement and len(self.current_rows) != 1: + raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) + + row = self.current_rows[0] + if isinstance(row, tuple): + return row[0] + else: + return row['[applied]'] + + @property + def paging_state(self): + """ + Server paging state of the query. Can be `None` if the query was not paged. + + The driver treats paging state as opaque, but it may contain primary key data, so applications may want to + avoid sending this to untrusted parties. + """ + return self.response_future._paging_state diff --git a/cassandra/cmurmur3.c b/cassandra/cmurmur3.c new file mode 100644 index 0000000000..4affdad46c --- /dev/null +++ b/cassandra/cmurmur3.c @@ -0,0 +1,255 @@ +/* + * The majority of this code was taken from the python-smhasher library, + * which can be found here: https://github.com/phensley/python-smhasher + * + * That library is under the MIT license with the following copyright: + * + * Copyright (c) 2011 Austin Appleby (Murmur3 routine) + * Copyright (c) 2011 Patrick Hensley (Python wrapper, packaging) + * Copyright DataStax (Minor modifications to match Cassandra's MM3 hashes) + * + */ + +#define PY_SSIZE_T_CLEAN 1 +#include +#include + +#ifdef PYPY_VERSION +#define COMPILING_IN_PYPY 1 +#define COMPILING_IN_CPYTHON 0 +#else +#define COMPILING_IN_PYPY 0 +#define COMPILING_IN_CPYTHON 1 +#endif +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +typedef unsigned char uint8_t; +typedef unsigned long uint32_t; +typedef unsigned __int64 uint64_t; + +typedef char int8_t; +typedef long int32_t; +typedef __int64 int64_t; + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x,y) _rotl(x,y) +#define ROTL64(x,y) _rotl64(x,y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#include + +#define FORCE_INLINE inline __attribute__((always_inline)) + +inline uint32_t rotl32 ( int32_t x, int8_t r ) +{ + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + return (x << r) | ((int32_t) (((uint32_t) x) >> (32 - r))); +} + +inline int64_t rotl64 ( int64_t x, int8_t r ) +{ + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + return (x << r) | ((int64_t) (((uint64_t) x) >> (64 - r))); +} + +#define ROTL32(x,y) rotl32(x,y) +#define ROTL64(x,y) rotl64(x,y) + +#define BIG_CONSTANT(x) (x##LL) + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- +// Block read - if your platform needs to do endian-swapping or can only +// handle aligned reads, do the conversion here + +// TODO 32bit? + +FORCE_INLINE int64_t getblock ( const int64_t * p, int i ) +{ + return p[i]; +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE int64_t fmix ( int64_t k ) +{ + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + k ^= ((uint64_t) k) >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= ((uint64_t) k) >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= ((uint64_t) k) >> 33; + + return k; +} + +int64_t MurmurHash3_x64_128 (const void * key, const int len, + const uint32_t seed) +{ + const int8_t * data = (const int8_t*)key; + const int nblocks = len / 16; + + int64_t h1 = seed; + int64_t h2 = seed; + + int64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); + int64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); + int64_t k1 = 0; + int64_t k2 = 0; + + const int64_t * blocks = (const int64_t *)(data); + const int8_t * tail = (const int8_t*)(data + nblocks*16); + + //---------- + // body + + int i; + for(i = 0; i < nblocks; i++) + { + int64_t k1 = getblock(blocks,i*2+0); + int64_t k2 = getblock(blocks,i*2+1); + + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + } + + //---------- + // tail + switch(len & 15) + { + case 15: k2 ^= ((int64_t) (tail[14])) << 48; + case 14: k2 ^= ((int64_t) (tail[13])) << 40; + case 13: k2 ^= ((int64_t) (tail[12])) << 32; + case 12: k2 ^= ((int64_t) (tail[11])) << 24; + case 11: k2 ^= ((int64_t) (tail[10])) << 16; + case 10: k2 ^= ((int64_t) (tail[ 9])) << 8; + case 9: k2 ^= ((int64_t) (tail[ 8])) << 0; + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + case 8: k1 ^= ((int64_t) (tail[ 7])) << 56; + case 7: k1 ^= ((int64_t) (tail[ 6])) << 48; + case 6: k1 ^= ((int64_t) (tail[ 5])) << 40; + case 5: k1 ^= ((int64_t) (tail[ 4])) << 32; + case 4: k1 ^= ((int64_t) (tail[ 3])) << 24; + case 3: k1 ^= ((int64_t) (tail[ 2])) << 16; + case 2: k1 ^= ((int64_t) (tail[ 1])) << 8; + case 1: k1 ^= ((int64_t) (tail[ 0])) << 0; + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix(h1); + h2 = fmix(h2); + + h1 += h2; + h2 += h1; + + return h1; +} + + +struct module_state { + PyObject *error; +}; + +// pypy3 doesn't have GetState yet. +#if COMPILING_IN_CPYTHON && PY_MAJOR_VERSION >= 3 +#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) +#else +#define GETSTATE(m) (&_state) +static struct module_state _state; +#endif + +static PyObject * +murmur3(PyObject *self, PyObject *args) +{ + const char *key; + Py_ssize_t len; + uint32_t seed = 0; + int64_t result = 0; + + + if (!PyArg_ParseTuple(args, "s#|I", &key, &len, &seed)) { + return NULL; + } + + // TODO handle x86 version? + result = MurmurHash3_x64_128((void *)key, len, seed); + return (PyObject *) PyLong_FromLongLong(result); +} + +static PyMethodDef cmurmur3_methods[] = { + {"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"}, + {NULL, NULL, 0, NULL} +}; + +static int cmurmur3_traverse(PyObject *m, visitproc visit, void *arg) { + Py_VISIT(GETSTATE(m)->error); + return 0; +} + +static int cmurmur3_clear(PyObject *m) { + Py_CLEAR(GETSTATE(m)->error); + return 0; +} + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "cmurmur3", + NULL, + sizeof(struct module_state), + cmurmur3_methods, + NULL, + cmurmur3_traverse, + cmurmur3_clear, + NULL +}; + +#define INITERROR return NULL + +PyObject * +PyInit_cmurmur3(void) + +{ + PyObject *module = PyModule_Create(&moduledef); + struct module_state *st = NULL; + + if (module == NULL) + INITERROR; + st = GETSTATE(module); + + st->error = PyErr_NewException("cmurmur3.Error", NULL, NULL); + if (st->error == NULL) { + Py_DECREF(module); + INITERROR; + } + + return module; +} diff --git a/cassandra/column_encryption/_policies.py b/cassandra/column_encryption/_policies.py new file mode 100644 index 0000000000..e1519f6b79 --- /dev/null +++ b/cassandra/column_encryption/_policies.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from functools import lru_cache + +import logging +import os + +log = logging.getLogger(__name__) + +from cassandra.cqltypes import _cqltypes +from cassandra.policies import ColumnEncryptionPolicy + +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +AES256_BLOCK_SIZE = 128 +AES256_BLOCK_SIZE_BYTES = int(AES256_BLOCK_SIZE / 8) +AES256_KEY_SIZE = 256 +AES256_KEY_SIZE_BYTES = int(AES256_KEY_SIZE / 8) + +ColData = namedtuple('ColData', ['key','type']) + +class AES256ColumnEncryptionPolicy(ColumnEncryptionPolicy): + + # Fix block cipher mode for now. IV size is a function of block cipher used + # so fixing this avoids (possibly unnecessary) validation logic here. + mode = modes.CBC + + # "iv" param here expects a bytearray that's the same size as the block + # size for AES-256 (128 bits or 16 bytes). If none is provided a new one + # will be randomly generated, but in this case the IV should be recorded and + # preserved or else you will not be able to decrypt any data encrypted by this + # policy. + def __init__(self, iv=None): + + # CBC uses an IV that's the same size as the block size + # + # Avoid defining IV with a default arg in order to stay away from + # any issues around the caching of default args + self.iv = iv + if self.iv: + if not len(self.iv) == AES256_BLOCK_SIZE_BYTES: + raise ValueError("This policy uses AES-256 with CBC mode and therefore expects a 128-bit initialization vector") + else: + self.iv = os.urandom(AES256_BLOCK_SIZE_BYTES) + + # ColData for a given ColDesc is always preserved. We only create a Cipher + # when there's an actual need to for a given ColDesc + self.coldata = {} + self.ciphers = {} + + def encrypt(self, coldesc, obj_bytes): + + # AES256 has a 128-bit block size so if the input bytes don't align perfectly on + # those blocks we have to pad them. There's plenty of room for optimization here: + # + # * Instances of the PKCS7 padder should be managed in a bounded pool + # * It would be nice if we could get a flag from encrypted data to indicate + # whether it was padded or not + # * Might be able to make this happen with a leading block of flags in encrypted data + padder = padding.PKCS7(AES256_BLOCK_SIZE).padder() + padded_bytes = padder.update(obj_bytes) + padder.finalize() + + cipher = self._get_cipher(coldesc) + encryptor = cipher.encryptor() + return self.iv + encryptor.update(padded_bytes) + encryptor.finalize() + + def decrypt(self, coldesc, bytes): + + iv = bytes[:AES256_BLOCK_SIZE_BYTES] + encrypted_bytes = bytes[AES256_BLOCK_SIZE_BYTES:] + cipher = self._get_cipher(coldesc, iv=iv) + decryptor = cipher.decryptor() + padded_bytes = decryptor.update(encrypted_bytes) + decryptor.finalize() + + unpadder = padding.PKCS7(AES256_BLOCK_SIZE).unpadder() + return unpadder.update(padded_bytes) + unpadder.finalize() + + def add_column(self, coldesc, key, type): + + if not coldesc: + raise ValueError("ColDesc supplied to add_column cannot be None") + if not key: + raise ValueError("Key supplied to add_column cannot be None") + if not type: + raise ValueError("Type supplied to add_column cannot be None") + if type not in _cqltypes.keys(): + raise ValueError("Type %s is not a supported type".format(type)) + if not len(key) == AES256_KEY_SIZE_BYTES: + raise ValueError("AES256 column encryption policy expects a 256-bit encryption key") + self.coldata[coldesc] = ColData(key, _cqltypes[type]) + + def contains_column(self, coldesc): + return coldesc in self.coldata + + def encode_and_encrypt(self, coldesc, obj): + if not coldesc: + raise ValueError("ColDesc supplied to encode_and_encrypt cannot be None") + if not obj: + raise ValueError("Object supplied to encode_and_encrypt cannot be None") + coldata = self.coldata.get(coldesc) + if not coldata: + raise ValueError("Could not find ColData for ColDesc %s".format(coldesc)) + return self.encrypt(coldesc, coldata.type.serialize(obj, None)) + + def cache_info(self): + return AES256ColumnEncryptionPolicy._build_cipher.cache_info() + + def column_type(self, coldesc): + return self.coldata[coldesc].type + + def _get_cipher(self, coldesc, iv=None): + """ + Access relevant state from this instance necessary to create a Cipher and then get one, + hopefully returning a cached instance if we've already done so (and it hasn't been evicted) + """ + try: + coldata = self.coldata[coldesc] + return AES256ColumnEncryptionPolicy._build_cipher(coldata.key, iv or self.iv) + except KeyError: + raise ValueError("Could not find column {}".format(coldesc)) + + # Explicitly use a class method here to avoid caching self + @lru_cache(maxsize=128) + def _build_cipher(key, iv): + return Cipher(algorithms.AES256(key), AES256ColumnEncryptionPolicy.mode(iv)) diff --git a/cassandra/column_encryption/policies.py b/cassandra/column_encryption/policies.py new file mode 100644 index 0000000000..a1bd25d3e6 --- /dev/null +++ b/cassandra/column_encryption/policies.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import cryptography + from cassandra.column_encryption._policies import * +except ImportError: + # Cryptography is not installed + pass diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py new file mode 100644 index 0000000000..012f52f954 --- /dev/null +++ b/cassandra/concurrent.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from collections import namedtuple +from concurrent.futures import Future +from heapq import heappush, heappop +from itertools import cycle +from threading import Condition + +from cassandra.cluster import ResultSet, EXEC_PROFILE_DEFAULT + +log = logging.getLogger(__name__) + + +ExecutionResult = namedtuple('ExecutionResult', ['success', 'result_or_exc']) + +def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False, execution_profile=EXEC_PROFILE_DEFAULT): + """ + See :meth:`.Session.execute_concurrent`. + """ + if concurrency <= 0: + raise ValueError("concurrency must be greater than 0") + + if not statements_and_parameters: + return [] + + executor = ConcurrentExecutorGenResults(session, statements_and_parameters, execution_profile) \ + if results_generator else ConcurrentExecutorListResults(session, statements_and_parameters, execution_profile) + return executor.execute(concurrency, raise_on_first_error) + + +class _ConcurrentExecutor(object): + + max_error_recursion = 100 + + def __init__(self, session, statements_and_params, execution_profile): + self.session = session + self._enum_statements = enumerate(iter(statements_and_params)) + self._execution_profile = execution_profile + self._condition = Condition() + self._fail_fast = False + self._results_queue = [] + self._current = 0 + self._exec_count = 0 + self._exec_depth = 0 + + def execute(self, concurrency, fail_fast): + self._fail_fast = fail_fast + self._results_queue = [] + self._current = 0 + self._exec_count = 0 + with self._condition: + for n in range(concurrency): + if not self._execute_next(): + break + return self._results() + + def _execute_next(self): + # lock must be held + try: + (idx, (statement, params)) = next(self._enum_statements) + self._exec_count += 1 + self._execute(idx, statement, params) + return True + except StopIteration: + pass + + def _execute(self, idx, statement, params): + self._exec_depth += 1 + try: + future = self.session.execute_async(statement, params, execution_profile=self._execution_profile) + args = (future, idx) + future.add_callbacks( + callback=self._on_success, callback_args=args, + errback=self._on_error, errback_args=args) + except Exception as exc: + # If we're not failing fast and all executions are raising, there is a chance of recursing + # here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry + # and let the event loop thread return. + if self._exec_depth < self.max_error_recursion: + self._put_result(exc, idx, False) + else: + self.session.submit(self._put_result, exc, idx, False) + self._exec_depth -= 1 + + def _on_success(self, result, future, idx): + future.clear_callbacks() + self._put_result(ResultSet(future, result), idx, True) + + def _on_error(self, result, future, idx): + self._put_result(result, idx, False) + + +class ConcurrentExecutorGenResults(_ConcurrentExecutor): + + def _put_result(self, result, idx, success): + with self._condition: + heappush(self._results_queue, (idx, ExecutionResult(success, result))) + self._execute_next() + self._condition.notify() + + def _results(self): + with self._condition: + while self._current < self._exec_count: + while not self._results_queue or self._results_queue[0][0] != self._current: + self._condition.wait() + while self._results_queue and self._results_queue[0][0] == self._current: + _, res = heappop(self._results_queue) + try: + self._condition.release() + if self._fail_fast and not res[0]: + raise res[1] + yield res + finally: + self._condition.acquire() + self._current += 1 + + +class ConcurrentExecutorListResults(_ConcurrentExecutor): + + _exception = None + + def execute(self, concurrency, fail_fast): + self._exception = None + return super(ConcurrentExecutorListResults, self).execute(concurrency, fail_fast) + + def _put_result(self, result, idx, success): + self._results_queue.append((idx, ExecutionResult(success, result))) + with self._condition: + self._current += 1 + if not success and self._fail_fast: + if not self._exception: + self._exception = result + self._condition.notify() + elif not self._execute_next() and self._current == self._exec_count: + self._condition.notify() + + def _results(self): + with self._condition: + while self._current < self._exec_count: + self._condition.wait() + if self._exception and self._fail_fast: + raise self._exception + if self._exception and self._fail_fast: # raise the exception even if there was no wait + raise self._exception + return [r[1] for r in sorted(self._results_queue)] + + + +def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs): + """ + See :meth:`.Session.execute_concurrent_with_args`. + """ + return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs) + + +class ConcurrentExecutorFutureResults(ConcurrentExecutorListResults): + def __init__(self, session, statements_and_params, execution_profile, future): + super().__init__(session, statements_and_params, execution_profile) + self.future = future + + def _put_result(self, result, idx, success): + super()._put_result(result, idx, success) + with self._condition: + if self._current == self._exec_count: + if self._exception and self._fail_fast: + self.future.set_exception(self._exception) + else: + sorted_results = [r[1] for r in sorted(self._results_queue)] + self.future.set_result(sorted_results) + + +def execute_concurrent_async( + session, + statements_and_parameters, + concurrency=100, + raise_on_first_error=False, + execution_profile=EXEC_PROFILE_DEFAULT +): + """ + See :meth:`.Session.execute_concurrent_async`. + """ + # Create a Future object and initialize the custom ConcurrentExecutor with the Future + future = Future() + executor = ConcurrentExecutorFutureResults( + session=session, + statements_and_params=statements_and_parameters, + execution_profile=execution_profile, + future=future + ) + + # Execute concurrently + try: + executor.execute(concurrency=concurrency, fail_fast=raise_on_first_error) + except Exception as e: + future.set_exception(e) + + return future diff --git a/cassandra/connection.py b/cassandra/connection.py index f795b178af..3ceaa08afc 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1,38 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict, deque import errno -from functools import wraps, partial +from functools import wraps, partial, total_ordering +from heapq import heappush, heappop +import io import logging -from threading import Event, Lock, RLock -from Queue import Queue +import socket +import struct +import sys +from threading import Thread, Event, RLock, Condition +import time +import ssl +import weakref -from cassandra import ConsistencyLevel, AuthenticationFailed -from cassandra.marshal import int8_unpack, int32_pack -from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage, - StartupMessage, ErrorMessage, CredentialsMessage, - QueryMessage, ResultMessage, decode_response, - InvalidRequestException, SupportedMessage) + +if 'gevent.monkey' in sys.modules: + from gevent.queue import Queue, Empty +else: + from queue import Queue, Empty # noqa + +from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion +from cassandra.marshal import int32_pack +from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, + StartupMessage, ErrorMessage, CredentialsMessage, + QueryMessage, ResultMessage, ProtocolHandler, + InvalidRequestException, SupportedMessage, + AuthResponseMessage, AuthChallengeMessage, + AuthSuccessMessage, ProtocolException, + RegisterMessage, ReviseRequestMessage) +from cassandra.segment import SegmentCodec, CrcException +from cassandra.util import OrderedDict log = logging.getLogger(__name__) -locally_supported_compressions = {} +segment_codec_no_compression = SegmentCodec() +segment_codec_lz4 = None -try: - import snappy -except ImportError: - pass -else: - # work around apparently buggy snappy decompress - def decompress(byts): - if byts == '\x00': - return '' - return snappy.decompress(byts) - locally_supported_compressions['snappy'] = (snappy.compress, decompress) +# We use an ordered dictionary and specifically add lz4 before +# snappy so that lz4 will be preferred. Changing the order of this +# will change the compression preferences for the driver. +locally_supported_compressions = OrderedDict() try: import lz4 except ImportError: pass else: + # The compress and decompress functions we need were moved from the lz4 to + # the lz4.block namespace, so we try both here. + try: + from lz4 import block as lz4_block + except ImportError: + lz4_block = lz4 + + try: + lz4_block.compress + lz4_block.decompress + except AttributeError: + raise ImportError( + 'lz4 not imported correctly. Imported object should have ' + '.compress and and .decompress attributes but does not. ' + 'Please file a bug report on JIRA. (Imported object was ' + '{lz4_block})'.format(lz4_block=repr(lz4_block)) + ) # Cassandra writes the uncompressed message length in big endian order, # but the lz4 lib requires little endian order, so we wrap these @@ -40,24 +86,313 @@ def decompress(byts): def lz4_compress(byts): # write length in big-endian instead of little-endian - return int32_pack(len(byts)) + lz4.compress(byts)[4:] + return int32_pack(len(byts)) + lz4_block.compress(byts)[4:] def lz4_decompress(byts): # flip from big-endian to little-endian - return lz4.decompress(byts[3::-1] + byts[4:]) + return lz4_block.decompress(byts[3::-1] + byts[4:]) locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress) + segment_codec_lz4 = SegmentCodec(lz4_compress, lz4_decompress) +try: + import snappy +except ImportError: + pass +else: + # work around apparently buggy snappy decompress + def decompress(byts): + if byts == '\x00': + return '' + return snappy.decompress(byts) + locally_supported_compressions['snappy'] = (snappy.compress, decompress) -MAX_STREAM_PER_CONNECTION = 128 +DRIVER_NAME, DRIVER_VERSION = 'Apache Cassandra Python Driver', sys.modules['cassandra'].__version__ -PROTOCOL_VERSION = 0x01 PROTOCOL_VERSION_MASK = 0x7f HEADER_DIRECTION_FROM_CLIENT = 0x00 HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_MASK = 0x80 +frame_header_v1_v2 = struct.Struct('>BbBi') +frame_header_v3 = struct.Struct('>BhBi') + + +class EndPoint(object): + """ + Represents the information to connect to a cassandra node. + """ + + @property + def address(self): + """ + The IP address of the node. This is the RPC address the driver uses when connecting to the node + """ + raise NotImplementedError() + + @property + def port(self): + """ + The port of the node. + """ + raise NotImplementedError() + + @property + def ssl_options(self): + """ + SSL options specific to this endpoint. + """ + return None + + @property + def socket_family(self): + """ + The socket family of the endpoint. + """ + return socket.AF_UNSPEC + + def resolve(self): + """ + Resolve the endpoint to an address/port. This is called + only on socket connection. + """ + raise NotImplementedError() + + +class EndPointFactory(object): + + cluster = None + + def configure(self, cluster): + """ + This is called by the cluster during its initialization. + """ + self.cluster = cluster + return self + + def create(self, row): + """ + Create an EndPoint from a system.peers row. + """ + raise NotImplementedError() + + +@total_ordering +class DefaultEndPoint(EndPoint): + """ + Default EndPoint implementation, basically just an address and port. + """ + + def __init__(self, address, port=9042): + self._address = address + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, DefaultEndPoint) and \ + self.address == other.address and self.port == other.port + + def __hash__(self): + return hash((self.address, self.port)) + + def __lt__(self, other): + return (self.address, self.port) < (other.address, other.port) + + def __str__(self): + return str("%s:%d" % (self.address, self.port)) + + def __repr__(self): + return "<%s: %s:%d>" % (self.__class__.__name__, self.address, self.port) + + +class DefaultEndPointFactory(EndPointFactory): + + port = None + """ + If no port is discovered in the row, this is the default port + used for endpoint creation. + """ + + def __init__(self, port=None): + self.port = port + + def create(self, row): + # TODO next major... move this class so we don't need this kind of hack + from cassandra.metadata import _NodeInfo + addr = _NodeInfo.get_broadcast_rpc_address(row) + port = _NodeInfo.get_broadcast_rpc_port(row) + if port is None: + port = self.port if self.port else 9042 + + # create the endpoint with the translated address + # TODO next major, create a TranslatedEndPoint type + return DefaultEndPoint( + self.cluster.address_translator.translate(addr), + port) + + +@total_ordering +class SniEndPoint(EndPoint): + """SNI Proxy EndPoint implementation.""" + + def __init__(self, proxy_address, server_name, port=9042, init_index=0): + self._proxy_address = proxy_address + self._index = init_index + self._resolved_address = None # resolved address + self._port = port + self._server_name = server_name + self._ssl_options = {'server_hostname': server_name} + + @property + def address(self): + return self._proxy_address + + @property + def port(self): + return self._port + + @property + def ssl_options(self): + return self._ssl_options + + def resolve(self): + try: + resolved_addresses = self._resolve_proxy_addresses() + except socket.gaierror: + log.debug('Could not resolve sni proxy hostname "%s" ' + 'with port %d' % (self._proxy_address, self._port)) + raise + + # round-robin pick + self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[self._index % len(resolved_addresses)] + self._index += 1 + + return self._resolved_address, self._port + + def _resolve_proxy_addresses(self): + return socket.getaddrinfo(self._proxy_address, self._port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + + def __eq__(self, other): + return (isinstance(other, SniEndPoint) and + self.address == other.address and self.port == other.port and + self._server_name == other._server_name) + + def __hash__(self): + return hash((self.address, self.port, self._server_name)) + + def __lt__(self, other): + return ((self.address, self.port, self._server_name) < + (other.address, other.port, self._server_name)) + + def __str__(self): + return str("%s:%d:%s" % (self.address, self.port, self._server_name)) + + def __repr__(self): + return "<%s: %s:%d:%s>" % (self.__class__.__name__, + self.address, self.port, self._server_name) + + +class SniEndPointFactory(EndPointFactory): + + def __init__(self, proxy_address, port): + self._proxy_address = proxy_address + self._port = port + # Initial lookup index to prevent all SNI endpoints to be resolved + # into the same starting IP address (which might not be available currently). + # If SNI resolves to 3 IPs, first endpoint will connect to first + # IP address, and subsequent resolutions to next IPs in round-robin + # fusion. + self._init_index = -1 + + def create(self, row): + host_id = row.get("host_id") + if host_id is None: + raise ValueError("No host_id to create the SniEndPoint") + + self._init_index += 1 + return SniEndPoint(self._proxy_address, str(host_id), self._port, self._init_index) + + def create_from_sni(self, sni): + self._init_index += 1 + return SniEndPoint(self._proxy_address, sni, self._port, self._init_index) + + +@total_ordering +class UnixSocketEndPoint(EndPoint): + """ + Unix Socket EndPoint implementation. + """ + + def __init__(self, unix_socket_path): + self._unix_socket_path = unix_socket_path + + @property + def address(self): + return self._unix_socket_path + + @property + def port(self): + return None + + @property + def socket_family(self): + return socket.AF_UNIX + + def resolve(self): + return self.address, None + + def __eq__(self, other): + return (isinstance(other, UnixSocketEndPoint) and + self._unix_socket_path == other._unix_socket_path) + + def __hash__(self): + return hash(self._unix_socket_path) + + def __lt__(self, other): + return self._unix_socket_path < other._unix_socket_path + + def __str__(self): + return str("%s" % (self._unix_socket_path,)) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) + + +class _Frame(object): + def __init__(self, version, flags, stream, opcode, body_offset, end_pos): + self.version = version + self.flags = flags + self.stream = stream + self.opcode = opcode + self.body_offset = body_offset + self.end_pos = end_pos + + def __eq__(self, other): # facilitates testing + if isinstance(other, _Frame): + return (self.version == other.version and + self.flags == other.flags and + self.stream == other.stream and + self.opcode == other.opcode and + self.body_offset == other.body_offset and + self.end_pos == other.end_pos) + return NotImplemented + + def __str__(self): + return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset) + + NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) @@ -67,18 +402,32 @@ class ConnectionException(Exception): or the connection was already closed or defunct. """ - def __init__(self, message, host=None): + def __init__(self, message, endpoint=None): Exception.__init__(self, message) - self.host = host + self.endpoint = endpoint + + @property + def host(self): + return self.endpoint.address class ConnectionShutdown(ConnectionException): """ - Raised when a connection has been defuncted or closed. + Raised when a connection has been marked as defunct or has been closed. """ pass +class ProtocolVersionUnsupported(ConnectionException): + """ + Server rejected startup message due to unsupported protocol version + """ + def __init__(self, endpoint, startup_version): + msg = "Unsupported protocol version on %s: %d" % (endpoint, startup_version) + super(ProtocolVersionUnsupported, self).__init__(msg, endpoint) + self.startup_version = startup_version + + class ConnectionBusy(Exception): """ An attempt was made to send a message through a :class:`.Connection` that @@ -94,6 +443,165 @@ class ProtocolError(Exception): pass +class CrcMismatchException(ConnectionException): + pass + + +class ContinuousPagingState(object): + """ + A class for specifying continuous paging state, only supported starting with DSE_V2. + """ + + num_pages_requested = None + """ + How many pages we have already requested + """ + + num_pages_received = None + """ + How many pages we have already received + """ + + max_queue_size = None + """ + The max queue size chosen by the user via the options + """ + + def __init__(self, max_queue_size): + self.num_pages_requested = max_queue_size # the initial query requests max_queue_size + self.num_pages_received = 0 + self.max_queue_size = max_queue_size + + +class ContinuousPagingSession(object): + def __init__(self, stream_id, decoder, row_factory, connection, state): + self.stream_id = stream_id + self.decoder = decoder + self.row_factory = row_factory + self.connection = connection + self._condition = Condition() + self._stop = False + self._page_queue = deque() + self._state = state + self.released = False + + def on_message(self, result): + if isinstance(result, ResultMessage): + self.on_page(result) + elif isinstance(result, ErrorMessage): + self.on_error(result) + + def on_page(self, result): + with self._condition: + if self._state: + self._state.num_pages_received += 1 + self._page_queue.appendleft((result.column_names, result.parsed_rows, None)) + self._stop |= result.continuous_paging_last + self._condition.notify() + + if result.continuous_paging_last: + self.released = True + + def on_error(self, error): + if isinstance(error, ErrorMessage): + error = error.to_exception() + + log.debug("Got error %s for session %s", error, self.stream_id) + + with self._condition: + self._page_queue.appendleft((None, None, error)) + self._stop = True + self._condition.notify() + + self.released = True + + def results(self): + try: + self._condition.acquire() + while True: + while not self._page_queue and not self._stop: + self._condition.wait(timeout=5) + while self._page_queue: + names, rows, err = self._page_queue.pop() + if err: + raise err + self.maybe_request_more() + self._condition.release() + for row in self.row_factory(names, rows): + yield row + self._condition.acquire() + if self._stop: + break + finally: + try: + self._condition.release() + except RuntimeError: + # This exception happens if the CP results are not entirely consumed + # and the session is terminated by the runtime. See PYTHON-1054 + pass + + def maybe_request_more(self): + if not self._state: + return + + max_queue_size = self._state.max_queue_size + num_in_flight = self._state.num_pages_requested - self._state.num_pages_received + space_in_queue = max_queue_size - len(self._page_queue) - num_in_flight + log.debug("Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s", + self.stream_id, self.connection.host, space_in_queue, self._state.num_pages_requested, + self._state.num_pages_received, num_in_flight) + + if space_in_queue >= max_queue_size / 2: + self.update_next_pages(space_in_queue) + + def update_next_pages(self, num_next_pages): + try: + self._state.num_pages_requested += num_next_pages + log.debug("Updating backpressure for session %s from %s", self.stream_id, self.connection.host) + with self.connection.lock: + self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, + self.stream_id, + next_pages=num_next_pages), + self.connection.get_request_id(), + self._on_backpressure_response) + except ConnectionShutdown as ex: + log.debug("Failed to update backpressure for session %s from %s, connection is shutdown", + self.stream_id, self.connection.host) + self.on_error(ex) + + def _on_backpressure_response(self, response): + if isinstance(response, ResultMessage): + log.debug("Paging session %s backpressure updated.", self.stream_id) + else: + log.error("Failed updating backpressure for session %s from %s: %s", self.stream_id, self.connection.host, + response.to_exception() if hasattr(response, 'to_exception') else response) + self.on_error(response) + + def cancel(self): + try: + log.debug("Canceling paging session %s from %s", self.stream_id, self.connection.host) + with self.connection.lock: + self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL, + self.stream_id), + self.connection.get_request_id(), + self._on_cancel_response) + except ConnectionShutdown: + log.debug("Failed to cancel session %s from %s, connection is shutdown", + self.stream_id, self.connection.host) + + with self._condition: + self._stop = True + self._condition.notify() + + def _on_cancel_response(self, response): + if isinstance(response, ResultMessage): + log.debug("Paging session %s canceled.", self.stream_id) + else: + log.error("Failed canceling streaming session %s from %s: %s", self.stream_id, self.connection.host, + response.to_exception() if hasattr(response, 'to_exception') else response) + self.released = True + + def defunct_on_error(f): @wraps(f) @@ -102,114 +610,720 @@ def wrapper(self, *args, **kwargs): return f(self, *args, **kwargs) except Exception as exc: self.defunct(exc) - return wrapper +DEFAULT_CQL_VERSION = '3.0.0' + + +class _ConnectionIOBuffer(object): + """ + Abstraction class to ease the use of the different connection io buffers. With + protocol V5 and checksumming, the data is read, validated and copied to another + cql frame buffer. + """ + _io_buffer = None + _cql_frame_buffer = None + _connection = None + _segment_consumed = False + + def __init__(self, connection): + self._io_buffer = io.BytesIO() + self._connection = weakref.proxy(connection) + + @property + def io_buffer(self): + return self._io_buffer + + @property + def cql_frame_buffer(self): + return self._cql_frame_buffer if self.is_checksumming_enabled else \ + self._io_buffer + + def set_checksumming_buffer(self): + self.reset_io_buffer() + self._cql_frame_buffer = io.BytesIO() + + @property + def is_checksumming_enabled(self): + return self._connection._is_checksumming_enabled + + @property + def has_consumed_segment(self): + return self._segment_consumed; + + def readable_io_bytes(self): + return self.io_buffer.tell() + + def readable_cql_frame_bytes(self): + return self.cql_frame_buffer.tell() + + def reset_io_buffer(self): + self._io_buffer = io.BytesIO(self._io_buffer.read()) + self._io_buffer.seek(0, 2) # 2 == SEEK_END + + def reset_cql_frame_buffer(self): + if self.is_checksumming_enabled: + self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read()) + self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END + else: + self.reset_io_buffer() + + class Connection(object): + CALLBACK_ERR_THREAD_THRESHOLD = 100 + in_buffer_size = 4096 out_buffer_size = 4096 cql_version = None + no_compact = False + protocol_version = ProtocolVersion.MAX_SUPPORTED keyspace = None compression = True + _compression_type = None compressor = None decompressor = None + endpoint = None + ssl_options = None + ssl_context = None last_error = None + + # The current number of operations that are in flight. More precisely, + # the number of request IDs that are currently in use. + # This includes orphaned requests. in_flight = 0 + + # Max concurrent requests allowed per connection. This is set optimistically high, allowing + # all request ids to be used in protocol version 3+. Normally concurrency would be controlled + # at a higher level by the application or concurrent.execute_concurrent. This attribute + # is for lower-level integrations that want some upper bound without reimplementing. + max_in_flight = 2 ** 15 + + # A set of available request IDs. When using the v3 protocol or higher, + # this will not initially include all request IDs in order to save memory, + # but the set will grow if it is exhausted. + request_ids = None + + # Tracks the highest used request ID in order to help with growing the + # request_ids set + highest_request_id = 0 + + # Tracks the request IDs which are no longer waited on (timed out), but + # cannot be reused yet because the node might still send a response + # on this stream + orphaned_request_ids = None + + # Set to true if the orphaned stream ID count cross configured threshold + # and the connection will be replaced + orphaned_threshold_reached = False + + # If the number of orphaned streams reaches this threshold, this connection + # will become marked and will be replaced with a new connection by the + # owning pool (currently, only HostConnection supports this) + orphaned_threshold = 3 * max_in_flight // 4 + is_defunct = False is_closed = False lock = None + user_type_map = None - def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None): - self.host = host - self.port = port - self.credentials = credentials + msg_received = False + + is_unsupported_proto_version = False + + is_control_connection = False + signaled_error = False # used for flagging at the pool level + + allow_beta_protocol_version = False + + _current_frame = None + + _socket = None + + _socket_impl = socket + + _check_hostname = False + _product_type = None + + _is_checksumming_enabled = False + + _on_orphaned_stream_released = None + + @property + def _iobuf(self): + # backward compatibility, to avoid any change in the reactors + return self._io_buffer.io_buffer + + def __init__(self, host='127.0.0.1', port=9042, authenticator=None, + ssl_options=None, sockopts=None, compression=True, + cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, + user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, + ssl_context=None, on_orphaned_stream_released=None): + + # TODO next major rename host to endpoint and remove port kwarg. + self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + + self.authenticator = authenticator + self.ssl_options = ssl_options.copy() if ssl_options else {} + self.ssl_context = ssl_context self.sockopts = sockopts self.compression = compression self.cql_version = cql_version - - self._id_queue = Queue(MAX_STREAM_PER_CONNECTION) - for i in range(MAX_STREAM_PER_CONNECTION): - self._id_queue.put_nowait(i) + self.protocol_version = protocol_version + self.is_control_connection = is_control_connection + self.user_type_map = user_type_map + self.connect_timeout = connect_timeout + self.allow_beta_protocol_version = allow_beta_protocol_version + self.no_compact = no_compact + self._push_watchers = defaultdict(set) + self._requests = {} + self._io_buffer = _ConnectionIOBuffer(self) + self._continuous_paging_sessions = {} + self._socket_writable = True + self.orphaned_request_ids = set() + self._on_orphaned_stream_released = on_orphaned_stream_released + + if ssl_options: + self.ssl_options.update(self.endpoint.ssl_options or {}) + elif self.endpoint.ssl_options: + self.ssl_options = self.endpoint.ssl_options + + # PYTHON-1331 + # + # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()... + # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if + # we need to do so. + # + # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this + # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call. + if not self.ssl_context and self.ssl_options: + self.ssl_context = self._build_ssl_context_from_options() + + if protocol_version >= 3: + self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1) + # Don't fill the deque with 2**15 items right away. Start with some and add + # more if needed. + initial_size = min(300, self.max_in_flight) + self.request_ids = deque(range(initial_size)) + self.highest_request_id = initial_size - 1 + else: + self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1) + self.request_ids = deque(range(self.max_request_id + 1)) + self.highest_request_id = self.max_request_id self.lock = RLock() - self.id_lock = Lock() + self.connected_event = Event() + + @property + def host(self): + return self.endpoint.address + + @property + def port(self): + return self.endpoint.port + + @classmethod + def initialize_reactor(cls): + """ + Called once by Cluster.connect(). This should be used by implementations + to set up any resources that will be shared across connections. + """ + pass + + @classmethod + def handle_fork(cls): + """ + Called after a forking. This should clean up any remaining reactor state + from the parent process. + """ + pass + + @classmethod + def create_timer(cls, timeout, callback): + raise NotImplementedError() + + @classmethod + def factory(cls, endpoint, timeout, *args, **kwargs): + """ + A factory function which returns connections which have + succeeded in connecting and are ready for service (or + raises an exception otherwise). + """ + start = time.time() + kwargs['connect_timeout'] = timeout + conn = cls(endpoint, *args, **kwargs) + elapsed = time.time() - start + conn.connected_event.wait(timeout - elapsed) + if conn.last_error: + if conn.is_unsupported_proto_version: + raise ProtocolVersionUnsupported(endpoint, conn.protocol_version) + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout) + else: + return conn + + def _build_ssl_context_from_options(self): + + # Extract a subset of names from self.ssl_options which apply to SSLContext creation + ssl_context_opt_names = ['ssl_version', 'cert_reqs', 'check_hostname', 'keyfile', 'certfile', 'ca_certs', 'ciphers'] + opts = {k:self.ssl_options.get(k, None) for k in ssl_context_opt_names if k in self.ssl_options} + + # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER, so we'll get ahead of things by always + # being explicit + ssl_version = opts.get('ssl_version', None) or ssl.PROTOCOL_TLS_CLIENT + cert_reqs = opts.get('cert_reqs', None) or ssl.CERT_REQUIRED + rv = ssl.SSLContext(protocol=int(ssl_version)) + rv.check_hostname = bool(opts.get('check_hostname', False)) + rv.options = int(cert_reqs) + + certfile = opts.get('certfile', None) + keyfile = opts.get('keyfile', None) + if certfile: + rv.load_cert_chain(certfile, keyfile) + ca_certs = opts.get('ca_certs', None) + if ca_certs: + rv.load_verify_locations(ca_certs) + ciphers = opts.get('ciphers', None) + if ciphers: + rv.set_ciphers(ciphers) + + return rv + + def _wrap_socket_from_context(self): + + # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts + # of it that don't involve building an SSLContext under the covers) + wrap_socket_opt_names = ['server_side', 'do_handshake_on_connect', 'suppress_ragged_eofs', 'server_hostname'] + opts = {k:self.ssl_options.get(k, None) for k in wrap_socket_opt_names if k in self.ssl_options} + + # PYTHON-1186: set the server_hostname only if the SSLContext has + # check_hostname enabled, and it is not already provided by the EndPoint ssl options + #opts['server_hostname'] = self.endpoint.address + if (self.ssl_context.check_hostname and 'server_hostname' not in opts): + server_hostname = self.endpoint.address + opts['server_hostname'] = server_hostname + + return self.ssl_context.wrap_socket(self._socket, **opts) + + def _initiate_connection(self, sockaddr): + self._socket.connect(sockaddr) + + # PYTHON-1331 + # + # Allow implementations specific to an event loop to add additional behaviours + def _validate_hostname(self): + pass + + def _get_socket_addresses(self): + address, port = self.endpoint.resolve() + + if hasattr(socket, 'AF_UNIX') and self.endpoint.socket_family == socket.AF_UNIX: + return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, address)] + + addresses = socket.getaddrinfo(address, port, self.endpoint.socket_family, socket.SOCK_STREAM) + if not addresses: + raise ConnectionException("getaddrinfo returned empty list for %s" % (self.endpoint,)) + + return addresses + + def _connect_socket(self): + sockerr = None + addresses = self._get_socket_addresses() + for (af, socktype, proto, _, sockaddr) in addresses: + try: + self._socket = self._socket_impl.socket(af, socktype, proto) + if self.ssl_context: + self._socket = self._wrap_socket_from_context() + self._socket.settimeout(self.connect_timeout) + self._initiate_connection(sockaddr) + self._socket.settimeout(None) + + # PYTHON-1331 + # + # Most checking is done via the check_hostname param on the SSLContext. + # Subclasses can add additional behaviours via _validate_hostname() so + # run that here. + if self._check_hostname: + self._validate_hostname() + sockerr = None + break + except socket.error as err: + if self._socket: + self._socket.close() + self._socket = None + sockerr = err + + if sockerr: + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % + ([a[4] for a in addresses], sockerr.strerror or sockerr)) + + if self.sockopts: + for args in self.sockopts: + self._socket.setsockopt(*args) + + def _enable_compression(self): + if self._compressor: + self.compressor = self._compressor + + def _enable_checksumming(self): + self._io_buffer.set_checksumming_buffer() + self._is_checksumming_enabled = True + self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression + log.debug("Enabling protocol checksumming on connection (%s).", id(self)) def close(self): raise NotImplementedError() def defunct(self, exc): - raise NotImplementedError() + with self.lock: + if self.is_defunct or self.is_closed: + return + self.is_defunct = True + + exc_info = sys.exc_info() + # if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message + if any(exc_info): + log.debug("Defuncting connection (%s) to %s:", + id(self), self.endpoint, exc_info=exc_info) + else: + log.debug("Defuncting connection (%s) to %s: %s", + id(self), self.endpoint, exc) + + self.last_error = exc + self.close() + self.error_all_cp_sessions(exc) + self.error_all_requests(exc) + self.connected_event.set() + return exc + + def error_all_cp_sessions(self, exc): + stream_ids = list(self._continuous_paging_sessions.keys()) + for stream_id in stream_ids: + self._continuous_paging_sessions[stream_id].on_error(exc) + + def error_all_requests(self, exc): + with self.lock: + requests = self._requests + self._requests = {} - def send_msg(self, msg, cb): - raise NotImplementedError() + if not requests: + return - def wait_for_response(self, msg): - raise NotImplementedError() + new_exc = ConnectionShutdown(str(exc)) - def wait_for_responses(self, *msgs): - raise NotImplementedError() + def try_callback(cb): + try: + cb(new_exc) + except Exception: + log.warning("Ignoring unhandled exception while erroring requests for a " + "failed connection (%s) to host %s:", + id(self), self.endpoint, exc_info=True) - def register_watcher(self, event_type, callback): - raise NotImplementedError() + # run first callback from this thread to ensure pool state before leaving + cb, _, _ = requests.popitem()[1] + try_callback(cb) - def register_watchers(self, type_callback_dict): - raise NotImplementedError() + if not requests: + return - @defunct_on_error - def process_msg(self, msg, body_len): - version, flags, stream_id, opcode = map(int8_unpack, msg[:4]) - if stream_id < 0: - callback = None + # additional requests are optionally errored from a separate thread + # The default callback and retry logic is fairly expensive -- we don't + # want to tie up the event thread when there are many requests + def err_all_callbacks(): + for cb, _, _ in requests.values(): + try_callback(cb) + if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: + err_all_callbacks() else: - callback = self._callbacks.pop(stream_id, None) - self._id_queue.put_nowait(stream_id) + # daemon thread here because we want to stay decoupled from the cluster TPE + # TODO: would it make sense to just have a driver-global TPE? + t = Thread(target=err_all_callbacks) + t.daemon = True + t.start() + + def get_request_id(self): + """ + This must be called while self.lock is held. + """ + try: + return self.request_ids.popleft() + except IndexError: + new_request_id = self.highest_request_id + 1 + # in_flight checks should guarantee this + assert new_request_id <= self.max_request_id + self.highest_request_id = new_request_id + return self.highest_request_id + + def handle_pushed(self, response): + log.debug("Message pushed from server: %r", response) + for cb in self._push_watchers.get(response.event_type, []): + try: + cb(response.event_args) + except Exception: + log.exception("Pushed event handler errored, ignoring:") + + def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): + if self.is_defunct: + raise ConnectionShutdown("Connection to %s is defunct" % self.endpoint) + elif self.is_closed: + raise ConnectionShutdown("Connection to %s is closed" % self.endpoint) + elif not self._socket_writable: + raise ConnectionBusy("Connection %s is overloaded" % self.endpoint) + + # queue the decoder function with the request + # this allows us to inject custom functions per request to encode, decode messages + self._requests[request_id] = (cb, decoder, result_metadata) + msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, + allow_beta_protocol_version=self.allow_beta_protocol_version) + + if self._is_checksumming_enabled: + buffer = io.BytesIO() + self._segment_codec.encode(buffer, msg) + msg = buffer.getvalue() + + self.push(msg) + return len(msg) + + def wait_for_response(self, msg, timeout=None, **kwargs): + return self.wait_for_responses(msg, timeout=timeout, **kwargs)[0] + + def wait_for_responses(self, *msgs, **kwargs): + """ + Returns a list of (success, response) tuples. If success + is False, response will be an Exception. Otherwise, response + will be the normal query response. + + If fail_on_error was left as True and one of the requests + failed, the corresponding Exception will be raised. + """ + if self.is_closed or self.is_defunct: + raise ConnectionShutdown("Connection %s is already closed" % (self, )) + timeout = kwargs.get('timeout') + fail_on_error = kwargs.get('fail_on_error', True) + waiter = ResponseWaiter(self, len(msgs), fail_on_error) + + # busy wait for sufficient space on the connection + messages_sent = 0 + while True: + needed = len(msgs) - messages_sent + with self.lock: + available = min(needed, self.max_request_id - self.in_flight + 1) + request_ids = [self.get_request_id() for _ in range(available)] + self.in_flight += available + + for i, request_id in enumerate(request_ids): + self.send_msg(msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i)) + messages_sent += available + + if messages_sent == len(msgs): + break + else: + if timeout is not None: + timeout -= 0.01 + if timeout <= 0.0: + raise OperationTimedOut() + time.sleep(0.01) - body = None try: - # check that the protocol version is supported - given_version = version & PROTOCOL_VERSION_MASK - if given_version != PROTOCOL_VERSION: - raise ProtocolError("Unsupported CQL protocol version: %d" % given_version) + return waiter.deliver(timeout) + except OperationTimedOut: + raise + except Exception as exc: + self.defunct(exc) + raise + + def register_watcher(self, event_type, callback, register_timeout=None): + """ + Register a callback for a given event type. + """ + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + """ + Register multiple callback/event type pairs, expressed as a dict. + """ + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) + + def control_conn_disposed(self): + self.is_control_connection = False + self._push_watchers = {} - # check that the header direction is correct - if version & HEADER_DIRECTION_MASK != HEADER_DIRECTION_TO_CLIENT: - raise ProtocolError( - "Header direction in response is incorrect; opcode %04x, stream id %r" - % (opcode, stream_id)) + @defunct_on_error + def _read_frame_header(self): + buf = self._io_buffer.cql_frame_buffer.getvalue() + pos = len(buf) + if pos: + version = buf[0] & PROTOCOL_VERSION_MASK + if version not in ProtocolVersion.SUPPORTED_VERSIONS: + raise ProtocolError("This version of the driver does not support protocol version %d" % version) + frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2 + # this frame header struct is everything after the version byte + header_size = frame_header.size + 1 + if pos >= header_size: + flags, stream, op, body_len = frame_header.unpack_from(buf, 1) + if body_len < 0: + raise ProtocolError("Received negative body length: %r" % body_len) + self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) + return pos - if body_len > 0: - body = msg[8:] - elif body_len == 0: - body = "" + @defunct_on_error + def _process_segment_buffer(self): + readable_bytes = self._io_buffer.readable_io_bytes() + if readable_bytes >= self._segment_codec.header_length_with_crc: + try: + self._io_buffer.io_buffer.seek(0) + segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer) + + if readable_bytes >= segment_header.segment_length: + segment = self._segment_codec.decode(self._iobuf, segment_header) + self._io_buffer._segment_consumed = True + self._io_buffer.cql_frame_buffer.write(segment.payload) + else: + # not enough data to read the segment. reset the buffer pointer at the + # beginning to not lose what we previously read (header). + self._io_buffer._segment_consumed = False + self._io_buffer.io_buffer.seek(0) + except CrcException as exc: + # re-raise an exception that inherits from ConnectionException + raise CrcMismatchException(str(exc), self.endpoint) + else: + self._io_buffer._segment_consumed = False + + def process_io_buffer(self): + while True: + if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes(): + self._process_segment_buffer() + self._io_buffer.reset_io_buffer() + + if self._is_checksumming_enabled and not self._io_buffer.has_consumed_segment: + # We couldn't read an entire segment from the io buffer, so return + # control to allow more bytes to be read off the wire + return + + if not self._current_frame: + pos = self._read_frame_header() + else: + pos = self._io_buffer.readable_cql_frame_bytes() + + if not self._current_frame or pos < self._current_frame.end_pos: + if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes(): + # We have a multi-segments message, and we need to read more + # data to complete the current cql frame + continue + + # we don't have a complete header yet, or we + # already saw a header, but we don't have a + # complete message yet + return else: - raise ProtocolError("Got negative body length: %r" % body_len) + frame = self._current_frame + self._io_buffer.cql_frame_buffer.seek(frame.body_offset) + msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset) + self.process_msg(frame, msg) + self._io_buffer.reset_cql_frame_buffer() + self._current_frame = None - response = decode_response(stream_id, flags, opcode, body, self.decompressor) + @defunct_on_error + def process_msg(self, header, body): + self.msg_received = True + stream_id = header.stream + if stream_id < 0: + callback = None + decoder = ProtocolHandler.decode_message + result_metadata = None + else: + if stream_id in self._continuous_paging_sessions: + paging_session = self._continuous_paging_sessions[stream_id] + callback = paging_session.on_message + decoder = paging_session.decoder + result_metadata = None + else: + need_notify_of_release = False + with self.lock: + if stream_id in self.orphaned_request_ids: + self.in_flight -= 1 + self.orphaned_request_ids.remove(stream_id) + need_notify_of_release = True + if need_notify_of_release and self._on_orphaned_stream_released: + self._on_orphaned_stream_released() + + try: + callback, decoder, result_metadata = self._requests.pop(stream_id) + # This can only happen if the stream_id was + # removed due to an OperationTimedOut + except KeyError: + with self.lock: + self.request_ids.append(stream_id) + return + + try: + response = decoder(header.version, self.user_type_map, stream_id, + header.flags, header.opcode, body, self.decompressor, result_metadata) except Exception as exc: log.exception("Error decoding response from Cassandra. " - "opcode: %04x; message contents: %r" % (opcode, body)) + "%s; buffer: %r", header, self._iobuf.getvalue()) if callback is not None: callback(exc) self.defunct(exc) return try: - if stream_id < 0: + if stream_id >= 0: + if isinstance(response, ProtocolException): + if 'unsupported protocol version' in response.message: + self.is_unsupported_proto_version = True + else: + log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) + self.defunct(response) + if callback is not None: + callback(response) + else: self.handle_pushed(response) - elif callback is not None: - callback(response) except Exception: log.exception("Callback handler errored, ignoring:") + # done after callback because the callback might signal this as a paging session + if stream_id >= 0: + if stream_id in self._continuous_paging_sessions: + if self._continuous_paging_sessions[stream_id].released: + self.remove_continuous_paging_session(stream_id) + else: + with self.lock: + self.request_ids.append(stream_id) + + def new_continuous_paging_session(self, stream_id, decoder, row_factory, state): + session = ContinuousPagingSession(stream_id, decoder, row_factory, self, state) + self._continuous_paging_sessions[stream_id] = session + return session + + def remove_continuous_paging_session(self, stream_id): + try: + self._continuous_paging_sessions.pop(stream_id) + with self.lock: + log.debug("Returning cp session stream id %s", stream_id) + self.request_ids.append(stream_id) + except KeyError: + pass + @defunct_on_error def _send_options_message(self): - log.debug("Sending initial options message for new Connection to %s", self.host) - self.send_msg(OptionsMessage(), self._handle_options_response) + log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.endpoint) + self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) @defunct_on_error def _handle_options_response(self, options_response): @@ -217,121 +1331,511 @@ def _handle_options_response(self, options_response): return if not isinstance(options_response, SupportedMessage): - log.error("Did not get expected SupportedMessage response; instead, got: %s", options_response) - return - - log.debug("Received options response on new Connection from %s" % self.host) - self.supported_cql_versions = options_response.cql_versions - self.remote_supported_compressions = options_response.options['COMPRESSION'] + if isinstance(options_response, ConnectionException): + raise options_response + else: + log.error("Did not get expected SupportedMessage response; " + "instead, got: %s", options_response) + raise ConnectionException("Did not get expected SupportedMessage " + "response; instead, got: %s" + % (options_response,)) + + log.debug("Received options response on new connection (%s) from %s", + id(self), self.endpoint) + supported_cql_versions = options_response.cql_versions + remote_supported_compressions = options_response.options['COMPRESSION'] + self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0] if self.cql_version: - if self.cql_version not in self.supported_cql_versions: + if self.cql_version not in supported_cql_versions: raise ProtocolError( "cql_version %r is not supported by remote (w/ native " "protocol). Supported versions: %r" - % (self.cql_version, self.supported_cql_versions)) + % (self.cql_version, supported_cql_versions)) else: - self.cql_version = self.supported_cql_versions[0] + self.cql_version = supported_cql_versions[0] - opts = {} self._compressor = None + compression_type = None if self.compression: overlap = (set(locally_supported_compressions.keys()) & - set(self.remote_supported_compressions)) + set(remote_supported_compressions)) if len(overlap) == 0: log.debug("No available compression types supported on both ends." - " locally supported: %r. remotely supported: %r" - % (locally_supported_compressions.keys(), - self.remote_supported_compressions)) + " locally supported: %r. remotely supported: %r", + locally_supported_compressions.keys(), + remote_supported_compressions) else: - compression_type = iter(overlap).next() # choose any - opts['COMPRESSION'] = compression_type - # set the decompressor here, but set the compressor only after - # a successful Ready message - self._compressor, self.decompressor = \ - locally_supported_compressions[compression_type] + compression_type = None + if isinstance(self.compression, str): + # the user picked a specific compression type ('snappy' or 'lz4') + if self.compression not in remote_supported_compressions: + raise ProtocolError( + "The requested compression type (%s) is not supported by the Cassandra server at %s" + % (self.compression, self.endpoint)) + compression_type = self.compression + else: + # our locally supported compressions are ordered to prefer + # lz4, if available + for k in locally_supported_compressions.keys(): + if k in overlap: + compression_type = k + break + + # If snappy compression is selected with v5+checksumming, the connection + # will fail with OTO. Only lz4 is supported + if (compression_type == 'snappy' and + ProtocolVersion.has_checksumming_support(self.protocol_version)): + log.debug("Snappy compression is not supported with protocol version %s and " + "checksumming. Consider installing lz4. Disabling compression.", self.protocol_version) + compression_type = None + else: + # set the decompressor here, but set the compressor only after + # a successful Ready message + self._compression_type = compression_type + self._compressor, self.decompressor = \ + locally_supported_compressions[compression_type] + + self._send_startup_message(compression_type, no_compact=self.no_compact) + @defunct_on_error + def _send_startup_message(self, compression=None, no_compact=False): + log.debug("Sending StartupMessage on %s", self) + opts = {'DRIVER_NAME': DRIVER_NAME, + 'DRIVER_VERSION': DRIVER_VERSION} + if compression: + opts['COMPRESSION'] = compression + if no_compact: + opts['NO_COMPACT'] = 'true' sm = StartupMessage(cqlversion=self.cql_version, options=opts) - self.send_msg(sm, cb=self._handle_startup_response) + self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) + log.debug("Sent StartupMessage on %s", self) @defunct_on_error def _handle_startup_response(self, startup_response, did_authenticate=False): if self.is_defunct: return + if isinstance(startup_response, ReadyMessage): - log.debug("Got ReadyMessage on new Connection from %s" % self.host) - if self._compressor: - self.compressor = self._compressor - self.connected_event.set() - elif isinstance(startup_response, AuthenticateMessage): - log.debug("Got AuthenticateMessage on new Connection from %s" % self.host) + if self.authenticator: + log.warning("An authentication challenge was not sent, " + "this is suspicious because the driver expects " + "authentication (configured authenticator = %s)", + self.authenticator.__class__.__name__) + + log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) + self._enable_compression() - if self.credentials is None: - raise AuthenticationFailed('Remote end requires authentication.') + if ProtocolVersion.has_checksumming_support(self.protocol_version): + self._enable_checksumming() - self.authenticator = startup_response.authenticator - cm = CredentialsMessage(creds=self.credentials) - callback = partial(self._handle_startup_response, did_authenticate=True) - self.send_msg(cm, cb=callback) + self.connected_event.set() + elif isinstance(startup_response, AuthenticateMessage): + log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", + id(self), self.endpoint, startup_response.authenticator) + + if self.authenticator is None: + log.error("Failed to authenticate to %s. If you are trying to connect to a DSE cluster, " + "consider using TransitionalModePlainTextAuthProvider " + "if DSE authentication is configured with transitional mode" % (self.host,)) + raise AuthenticationFailed('Remote end requires authentication') + + self._enable_compression() + if ProtocolVersion.has_checksumming_support(self.protocol_version): + self._enable_checksumming() + + if isinstance(self.authenticator, dict): + log.debug("Sending credentials-based auth response on %s", self) + cm = CredentialsMessage(creds=self.authenticator) + callback = partial(self._handle_startup_response, did_authenticate=True) + self.send_msg(cm, self.get_request_id(), cb=callback) + else: + log.debug("Sending SASL-based auth response on %s", self) + self.authenticator.server_authenticator_class = startup_response.authenticator + initial_response = self.authenticator.initial_response() + initial_response = "" if initial_response is None else initial_response + self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), + self._handle_auth_response) elif isinstance(startup_response, ErrorMessage): - log.debug("Received ErrorMessage on new Connection from %s: %s" - % (self.host, startup_response.summary_msg())) + log.debug("Received ErrorMessage on new connection (%s) from %s: %s", + id(self), self.endpoint, startup_response.summary_msg()) if did_authenticate: raise AuthenticationFailed( "Failed to authenticate to %s: %s" % - (self.host, startup_response.summary_msg())) + (self.endpoint, startup_response.summary_msg())) else: raise ConnectionException( "Failed to initialize new connection to %s: %s" - % (self.host, startup_response.summary_msg())) + % (self.endpoint, startup_response.summary_msg())) + elif isinstance(startup_response, ConnectionShutdown): + log.debug("Connection to %s was closed during the startup handshake", (self.endpoint)) + raise startup_response else: - msg = "Unexpected response during Connection setup: %r" % (startup_response,) - log.error(msg) - raise ProtocolError(msg) + msg = "Unexpected response during Connection setup: %r" + log.error(msg, startup_response) + raise ProtocolError(msg % (startup_response,)) + + @defunct_on_error + def _handle_auth_response(self, auth_response): + if self.is_defunct: + return - def set_keyspace(self, keyspace): + if isinstance(auth_response, AuthSuccessMessage): + log.debug("Connection %s successfully authenticated", self) + self.authenticator.on_authentication_success(auth_response.token) + if self._compressor: + self.compressor = self._compressor + self.connected_event.set() + elif isinstance(auth_response, AuthChallengeMessage): + response = self.authenticator.evaluate_challenge(auth_response.challenge) + msg = AuthResponseMessage("" if response is None else response) + log.debug("Responding to auth challenge on %s", self) + self.send_msg(msg, self.get_request_id(), self._handle_auth_response) + elif isinstance(auth_response, ErrorMessage): + log.debug("Received ErrorMessage on new connection (%s) from %s: %s", + id(self), self.endpoint, auth_response.summary_msg()) + raise AuthenticationFailed( + "Failed to authenticate to %s: %s" % + (self.endpoint, auth_response.summary_msg())) + elif isinstance(auth_response, ConnectionShutdown): + log.debug("Connection to %s was closed during the authentication process", self.endpoint) + raise auth_response + else: + msg = "Unexpected response during Connection authentication to %s: %r" + log.error(msg, self.endpoint, auth_response) + raise ProtocolError(msg % (self.endpoint, auth_response)) + + def set_keyspace_blocking(self, keyspace): if not keyspace or keyspace == self.keyspace: return - with self.lock: - query = 'USE "%s"' % (keyspace,) - try: - result = self.wait_for_response( - QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)) - if isinstance(result, ResultMessage): - self.keyspace = keyspace - else: - raise self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.host)) - except InvalidRequestException as ire: - # the keyspace probably doesn't exist - raise ire.to_exception() - except Exception as exc: - raise self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (exc,), self.host)) + query = QueryMessage(query='USE "%s"' % (keyspace,), + consistency_level=ConsistencyLevel.ONE) + try: + result = self.wait_for_response(query) + except InvalidRequestException as ire: + # the keyspace probably doesn't exist + raise ire.to_exception() + except Exception as exc: + conn_exc = ConnectionException( + "Problem while setting keyspace: %r" % (exc,), self.endpoint) + self.defunct(conn_exc) + raise conn_exc + + if isinstance(result, ResultMessage): + self.keyspace = keyspace + else: + conn_exc = ConnectionException( + "Problem while setting keyspace: %r" % (result,), self.endpoint) + self.defunct(conn_exc) + raise conn_exc + + def set_keyspace_async(self, keyspace, callback): + """ + Use this in order to avoid deadlocking the event loop thread. + When the operation completes, `callback` will be called with + two arguments: this connection and an Exception if an error + occurred, otherwise :const:`None`. + + This method will always increment :attr:`.in_flight` attribute, even if + it doesn't need to make a request, just to maintain an + ":attr:`.in_flight` is incremented" invariant. + """ + # Here we increment in_flight unconditionally, whether we need to issue + # a request or not. This is bad, but allows callers -- specifically + # _set_keyspace_for_all_conns -- to assume that we increment + # self.in_flight during this call. This allows the passed callback to + # safely call HostConnection{Pool,}.return_connection on this + # Connection. + # + # We use a busy wait on the lock here because: + # - we'll only spin if the connection is at max capacity, which is very + # unlikely for a set_keyspace call + # - it allows us to avoid signaling a condition every time a request completes + while True: + with self.lock: + if self.in_flight < self.max_request_id: + self.in_flight += 1 + break + time.sleep(0.001) + + if not keyspace or keyspace == self.keyspace: + callback(self, None) + return + + query = QueryMessage(query='USE "%s"' % (keyspace,), + consistency_level=ConsistencyLevel.ONE) + + def process_result(result): + if isinstance(result, ResultMessage): + self.keyspace = keyspace + callback(self, None) + elif isinstance(result, InvalidRequestException): + callback(self, result.to_exception()) + else: + callback(self, self.defunct(ConnectionException( + "Problem while setting keyspace: %r" % (result,), self.endpoint))) + + # We've incremented self.in_flight above, so we "have permission" to + # acquire a new request id + request_id = self.get_request_id() + + self.send_msg(query, request_id, process_result) + + @property + def is_idle(self): + return not self.msg_received + + def reset_idle(self): + self.msg_received = False + + def __str__(self): + status = "" + if self.is_defunct: + status = " (defunct)" + elif self.is_closed: + status = " (closed)" + + return "<%s(%r) %s%s>" % (self.__class__.__name__, id(self), self.endpoint, status) + __repr__ = __str__ class ResponseWaiter(object): - def __init__(self, num_responses): + def __init__(self, connection, num_responses, fail_on_error): + self.connection = connection self.pending = num_responses + self.fail_on_error = fail_on_error self.error = None self.responses = [None] * num_responses self.event = Event() def got_response(self, response, index): + with self.connection.lock: + self.connection.in_flight -= 1 if isinstance(response, Exception): - self.error = response - self.event.set() - else: - self.responses[index] = response - self.pending -= 1 - if not self.pending: + if hasattr(response, 'to_exception'): + response = response.to_exception() + if self.fail_on_error: + self.error = response self.event.set() + else: + self.responses[index] = (False, response) + else: + if not self.fail_on_error: + self.responses[index] = (True, response) + else: + self.responses[index] = response - def deliver(self): - self.event.wait() + self.pending -= 1 + if not self.pending: + self.event.set() + + def deliver(self, timeout=None): + """ + If fail_on_error was set to False, a list of (success, response) + tuples will be returned. If success is False, response will be + an Exception. Otherwise, response will be the normal query response. + + If fail_on_error was left as True and one of the requests + failed, the corresponding Exception will be raised. Otherwise, + the normal response will be returned. + """ + self.event.wait(timeout) if self.error: raise self.error + elif not self.event.is_set(): + raise OperationTimedOut() else: return self.responses + + +class HeartbeatFuture(object): + def __init__(self, connection, owner): + self._exception = None + self._event = Event() + self.connection = connection + self.owner = owner + log.debug("Sending options message heartbeat on idle connection (%s) %s", + id(connection), connection.endpoint) + with connection.lock: + if connection.in_flight < connection.max_request_id: + connection.in_flight += 1 + connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + else: + self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") + self._event.set() + + def wait(self, timeout): + self._event.wait(timeout) + if self._event.is_set(): + if self._exception: + raise self._exception + else: + raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint) + + def _options_callback(self, response): + if isinstance(response, SupportedMessage): + log.debug("Received options response on connection (%s) from %s", + id(self.connection), self.connection.endpoint) + else: + if isinstance(response, ConnectionException): + self._exception = response + else: + self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" + % (response,)) + self._event.set() + + +class ConnectionHeartbeat(Thread): + + def __init__(self, interval_sec, get_connection_holders, timeout): + Thread.__init__(self, name="Connection heartbeat") + self._interval = interval_sec + self._timeout = timeout + self._get_connection_holders = get_connection_holders + self._shutdown_event = Event() + self.daemon = True + self.start() + + class ShutdownException(Exception): + pass + + def run(self): + self._shutdown_event.wait(self._interval) + while not self._shutdown_event.is_set(): + start_time = time.time() + + futures = [] + failed_connections = [] + try: + for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: + for connection in connections: + self._raise_if_stopped() + if not (connection.is_defunct or connection.is_closed): + if connection.is_idle: + try: + futures.append(HeartbeatFuture(connection, owner)) + except Exception as e: + log.warning("Failed sending heartbeat message on connection (%s) to %s", + id(connection), connection.endpoint) + failed_connections.append((connection, owner, e)) + else: + connection.reset_idle() + else: + log.debug("Cannot send heartbeat message on connection (%s) to %s", + id(connection), connection.endpoint) + # make sure the owner sees this defunct/closed connection + owner.return_connection(connection) + self._raise_if_stopped() + + # Wait max `self._timeout` seconds for all HeartbeatFutures to complete + timeout = self._timeout + start_time = time.time() + for f in futures: + self._raise_if_stopped() + connection = f.connection + try: + f.wait(timeout) + # TODO: move this, along with connection locks in pool, down into Connection + with connection.lock: + connection.in_flight -= 1 + connection.reset_idle() + except Exception as e: + log.warning("Heartbeat failed for connection (%s) to %s", + id(connection), connection.endpoint) + failed_connections.append((f.connection, f.owner, e)) + + timeout = self._timeout - (time.time() - start_time) + + for connection, owner, exc in failed_connections: + self._raise_if_stopped() + if not connection.is_control_connection: + # Only HostConnection supports shutdown_on_error + owner.shutdown_on_error = True + connection.defunct(exc) + owner.return_connection(connection) + except self.ShutdownException: + pass + except Exception: + log.error("Failed connection heartbeat", exc_info=True) + + elapsed = time.time() - start_time + self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) + + def stop(self): + self._shutdown_event.set() + self.join() + + def _raise_if_stopped(self): + if self._shutdown_event.is_set(): + raise self.ShutdownException() + + +class Timer(object): + + canceled = False + + def __init__(self, timeout, callback): + self.end = time.time() + timeout + self.callback = callback + + def __lt__(self, other): + return self.end < other.end + + def cancel(self): + self.canceled = True + + def finish(self, time_now): + if self.canceled: + return True + + if time_now >= self.end: + self.callback() + return True + + return False + + +class TimerManager(object): + + def __init__(self): + self._queue = [] + self._new_timers = [] + + def add_timer(self, timer): + """ + called from client thread with a Timer object + """ + self._new_timers.append((timer.end, timer)) + + def service_timeouts(self): + """ + run callbacks on all expired timers + Called from the event thread + :return: next end time, or None + """ + queue = self._queue + if self._new_timers: + new_timers = self._new_timers + while new_timers: + heappush(queue, new_timers.pop()) + + if queue: + now = time.time() + while queue: + try: + timer = queue[0][1] + if timer.finish(now): + heappop(queue) + else: + return timer.end + except Exception: + log.exception("Exception while servicing timeout callback: ") + + @property + def next_timeout(self): + try: + return self._queue[0][0] + except IndexError: + pass diff --git a/cassandra/cqlengine/__init__.py b/cassandra/cqlengine/__init__.py new file mode 100644 index 0000000000..200d04b831 --- /dev/null +++ b/cassandra/cqlengine/__init__.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Caching constants. +CACHING_ALL = "ALL" +CACHING_KEYS_ONLY = "KEYS_ONLY" +CACHING_ROWS_ONLY = "ROWS_ONLY" +CACHING_NONE = "NONE" + + +class CQLEngineException(Exception): + pass + + +class ValidationError(CQLEngineException): + pass + + +class UnicodeMixin(object): + __str__ = lambda x: x.__unicode__() diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py new file mode 100644 index 0000000000..7d50687d95 --- /dev/null +++ b/cassandra/cqlengine/columns.py @@ -0,0 +1,1080 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy, copy +from datetime import date, datetime, timedelta, timezone +import logging +from uuid import UUID as _UUID + +from cassandra import util +from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType +from cassandra.cqlengine import ValidationError +from cassandra.cqlengine.functions import get_total_seconds +from cassandra.util import Duration as _Duration + +log = logging.getLogger(__name__) + + +class BaseValueManager(object): + + def __init__(self, instance, column, value): + self.instance = instance + self.column = column + self.value = value + self.previous_value = None + self.explicit = False + + @property + def deleted(self): + return self.column._val_is_null(self.value) and (self.explicit or not self.column._val_is_null(self.previous_value)) + + @property + def changed(self): + """ + Indicates whether or not this value has changed. + + :rtype: boolean + + """ + if self.explicit: + return self.value != self.previous_value + + if isinstance(self.column, BaseContainerColumn): + default_value = self.column.get_default() + if self.column._val_is_null(default_value): + return not self.column._val_is_null(self.value) and self.value != self.previous_value + elif self.previous_value is None: + return self.value != default_value + + return self.value != self.previous_value + + return False + + def reset_previous_value(self): + self.previous_value = deepcopy(self.value) + + def getval(self): + return self.value + + def setval(self, val): + self.value = val + self.explicit = True + + def delval(self): + self.value = None + + def get_property(self): + _get = lambda slf: self.getval() + _set = lambda slf, val: self.setval(val) + _del = lambda slf: self.delval() + + if self.column.can_delete: + return property(_get, _set, _del) + else: + return property(_get, _set) + + +class Column(object): + + # the cassandra type this column maps to + db_type = None + value_manager = BaseValueManager + + instance_counter = 0 + + _python_type_hashable = True + + primary_key = False + """ + bool flag, indicates this column is a primary key. The first primary key defined + on a model is the partition key (unless partition keys are set), all others are cluster keys + """ + + partition_key = False + + """ + indicates that this column should be the partition key, defining + more than one partition key column creates a compound partition key + """ + + index = False + """ + bool flag, indicates an index should be created for this column + """ + + custom_index = False + """ + bool flag, indicates an index is managed outside of cqlengine. This is + useful if you want to do filter queries on fields that have custom + indexes. + """ + + db_field = None + """ + the fieldname this field will map to in the database + """ + + default = None + """ + the default value, can be a value or a callable (no args) + """ + + required = False + """ + boolean, is the field required? Model validation will raise and + exception if required is set to True and there is a None value assigned + """ + + clustering_order = None + """ + only applicable on clustering keys (primary keys that are not partition keys) + determines the order that the clustering keys are sorted on disk + """ + + discriminator_column = False + """ + boolean, if set to True, this column will be used for discriminating records + of inherited models. + + Should only be set on a column of an abstract model being used for inheritance. + + There may only be one discriminator column per model. See :attr:`~.__discriminator_value__` + for how to specify the value of this column on specialized models. + """ + + static = False + """ + boolean, if set to True, this is a static column, with a single value per partition + """ + + def __init__(self, + primary_key=False, + partition_key=False, + index=False, + db_field=None, + default=None, + required=False, + clustering_order=None, + discriminator_column=False, + static=False, + custom_index=False): + self.partition_key = partition_key + self.primary_key = partition_key or primary_key + self.index = index + self.custom_index = custom_index + self.db_field = db_field + self.default = default + self.required = required + self.clustering_order = clustering_order + self.discriminator_column = discriminator_column + + # the column name in the model definition + self.column_name = None + self._partition_key_index = None + self.static = static + + self.value = None + + # keep track of instantiation order + self.position = Column.instance_counter + Column.instance_counter += 1 + + def __ne__(self, other): + if isinstance(other, Column): + return self.position != other.position + return NotImplemented + + def __eq__(self, other): + if isinstance(other, Column): + return self.position == other.position + return NotImplemented + + def __lt__(self, other): + if isinstance(other, Column): + return self.position < other.position + return NotImplemented + + def __le__(self, other): + if isinstance(other, Column): + return self.position <= other.position + return NotImplemented + + def __gt__(self, other): + if isinstance(other, Column): + return self.position > other.position + return NotImplemented + + def __ge__(self, other): + if isinstance(other, Column): + return self.position >= other.position + return NotImplemented + + def __hash__(self): + return id(self) + + def validate(self, value): + """ + Returns a cleaned and validated value. Raises a ValidationError + if there's a problem + """ + if value is None: + if self.required: + raise ValidationError('{0} - None values are not allowed'.format(self.column_name or self.db_field)) + return value + + def to_python(self, value): + """ + Converts data from the database into python values + raises a ValidationError if the value can't be converted + """ + return value + + def to_database(self, value): + """ + Converts python value into database value + """ + return value + + @property + def has_default(self): + return self.default is not None + + @property + def is_primary_key(self): + return self.primary_key + + @property + def can_delete(self): + return not self.primary_key + + def get_default(self): + if self.has_default: + if callable(self.default): + return self.default() + else: + return self.default + + def get_column_def(self): + """ + Returns a column definition for CQL table definition + """ + static = "static" if self.static else "" + return '{0} {1} {2}'.format(self.cql, self.db_type, static) + + # TODO: make columns use cqltypes under the hood + # until then, this bridges the gap in using types along with cassandra.metadata for CQL generation + def cql_parameterized_type(self): + return self.db_type + + def set_column_name(self, name): + """ + Sets the column name during document class construction + This value will be ignored if db_field is set in __init__ + """ + self.column_name = name + + @property + def db_field_name(self): + """ Returns the name of the cql name of this column """ + return self.db_field if self.db_field is not None else self.column_name + + @property + def db_index_name(self): + """ Returns the name of the cql index """ + return 'index_{0}'.format(self.db_field_name) + + @property + def has_index(self): + return self.index or self.custom_index + + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{0}"'.format(self.db_field_name) + + def _val_is_null(self, val): + """ determines if the given value equates to a null value for the given column type """ + return val is None + + @property + def sub_types(self): + return [] + + @property + def cql_type(self): + return _cqltypes[self.db_type] + + +class Blob(Column): + """ + Stores a raw binary value + """ + db_type = 'blob' + + def to_database(self, value): + + if not isinstance(value, (bytes, bytearray)): + raise Exception("expecting a binary, got a %s" % type(value)) + + val = super(Bytes, self).to_database(value) + return bytearray(val) + + +Bytes = Blob + + +class Inet(Column): + """ + Stores an IP address in IPv4 or IPv6 format + """ + db_type = 'inet' + + +class Text(Column): + """ + Stores a UTF-8 encoded string + """ + db_type = 'text' + + def __init__(self, min_length=None, max_length=None, **kwargs): + """ + :param int min_length: Sets the minimum length of this string, for validation purposes. + Defaults to 1 if this is a ``required`` column. Otherwise, None. + :param int max_length: Sets the maximum length of this string, for validation purposes. + """ + self.min_length = ( + 1 if min_length is None and kwargs.get('required', False) + else min_length) + self.max_length = max_length + + if self.min_length is not None: + if self.min_length < 0: + raise ValueError( + 'Minimum length is not allowed to be negative.') + + if self.max_length is not None: + if self.max_length < 0: + raise ValueError( + 'Maximum length is not allowed to be negative.') + + if self.min_length is not None and self.max_length is not None: + if self.max_length < self.min_length: + raise ValueError( + 'Maximum length must be greater or equal ' + 'to minimum length.') + + super(Text, self).__init__(**kwargs) + + def validate(self, value): + value = super(Text, self).validate(value) + if not isinstance(value, (str, bytearray)) and value is not None: + raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value))) + if self.max_length is not None: + if value and len(value) > self.max_length: + raise ValidationError('{0} is longer than {1} characters'.format(self.column_name, self.max_length)) + if self.min_length: + if (self.min_length and not value) or len(value) < self.min_length: + raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) + return value + + +class Ascii(Text): + """ + Stores a US-ASCII character string + """ + db_type = 'ascii' + + def validate(self, value): + """ Only allow ASCII and None values. + + Check against US-ASCII, a.k.a. 7-bit ASCII, a.k.a. ISO646-US, a.k.a. + the Basic Latin block of the Unicode character set. + + Source: https://github.com/apache/cassandra/blob + /3dcbe90e02440e6ee534f643c7603d50ca08482b/src/java/org/apache/cassandra + /serializers/AsciiSerializer.java#L29 + """ + value = super(Ascii, self).validate(value) + if value: + charset = value if isinstance( + value, (bytearray, )) else map(ord, value) + if not set(range(128)).issuperset(charset): + raise ValidationError( + '{!r} is not an ASCII string.'.format(value)) + return value + + +class Integer(Column): + """ + Stores a 32-bit signed integer value + """ + + db_type = 'int' + + def validate(self, value): + val = super(Integer, self).validate(value) + if val is None: + return + try: + return int(val) + except (TypeError, ValueError): + raise ValidationError("{0} {1} can't be converted to integral value".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class TinyInt(Integer): + """ + Stores an 8-bit signed integer value + + .. versionadded:: 2.6.0 + + requires C* 2.2+ and protocol v4+ + """ + db_type = 'tinyint' + + +class SmallInt(Integer): + """ + Stores a 16-bit signed integer value + + .. versionadded:: 2.6.0 + + requires C* 2.2+ and protocol v4+ + """ + db_type = 'smallint' + + +class BigInt(Integer): + """ + Stores a 64-bit signed integer value + """ + db_type = 'bigint' + + +class VarInt(Column): + """ + Stores an arbitrary-precision integer + """ + db_type = 'varint' + + def validate(self, value): + val = super(VarInt, self).validate(value) + if val is None: + return + try: + return int(val) + except (TypeError, ValueError): + raise ValidationError( + "{0} {1} can't be converted to integral value".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class CounterValueManager(BaseValueManager): + def __init__(self, instance, column, value): + super(CounterValueManager, self).__init__(instance, column, value) + self.value = self.value or 0 + self.previous_value = self.previous_value or 0 + + +class Counter(Integer): + """ + Stores a counter that can be incremented and decremented + """ + db_type = 'counter' + + value_manager = CounterValueManager + + def __init__(self, + index=False, + db_field=None, + required=False): + super(Counter, self).__init__( + primary_key=False, + partition_key=False, + index=index, + db_field=db_field, + default=0, + required=required, + ) + + +class DateTime(Column): + """ + Stores a datetime value + """ + db_type = 'timestamp' + + truncate_microseconds = False + """ + Set this ``True`` to have model instances truncate the date, quantizing it in the same way it will be in the database. + This allows equality comparison between assigned values and values read back from the database:: + + DateTime.truncate_microseconds = True + assert Model.create(id=0, d=datetime.utcnow()) == Model.objects(id=0).first() + + Defaults to ``False`` to preserve legacy behavior. May change in the future. + """ + + def to_python(self, value): + if value is None: + return + if isinstance(value, datetime): + if DateTime.truncate_microseconds: + us = value.microsecond + truncated_us = us // 1000 * 1000 + return value - timedelta(microseconds=us - truncated_us) + else: + return value + elif isinstance(value, date): + return datetime(*(value.timetuple()[:6])) + + return datetime.fromtimestamp(value, tz=timezone.utc).replace(tzinfo=None) + + def to_database(self, value): + value = super(DateTime, self).to_database(value) + if value is None: + return + if not isinstance(value, datetime): + if isinstance(value, date): + value = datetime(value.year, value.month, value.day) + else: + raise ValidationError("{0} '{1}' is not a datetime object".format(self.column_name, value)) + epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo) + offset = get_total_seconds(epoch.tzinfo.utcoffset(epoch)) if epoch.tzinfo else 0 + + return int((get_total_seconds(value - epoch) - offset) * 1000) + + +class Date(Column): + """ + Stores a simple date, with no time-of-day + + .. versionchanged:: 2.6.0 + + removed overload of Date and DateTime. DateTime is a drop-in replacement for legacy models + + requires C* 2.2+ and protocol v4+ + """ + db_type = 'date' + + def to_database(self, value): + if value is None: + return + + # need to translate to int version because some dates are not representable in + # string form (datetime limitation) + d = value if isinstance(value, util.Date) else util.Date(value) + return d.days_from_epoch + SimpleDateType.EPOCH_OFFSET_DAYS + + def to_python(self, value): + if value is None: + return + if isinstance(value, util.Date): + return value + if isinstance(value, datetime): + value = value.date() + return util.Date(value) + +class Time(Column): + """ + Stores a timezone-naive time-of-day, with nanosecond precision + + .. versionadded:: 2.6.0 + + requires C* 2.2+ and protocol v4+ + """ + db_type = 'time' + + def to_database(self, value): + value = super(Time, self).to_database(value) + if value is None: + return + # str(util.Time) yields desired CQL encoding + return value if isinstance(value, util.Time) else util.Time(value) + + def to_python(self, value): + value = super(Time, self).to_database(value) + if value is None: + return + if isinstance(value, util.Time): + return value + return util.Time(value) + +class Duration(Column): + """ + Stores a duration (months, days, nanoseconds) + + .. versionadded:: 3.10.0 + + requires C* 3.10+ and protocol v4+ + """ + db_type = 'duration' + + def validate(self, value): + val = super(Duration, self).validate(value) + if val is None: + return + if not isinstance(val, _Duration): + raise TypeError('{0} {1} is not a valid Duration.'.format(self.column_name, value)) + return val + + +class UUID(Column): + """ + Stores a type 1 or 4 UUID + """ + db_type = 'uuid' + + def validate(self, value): + val = super(UUID, self).validate(value) + if val is None: + return + if isinstance(val, _UUID): + return val + if isinstance(val, str): + try: + return _UUID(val) + except ValueError: + # fall-through to error + pass + raise ValidationError("{0} {1} is not a valid uuid".format( + self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class TimeUUID(UUID): + """ + UUID containing timestamp + """ + + db_type = 'timeuuid' + + +class Boolean(Column): + """ + Stores a boolean True or False value + """ + db_type = 'boolean' + + def validate(self, value): + """ Always returns a Python boolean. """ + value = super(Boolean, self).validate(value) + + if value is not None: + value = bool(value) + + return value + + def to_python(self, value): + return self.validate(value) + + +class BaseFloat(Column): + def validate(self, value): + value = super(BaseFloat, self).validate(value) + if value is None: + return + try: + return float(value) + except (TypeError, ValueError): + raise ValidationError("{0} {1} is not a valid float".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class Float(BaseFloat): + """ + Stores a single-precision floating-point value + """ + db_type = 'float' + + +class Double(BaseFloat): + """ + Stores a double-precision floating-point value + """ + db_type = 'double' + + +class Decimal(Column): + """ + Stores a variable precision decimal value + """ + db_type = 'decimal' + + def validate(self, value): + from decimal import Decimal as _Decimal + from decimal import InvalidOperation + val = super(Decimal, self).validate(value) + if val is None: + return + try: + return _Decimal(repr(val)) if isinstance(val, float) else _Decimal(val) + except InvalidOperation: + raise ValidationError("{0} '{1}' can't be coerced to decimal".format(self.column_name, val)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class BaseCollectionColumn(Column): + """ + Base Container type for collection-like columns. + + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections + """ + def __init__(self, types, **kwargs): + """ + :param types: a sequence of sub types in this collection + """ + instances = [] + for t in types: + inheritance_comparator = issubclass if isinstance(t, type) else isinstance + if not inheritance_comparator(t, Column): + raise ValidationError("%s is not a column class" % (t,)) + if t.db_type is None: + raise ValidationError("%s is an abstract type" % (t,)) + inst = t() if isinstance(t, type) else t + if isinstance(t, BaseCollectionColumn): + inst._freeze_db_type() + instances.append(inst) + + self.types = instances + super(BaseCollectionColumn, self).__init__(**kwargs) + + def validate(self, value): + value = super(BaseCollectionColumn, self).validate(value) + # It is dangerous to let collections have more than 65535. + # See: https://issues.apache.org/jira/browse/CASSANDRA-5428 + if value is not None and len(value) > 65535: + raise ValidationError("{0} Collection can't have more than 65535 elements.".format(self.column_name)) + return value + + def _val_is_null(self, val): + return not val + + def _freeze_db_type(self): + if not self.db_type.startswith('frozen'): + self.db_type = "frozen<%s>" % (self.db_type,) + + @property + def sub_types(self): + return self.types + + @property + def cql_type(self): + return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types]) + + +class Tuple(BaseCollectionColumn): + """ + Stores a fixed-length set of positional values + + http://docs.datastax.com/en/cql/3.1/cql/cql_reference/tupleType.html + """ + def __init__(self, *args, **kwargs): + """ + :param args: column types representing tuple composition + """ + if not args: + raise ValueError("Tuple must specify at least one inner type") + super(Tuple, self).__init__(args, **kwargs) + self.db_type = 'tuple<{0}>'.format(', '.join(typ.db_type for typ in self.types)) + + def validate(self, value): + val = super(Tuple, self).validate(value) + if val is None: + return + if len(val) > len(self.types): + raise ValidationError("Value %r has more fields than tuple definition (%s)" % + (val, ', '.join(t for t in self.types))) + return tuple(t.validate(v) for t, v in zip(self.types, val)) + + def to_python(self, value): + if value is None: + return tuple() + return tuple(t.to_python(v) for t, v in zip(self.types, value)) + + def to_database(self, value): + if value is None: + return + return tuple(t.to_database(v) for t, v in zip(self.types, value)) + + +class BaseContainerColumn(BaseCollectionColumn): + pass + + +class Set(BaseContainerColumn): + """ + Stores a set of unordered, unique values + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html + """ + + _python_type_hashable = False + + def __init__(self, value_type, strict=True, default=set, **kwargs): + """ + :param value_type: a column class indicating the types of the value + :param strict: sets whether non set values will be coerced to set + type on validation, or raise a validation error, defaults to True + """ + self.strict = strict + super(Set, self).__init__((value_type,), default=default, **kwargs) + self.value_col = self.types[0] + if not self.value_col._python_type_hashable: + raise ValidationError("Cannot create a Set with unhashable value type (see PYTHON-494)") + self.db_type = 'set<{0}>'.format(self.value_col.db_type) + + def validate(self, value): + val = super(Set, self).validate(value) + if val is None: + return + types = (set, util.SortedSet) if self.strict else (set, util.SortedSet, list, tuple) + if not isinstance(val, types): + if self.strict: + raise ValidationError('{0} {1} is not a set object'.format(self.column_name, val)) + else: + raise ValidationError('{0} {1} cannot be coerced to a set object'.format(self.column_name, val)) + + if None in val: + raise ValidationError("{0} None not allowed in a set".format(self.column_name)) + # TODO: stop doing this conversion because it doesn't support non-hashable collections as keys (cassandra does) + # will need to start using the cassandra.util types in the next major rev (PYTHON-494) + return set(self.value_col.validate(v) for v in val) + + def to_python(self, value): + if value is None: + return set() + return set(self.value_col.to_python(v) for v in value) + + def to_database(self, value): + if value is None: + return None + return set(self.value_col.to_database(v) for v in value) + + +class List(BaseContainerColumn): + """ + Stores a list of ordered values + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html + """ + + _python_type_hashable = False + + def __init__(self, value_type, default=list, **kwargs): + """ + :param value_type: a column class indicating the types of the value + """ + super(List, self).__init__((value_type,), default=default, **kwargs) + self.value_col = self.types[0] + self.db_type = 'list<{0}>'.format(self.value_col.db_type) + + def validate(self, value): + val = super(List, self).validate(value) + if val is None: + return + if not isinstance(val, (set, list, tuple)): + raise ValidationError('{0} {1} is not a list object'.format(self.column_name, val)) + if None in val: + raise ValidationError("{0} None is not allowed in a list".format(self.column_name)) + return [self.value_col.validate(v) for v in val] + + def to_python(self, value): + if value is None: + return [] + return [self.value_col.to_python(v) for v in value] + + def to_database(self, value): + if value is None: + return None + return [self.value_col.to_database(v) for v in value] + + +class Map(BaseContainerColumn): + """ + Stores a key -> value map (dictionary) + + https://docs.datastax.com/en/dse/6.7/cql/cql/cql_using/useMap.html + """ + + _python_type_hashable = False + + def __init__(self, key_type, value_type, default=dict, **kwargs): + """ + :param key_type: a column class indicating the types of the key + :param value_type: a column class indicating the types of the value + """ + super(Map, self).__init__((key_type, value_type), default=default, **kwargs) + self.key_col = self.types[0] + self.value_col = self.types[1] + + if not self.key_col._python_type_hashable: + raise ValidationError("Cannot create a Map with unhashable key type (see PYTHON-494)") + + self.db_type = 'map<{0}, {1}>'.format(self.key_col.db_type, self.value_col.db_type) + + def validate(self, value): + val = super(Map, self).validate(value) + if val is None: + return + if not isinstance(val, (dict, util.OrderedMap)): + raise ValidationError('{0} {1} is not a dict object'.format(self.column_name, val)) + if None in val: + raise ValidationError("{0} None is not allowed in a map".format(self.column_name)) + # TODO: stop doing this conversion because it doesn't support non-hashable collections as keys (cassandra does) + # will need to start using the cassandra.util types in the next major rev (PYTHON-494) + return dict((self.key_col.validate(k), self.value_col.validate(v)) for k, v in val.items()) + + def to_python(self, value): + if value is None: + return {} + if value is not None: + return dict((self.key_col.to_python(k), self.value_col.to_python(v)) for k, v in value.items()) + + def to_database(self, value): + if value is None: + return None + return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items()) + + +class UDTValueManager(BaseValueManager): + @property + def changed(self): + if self.explicit: + return self.value != self.previous_value + + default_value = self.column.get_default() + if not self.column._val_is_null(default_value): + return self.value != default_value + elif self.previous_value is None: + return not self.column._val_is_null(self.value) and self.value.has_changed_fields() + + return False + + def reset_previous_value(self): + if self.value is not None: + self.value.reset_changed_fields() + self.previous_value = copy(self.value) + + +class UserDefinedType(Column): + """ + User Defined Type column + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/cqlUseUDT.html + + These columns are represented by a specialization of :class:`cassandra.cqlengine.usertype.UserType`. + + Please see :ref:`user_types` for examples and discussion. + """ + + value_manager = UDTValueManager + + def __init__(self, user_type, **kwargs): + """ + :param type user_type: specifies the :class:`~.cqlengine.usertype.UserType` model of the column + """ + self.user_type = user_type + self.db_type = "frozen<%s>" % user_type.type_name() + super(UserDefinedType, self).__init__(**kwargs) + + @property + def sub_types(self): + return list(self.user_type._fields.values()) + + @property + def cql_type(self): + return UserType.make_udt_class(keyspace='', udt_name=self.user_type.type_name(), + field_names=[c.db_field_name for c in self.user_type._fields.values()], + field_types=[c.cql_type for c in self.user_type._fields.values()]) + + def validate(self, value): + val = super(UserDefinedType, self).validate(value) + if val is None: + return + val.validate() + return val + + def to_python(self, value): + if value is None: + return + + copied_value = deepcopy(value) + for name, field in self.user_type._fields.items(): + if copied_value[name] is not None or isinstance(field, BaseContainerColumn): + copied_value[name] = field.to_python(copied_value[name]) + + return copied_value + + def to_database(self, value): + if value is None: + return + + copied_value = deepcopy(value) + for name, field in self.user_type._fields.items(): + if copied_value[name] is not None or isinstance(field, BaseContainerColumn): + copied_value[name] = field.to_database(copied_value[name]) + + return copied_value + + +def resolve_udts(col_def, out_list): + for col in col_def.sub_types: + resolve_udts(col, out_list) + if isinstance(col_def, UserDefinedType): + out_list.append(col_def.user_type) + + +class _PartitionKeysToken(Column): + """ + virtual column representing token of partition columns. + Used by filter(pk__token=Token(...)) filters + """ + + def __init__(self, model): + self.partition_columns = list(model._partition_keys.values()) + super(_PartitionKeysToken, self).__init__(partition_key=True) + + @property + def db_field_name(self): + return 'token({0})'.format(', '.join(['"{0}"'.format(c.db_field_name) for c in self.partition_columns])) diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py new file mode 100644 index 0000000000..55437d7b7f --- /dev/null +++ b/cassandra/cqlengine/connection.py @@ -0,0 +1,393 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import logging +import threading + +from cassandra.cluster import Cluster, _ConfigMode, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel +from cassandra.query import SimpleStatement, dict_factory + +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine.statements import BaseCQLStatement + + +log = logging.getLogger(__name__) + +NOT_SET = _NOT_SET # required for passing timeout to Session.execute + +cluster = None +session = None + +# connections registry +DEFAULT_CONNECTION = object() +_connections = {} + +# Because type models may be registered before a connection is present, +# and because sessions may be replaced, we must register UDTs here, in order +# to have them registered when a new session is established. +udt_by_keyspace = defaultdict(dict) + + +def format_log_context(msg, connection=None, keyspace=None): + """Format log message to add keyspace and connection context""" + connection_info = connection or 'DEFAULT_CONNECTION' + + if keyspace: + msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg) + else: + msg = '[Connection: {0}] {1}'.format(connection_info, msg) + return msg + + +class UndefinedKeyspaceException(CQLEngineException): + pass + + +class Connection(object): + """CQLEngine Connection""" + + name = None + hosts = None + + consistency = None + retry_connect = False + lazy_connect = False + lazy_connect_lock = None + cluster_options = None + + cluster = None + session = None + + def __init__(self, name, hosts, consistency=None, + lazy_connect=False, retry_connect=False, cluster_options=None): + self.hosts = hosts + self.name = name + self.consistency = consistency + self.lazy_connect = lazy_connect + self.retry_connect = retry_connect + self.cluster_options = cluster_options if cluster_options else {} + self.lazy_connect_lock = threading.RLock() + + @classmethod + def from_session(cls, name, session): + instance = cls(name=name, hosts=session.hosts) + instance.cluster, instance.session = session.cluster, session + instance.setup_session() + return instance + + def setup(self): + """Set up the connection""" + global cluster, session + + if 'username' in self.cluster_options or 'password' in self.cluster_options: + raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") + + if self.lazy_connect: + return + + if 'cloud' in self.cluster_options: + if self.hosts: + log.warning("Ignoring hosts %s because a cloud config was provided.", self.hosts) + self.cluster = Cluster(**self.cluster_options) + else: + self.cluster = Cluster(self.hosts, **self.cluster_options) + + try: + self.session = self.cluster.connect() + log.debug(format_log_context("connection initialized with internally created session", connection=self.name)) + except NoHostAvailable: + if self.retry_connect: + log.warning(format_log_context("connect failed, setting up for re-attempt on first use", connection=self.name)) + self.lazy_connect = True + raise + + if DEFAULT_CONNECTION in _connections and _connections[DEFAULT_CONNECTION] == self: + cluster = _connections[DEFAULT_CONNECTION].cluster + session = _connections[DEFAULT_CONNECTION].session + + self.setup_session() + + def setup_session(self): + if self.cluster._config_mode == _ConfigMode.PROFILES: + self.cluster.profile_manager.default.row_factory = dict_factory + if self.consistency is not None: + self.cluster.profile_manager.default.consistency_level = self.consistency + else: + self.session.row_factory = dict_factory + if self.consistency is not None: + self.session.default_consistency_level = self.consistency + enc = self.session.encoder + enc.mapping[tuple] = enc.cql_encode_tuple + _register_known_types(self.session.cluster) + + def handle_lazy_connect(self): + + # if lazy_connect is False, it means the cluster is set up and ready + # No need to acquire the lock + if not self.lazy_connect: + return + + with self.lazy_connect_lock: + # lazy_connect might have been set to False by another thread while waiting the lock + # In this case, do nothing. + if self.lazy_connect: + log.debug(format_log_context("Lazy connect enabled", connection=self.name)) + self.lazy_connect = False + self.setup() + + +def register_connection(name, hosts=None, consistency=None, lazy_connect=False, + retry_connect=False, cluster_options=None, default=False, + session=None): + """ + Add a connection to the connection registry. ``hosts`` and ``session`` are + mutually exclusive, and ``consistency``, ``lazy_connect``, + ``retry_connect``, and ``cluster_options`` only work with ``hosts``. Using + ``hosts`` will create a new :class:`cassandra.cluster.Cluster` and + :class:`cassandra.cluster.Session`. + + :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`). + :param int consistency: The default :class:`~.ConsistencyLevel` for the + registered connection's new session. Default is the same as + :attr:`.Session.default_consistency_level`. For use with ``hosts`` only; + will fail when used with ``session``. + :param bool lazy_connect: True if should not connect until first use. For + use with ``hosts`` only; will fail when used with ``session``. + :param bool retry_connect: True if we should retry to connect even if there + was a connection failure initially. For use with ``hosts`` only; will + fail when used with ``session``. + :param dict cluster_options: A dict of options to be used as keyword + arguments to :class:`cassandra.cluster.Cluster`. For use with ``hosts`` + only; will fail when used with ``session``. + :param bool default: If True, set the new connection as the cqlengine + default + :param Session session: A :class:`cassandra.cluster.Session` to be used in + the created connection. + """ + + if name in _connections: + log.warning("Registering connection '{0}' when it already exists.".format(name)) + + if session is not None: + invalid_config_args = (hosts is not None or + consistency is not None or + lazy_connect is not False or + retry_connect is not False or + cluster_options is not None) + if invalid_config_args: + raise CQLEngineException( + "Session configuration arguments and 'session' argument are mutually exclusive" + ) + conn = Connection.from_session(name, session=session) + else: # use hosts argument + conn = Connection( + name, hosts=hosts, + consistency=consistency, lazy_connect=lazy_connect, + retry_connect=retry_connect, cluster_options=cluster_options + ) + conn.setup() + + _connections[name] = conn + + if default: + set_default_connection(name) + + return conn + + +def unregister_connection(name): + global cluster, session + + if name not in _connections: + return + + if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]: + del _connections[DEFAULT_CONNECTION] + cluster = None + session = None + + conn = _connections[name] + if conn.cluster: + conn.cluster.shutdown() + del _connections[name] + log.debug("Connection '{0}' has been removed from the registry.".format(name)) + + +def set_default_connection(name): + global cluster, session + + if name not in _connections: + raise CQLEngineException("Connection '{0}' doesn't exist.".format(name)) + + log.debug("Connection '{0}' has been set as default.".format(name)) + _connections[DEFAULT_CONNECTION] = _connections[name] + cluster = _connections[name].cluster + session = _connections[name].session + + +def get_connection(name=None): + + if not name: + name = DEFAULT_CONNECTION + + if name not in _connections: + raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name)) + + conn = _connections[name] + conn.handle_lazy_connect() + + return conn + + +def default(): + """ + Configures the default connection to localhost, using the driver defaults + (except for row_factory) + """ + + try: + conn = get_connection() + if conn.session: + log.warning("configuring new default connection for cqlengine when one was already set") + except: + pass + + register_connection('default', hosts=None, default=True) + + log.debug("cqlengine connection initialized with default session to localhost") + + +def set_session(s): + """ + Configures the default connection with a preexisting :class:`cassandra.cluster.Session` + + Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``. + This may be relaxed in the future + """ + + try: + conn = get_connection() + except CQLEngineException: + # no default connection set; initialize one + register_connection('default', session=s, default=True) + conn = get_connection() + else: + if conn.session: + log.warning("configuring new default session for cqlengine when one was already set") + + if not any([ + s.cluster.profile_manager.default.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.PROFILES, _ConfigMode.UNCOMMITTED], + s.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.LEGACY, _ConfigMode.UNCOMMITTED], + ]): + raise CQLEngineException("Failed to initialize: row_factory must be 'dict_factory'") + + conn.session = s + conn.cluster = s.cluster + + # Set default keyspace from given session's keyspace + if conn.session.keyspace: + from cassandra.cqlengine import models + models.DEFAULT_KEYSPACE = conn.session.keyspace + + conn.setup_session() + + log.debug("cqlengine default connection initialized with %s", s) + + +# TODO next major: if a cloud config is specified in kwargs, hosts will be ignored. +# This function should be refactored to reflect this change. PYTHON-1265 +def setup( + hosts, + default_keyspace, + consistency=None, + lazy_connect=False, + retry_connect=False, + **kwargs): + """ + Set up the driver connection used by the mapper + + :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`) + :param str default_keyspace: The default keyspace to use + :param int consistency: The global default :class:`~.ConsistencyLevel` - default is the same as :attr:`.Session.default_consistency_level` + :param bool lazy_connect: True if should not connect until first use + :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially + :param kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster` + """ + + from cassandra.cqlengine import models + models.DEFAULT_KEYSPACE = default_keyspace + + register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect, + retry_connect=retry_connect, cluster_options=kwargs, default=True) + + +def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None): + + conn = get_connection(connection) + + if not conn.session: + raise CQLEngineException("It is required to setup() cqlengine before executing queries") + + if isinstance(query, SimpleStatement): + pass # + elif isinstance(query, BaseCQLStatement): + params = query.get_context() + query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size) + elif isinstance(query, str): + query = SimpleStatement(query, consistency_level=consistency_level) + log.debug(format_log_context('Query: {}, Params: {}'.format(query.query_string, params), connection=connection)) + + result = conn.session.execute(query, params, timeout=timeout) + + return result + + +def get_session(connection=None): + conn = get_connection(connection) + return conn.session + + +def get_cluster(connection=None): + conn = get_connection(connection) + if not conn.cluster: + raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__) + return conn.cluster + + +def register_udt(keyspace, type_name, klass, connection=None): + udt_by_keyspace[keyspace][type_name] = klass + + try: + cluster = get_cluster(connection) + except CQLEngineException: + cluster = None + + if cluster: + try: + cluster.register_user_type(keyspace, type_name, klass) + except UserTypeDoesNotExist: + pass # new types are covered in management sync functions + + +def _register_known_types(cluster): + from cassandra.cqlengine import models + for ks_name, name_type_map in udt_by_keyspace.items(): + for type_name, klass in name_type_map.items(): + try: + cluster.register_user_type(ks_name or models.DEFAULT_KEYSPACE, type_name, klass) + except UserTypeDoesNotExist: + pass # new types are covered in management sync functions diff --git a/cassandra/cqlengine/functions.py b/cassandra/cqlengine/functions.py new file mode 100644 index 0000000000..69bdc3feb4 --- /dev/null +++ b/cassandra/cqlengine/functions.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from cassandra.cqlengine import UnicodeMixin, ValidationError + +def get_total_seconds(td): + return td.total_seconds() + +class QueryValue(UnicodeMixin): + """ + Base class for query filter values. Subclasses of these classes can + be passed into .filter() keyword args + """ + + format_string = '%({0})s' + + def __init__(self, value): + self.value = value + self.context_id = None + + def __unicode__(self): + return self.format_string.format(self.context_id) + + def set_context_id(self, ctx_id): + self.context_id = ctx_id + + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.value + + +class BaseQueryFunction(QueryValue): + """ + Base class for filtering functions. Subclasses of these classes can + be passed into .filter() and will be translated into CQL functions in + the resulting query + """ + pass + + +class TimeUUIDQueryFunction(BaseQueryFunction): + + def __init__(self, value): + """ + :param value: the time to create bounding time uuid from + :type value: datetime + """ + if not isinstance(value, datetime): + raise ValidationError('datetime instance is required') + super(TimeUUIDQueryFunction, self).__init__(value) + + def to_database(self, val): + epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo) + offset = get_total_seconds(epoch.tzinfo.utcoffset(epoch)) if epoch.tzinfo else 0 + return int((get_total_seconds(val - epoch) - offset) * 1000) + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.to_database(self.value) + + +class MinTimeUUID(TimeUUIDQueryFunction): + """ + return a fake timeuuid corresponding to the smallest possible timeuuid for the given timestamp + + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun + """ + format_string = 'MinTimeUUID(%({0})s)' + + +class MaxTimeUUID(TimeUUIDQueryFunction): + """ + return a fake timeuuid corresponding to the largest possible timeuuid for the given timestamp + + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun + """ + format_string = 'MaxTimeUUID(%({0})s)' + + +class Token(BaseQueryFunction): + """ + compute the token for a given partition key + + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun + """ + def __init__(self, *values): + if len(values) == 1 and isinstance(values[0], (list, tuple)): + values = values[0] + super(Token, self).__init__(values) + self._columns = None + + def set_columns(self, columns): + self._columns = columns + + def get_context_size(self): + return len(self.value) + + def __unicode__(self): + token_args = ', '.join('%({0})s'.format(self.context_id + i) for i in range(self.get_context_size())) + return "token({0})".format(token_args) + + def update_context(self, ctx): + for i, (col, val) in enumerate(zip(self._columns, self.value)): + ctx[str(self.context_id + i)] = col.to_database(val) diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py new file mode 100644 index 0000000000..66b391b714 --- /dev/null +++ b/cassandra/cqlengine/management.py @@ -0,0 +1,549 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import json +import logging +import os +import warnings +from itertools import product + +from cassandra import metadata +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine import columns, query +from cassandra.cqlengine.connection import execute, get_cluster, format_log_context +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.named import NamedTable +from cassandra.cqlengine.usertype import UserType + +CQLENG_ALLOW_SCHEMA_MANAGEMENT = 'CQLENG_ALLOW_SCHEMA_MANAGEMENT' + +Field = namedtuple('Field', ['name', 'type']) + +log = logging.getLogger(__name__) + +# system keyspaces +schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') + + +def _get_context(keyspaces, connections): + """Return all the execution contexts""" + + if keyspaces: + if not isinstance(keyspaces, (list, tuple)): + raise ValueError('keyspaces must be a list or a tuple.') + + if connections: + if not isinstance(connections, (list, tuple)): + raise ValueError('connections must be a list or a tuple.') + + keyspaces = keyspaces if keyspaces else [None] + connections = connections if connections else [None] + + return product(connections, keyspaces) + + +def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None): + """ + Creates a keyspace with SimpleStrategy for replica placement + + If the keyspace already exists, it will not be modified. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + :param str name: name of keyspace to create + :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` + :param bool durable_writes: Write log is bypassed if set to False + :param list connections: List of connection names + """ + _create_keyspace(name, durable_writes, 'SimpleStrategy', + {'replication_factor': replication_factor}, connections=connections) + + +def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True, connections=None): + """ + Creates a keyspace with NetworkTopologyStrategy for replica placement + + If the keyspace already exists, it will not be modified. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + :param str name: name of keyspace to create + :param dict dc_replication_map: map of dc_names: replication_factor + :param bool durable_writes: Write log is bypassed if set to False + :param list connections: List of connection names + """ + _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map, connections=connections) + + +def _create_keyspace(name, durable_writes, strategy_class, strategy_options, connections=None): + if not _allow_schema_modification(): + return + + if connections: + if not isinstance(connections, (list, tuple)): + raise ValueError('Connections must be a list or a tuple.') + + def __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=None): + cluster = get_cluster(connection) + + if name not in cluster.metadata.keyspaces: + log.info(format_log_context("Creating keyspace %s", connection=connection), name) + ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + execute(ks_meta.as_cql_query(), connection=connection) + else: + log.info(format_log_context("Not creating keyspace %s because it already exists", connection=connection), name) + + if connections: + for connection in connections: + __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=connection) + else: + __create_keyspace(name, durable_writes, strategy_class, strategy_options) + + +def drop_keyspace(name, connections=None): + """ + Drops a keyspace, if it exists. + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + :param str name: name of keyspace to drop + :param list connections: List of connection names + """ + if not _allow_schema_modification(): + return + + if connections: + if not isinstance(connections, (list, tuple)): + raise ValueError('Connections must be a list or a tuple.') + + def _drop_keyspace(name, connection=None): + cluster = get_cluster(connection) + if name in cluster.metadata.keyspaces: + execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)), connection=connection) + + if connections: + for connection in connections: + _drop_keyspace(name, connection) + else: + _drop_keyspace(name) + +def _get_index_name_by_column(table, column_name): + """ + Find the index name for a given table and column. + """ + protected_name = metadata.protect_name(column_name) + possible_index_values = [protected_name, "values(%s)" % protected_name] + for index_metadata in table.indexes.values(): + options = dict(index_metadata.index_options) + if options.get('target') in possible_index_values: + return index_metadata.name + + +def sync_table(model, keyspaces=None, connections=None): + """ + Inspects the model and creates / updates the corresponding table and columns. + + If `keyspaces` is specified, the table will be synched for all specified keyspaces. + Note that the `Model.__keyspace__` is ignored in that case. + + If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case. + If not specified, it will try to get the connection from the Model. + + Any User Defined Types used in the table are implicitly synchronized. + + This function can only add fields that are not part of the primary key. + + Note that the attributes removed from the model are not deleted on the database. + They become effectively ignored by (will not show up on) the model. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + """ + + context = _get_context(keyspaces, connections) + for connection, keyspace in context: + with query.ContextQuery(model, keyspace=keyspace) as m: + _sync_table(m, connection=connection) + + +def _sync_table(model, connection=None): + if not _allow_schema_modification(): + return + + if not issubclass(model, Model): + raise CQLEngineException("Models must be derived from base Model.") + + if model.__abstract__: + raise CQLEngineException("cannot create table from abstract model") + + cf_name = model.column_family_name() + raw_cf_name = model._raw_column_family_name() + + ks_name = model._get_keyspace() + connection = connection or model._get_connection() + + cluster = get_cluster(connection) + + try: + keyspace = cluster.metadata.keyspaces[ks_name] + except KeyError: + msg = format_log_context("Keyspace '{0}' for model {1} does not exist.", connection=connection) + raise CQLEngineException(msg.format(ks_name, model)) + + tables = keyspace.tables + + syncd_types = set() + for col in model._columns.values(): + udts = [] + columns.resolve_udts(col, udts) + for udt in [u for u in udts if u not in syncd_types]: + _sync_type(ks_name, udt, syncd_types, connection=connection) + + if raw_cf_name not in tables: + log.debug(format_log_context("sync_table creating new table %s", keyspace=ks_name, connection=connection), cf_name) + qs = _get_create_table(model) + + try: + execute(qs, connection=connection) + except CQLEngineException as ex: + # 1.2 doesn't return cf names, so we have to examine the exception + # and ignore if it says the column family already exists + if "Cannot add already existing column family" not in str(ex): + raise + else: + log.debug(format_log_context("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name) + table_meta = tables[raw_cf_name] + + _validate_pk(model, table_meta) + + table_columns = table_meta.columns + model_fields = set() + + for model_name, col in model._columns.items(): + db_name = col.db_field_name + model_fields.add(db_name) + if db_name in table_columns: + col_meta = table_columns[db_name] + if col_meta.cql_type != col.db_type: + msg = format_log_context('Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' + ' Model should be updated.', keyspace=ks_name, connection=connection) + msg = msg.format(cf_name, db_name, col_meta.cql_type, col.db_type) + warnings.warn(msg) + log.warning(msg) + + continue + + if col.primary_key: + msg = format_log_context("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", keyspace=ks_name, connection=connection) + raise CQLEngineException(msg.format(model_name, db_name, cf_name)) + + query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def()) + execute(query, connection=connection) + + db_fields_not_in_model = model_fields.symmetric_difference(table_columns) + if db_fields_not_in_model: + msg = format_log_context("Table {0} has fields not referenced by model: {1}", keyspace=ks_name, connection=connection) + log.info(msg.format(cf_name, db_fields_not_in_model)) + + _update_options(model, connection=connection) + + table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name] + + indexes = [c for n, c in model._columns.items() if c.index] + + # TODO: support multiple indexes in C* 3.0+ + for column in indexes: + index_name = _get_index_name_by_column(table, column.db_field_name) + if index_name: + continue + + qs = ['CREATE INDEX'] + qs += ['ON {0}'.format(cf_name)] + qs += ['("{0}")'.format(column.db_field_name)] + qs = ' '.join(qs) + execute(qs, connection=connection) + + +def _validate_pk(model, table_meta): + model_partition = [c.db_field_name for c in model._partition_keys.values()] + meta_partition = [c.name for c in table_meta.partition_key] + model_clustering = [c.db_field_name for c in model._clustering_keys.values()] + meta_clustering = [c.name for c in table_meta.clustering_key] + + if model_partition != meta_partition or model_clustering != meta_clustering: + def _pk_string(partition, clustering): + return "PRIMARY KEY (({0}){1})".format(', '.join(partition), ', ' + ', '.join(clustering) if clustering else '') + raise CQLEngineException("Model {0} PRIMARY KEY composition does not match existing table {1}. " + "Model: {2}; Table: {3}. " + "Update model or drop the table.".format(model, model.column_family_name(), + _pk_string(model_partition, model_clustering), + _pk_string(meta_partition, meta_clustering))) + + +def sync_type(ks_name, type_model, connection=None): + """ + Inspects the type_model and creates / updates the corresponding type. + + Note that the attributes removed from the type_model are not deleted on the database (this operation is not supported). + They become effectively ignored by (will not show up on) the type_model. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + """ + if not _allow_schema_modification(): + return + + if not issubclass(type_model, UserType): + raise CQLEngineException("Types must be derived from base UserType.") + + _sync_type(ks_name, type_model, connection=connection) + + +def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None): + + syncd_sub_types = omit_subtypes or set() + for field in type_model._fields.values(): + udts = [] + columns.resolve_udts(field, udts) + for udt in [u for u in udts if u not in syncd_sub_types]: + _sync_type(ks_name, udt, syncd_sub_types, connection=connection) + syncd_sub_types.add(udt) + + type_name = type_model.type_name() + type_name_qualified = "%s.%s" % (ks_name, type_name) + + cluster = get_cluster(connection) + + keyspace = cluster.metadata.keyspaces[ks_name] + defined_types = keyspace.user_types + + if type_name not in defined_types: + log.debug(format_log_context("sync_type creating new type %s", keyspace=ks_name, connection=connection), type_name_qualified) + cql = get_create_type(type_model, ks_name) + execute(cql, connection=connection) + cluster.refresh_user_type_metadata(ks_name, type_name) + type_model.register_for_keyspace(ks_name, connection=connection) + else: + type_meta = defined_types[type_name] + defined_fields = type_meta.field_names + model_fields = set() + for field in type_model._fields.values(): + model_fields.add(field.db_field_name) + if field.db_field_name not in defined_fields: + execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()), connection=connection) + else: + field_type = type_meta.field_types[defined_fields.index(field.db_field_name)] + if field_type != field.db_type: + msg = format_log_context('Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' + ' UserType should be updated.', keyspace=ks_name, connection=connection) + msg = msg.format(type_name_qualified, field.db_field_name, field_type, field.db_type) + warnings.warn(msg) + log.warning(msg) + + type_model.register_for_keyspace(ks_name, connection=connection) + + if len(defined_fields) == len(model_fields): + log.info(format_log_context("Type %s did not require synchronization", keyspace=ks_name, connection=connection), type_name_qualified) + return + + db_fields_not_in_model = model_fields.symmetric_difference(defined_fields) + if db_fields_not_in_model: + msg = format_log_context("Type %s has fields not referenced by model: %s", keyspace=ks_name, connection=connection) + log.info(msg, type_name_qualified, db_fields_not_in_model) + + +def get_create_type(type_model, keyspace): + type_meta = metadata.UserType(keyspace, + type_model.type_name(), + (f.db_field_name for f in type_model._fields.values()), + (v.db_type for v in type_model._fields.values())) + return type_meta.as_cql_query() + + +def _get_create_table(model): + ks_table_name = model.column_family_name() + query_strings = ['CREATE TABLE {0}'.format(ks_table_name)] + + # add column types + pkeys = [] # primary keys + ckeys = [] # clustering keys + qtypes = [] # field types + + def add_column(col): + s = col.get_column_def() + if col.primary_key: + keys = (pkeys if col.partition_key else ckeys) + keys.append('"{0}"'.format(col.db_field_name)) + qtypes.append(s) + + for name, col in model._columns.items(): + add_column(col) + + qtypes.append('PRIMARY KEY (({0}){1})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) + + query_strings += ['({0})'.format(', '.join(qtypes))] + + property_strings = [] + + _order = ['"{0}" {1}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] + if _order: + property_strings.append('CLUSTERING ORDER BY ({0})'.format(', '.join(_order))) + + # options strings use the V3 format, which matches CQL more closely and does not require mapping + property_strings += metadata.TableMetadataV3._make_option_strings(model.__options__ or {}) + + if property_strings: + query_strings += ['WITH {0}'.format(' AND '.join(property_strings))] + + return ' '.join(query_strings) + + +def _get_table_metadata(model, connection=None): + # returns the table as provided by the native driver for a given model + cluster = get_cluster(connection) + ks = model._get_keyspace() + table = model._raw_column_family_name() + table = cluster.metadata.keyspaces[ks].tables[table] + return table + + +def _options_map_from_strings(option_strings): + # converts options strings to a mapping to strings or dict + options = {} + for option in option_strings: + name, value = option.split('=') + i = value.find('{') + if i >= 0: + value = value[i:value.rfind('}') + 1].replace("'", '"') # from cql single quotes to json double; not aware of any values that would be escaped right now + value = json.loads(value) + else: + value = value.strip() + options[name.strip()] = value + return options + + +def _update_options(model, connection=None): + """Updates the table options for the given model if necessary. + + :param model: The model to update. + :param connection: Name of the connection to use + + :return: `True`, if the options were modified in Cassandra, + `False` otherwise. + :rtype: bool + """ + ks_name = model._get_keyspace() + msg = format_log_context("Checking %s for option differences", keyspace=ks_name, connection=connection) + log.debug(msg, model) + model_options = model.__options__ or {} + + table_meta = _get_table_metadata(model, connection=connection) + # go to CQL string first to normalize meta from different versions + existing_option_strings = set(table_meta._make_option_strings(table_meta.options)) + existing_options = _options_map_from_strings(existing_option_strings) + model_option_strings = metadata.TableMetadataV3._make_option_strings(model_options) + model_options = _options_map_from_strings(model_option_strings) + + update_options = {} + for name, value in model_options.items(): + try: + existing_value = existing_options[name] + except KeyError: + msg = format_log_context("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection) + raise KeyError(msg % (name, existing_options.keys())) + if isinstance(existing_value, str): + if value != existing_value: + update_options[name] = value + else: + try: + for k, v in value.items(): + if existing_value[k] != v: + update_options[name] = value + break + except KeyError: + update_options[name] = value + + if update_options: + options = ' AND '.join(metadata.TableMetadataV3._make_option_strings(update_options)) + query = "ALTER TABLE {0} WITH {1}".format(model.column_family_name(), options) + execute(query, connection=connection) + return True + + return False + + +def drop_table(model, keyspaces=None, connections=None): + """ + Drops the table indicated by the model, if it exists. + + If `keyspaces` is specified, the table will be dropped for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. + + If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case. + If not specified, it will try to get the connection from the Model. + + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + """ + + context = _get_context(keyspaces, connections) + for connection, keyspace in context: + with query.ContextQuery(model, keyspace=keyspace) as m: + _drop_table(m, connection=connection) + + +def _drop_table(model, connection=None): + if not _allow_schema_modification(): + return + + connection = connection or model._get_connection() + + # don't try to delete non existent tables + meta = get_cluster(connection).metadata + + ks_name = model._get_keyspace() + raw_cf_name = model._raw_column_family_name() + + try: + meta.keyspaces[ks_name].tables[raw_cf_name] + execute('DROP TABLE {0};'.format(model.column_family_name()), connection=connection) + except KeyError: + pass + + +def _allow_schema_modification(): + if not os.getenv(CQLENG_ALLOW_SCHEMA_MANAGEMENT): + msg = CQLENG_ALLOW_SCHEMA_MANAGEMENT + " environment variable is not set. Future versions of this package will require this variable to enable management functions." + warnings.warn(msg) + log.warning(msg) + + return True diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py new file mode 100644 index 0000000000..f0f5a207ec --- /dev/null +++ b/cassandra/cqlengine/models.py @@ -0,0 +1,1088 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from warnings import warn + +from cassandra.cqlengine import CQLEngineException, ValidationError +from cassandra.cqlengine import columns +from cassandra.cqlengine import connection +from cassandra.cqlengine import query +from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist +from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned +from cassandra.metadata import protect_name +from cassandra.util import OrderedDict + +log = logging.getLogger(__name__) + + +def _clone_model_class(model, attrs): + new_type = type(model.__name__, (model,), attrs) + try: + new_type.__abstract__ = model.__abstract__ + new_type.__discriminator_value__ = model.__discriminator_value__ + new_type.__default_ttl__ = model.__default_ttl__ + except AttributeError: + pass + return new_type + + +class ModelException(CQLEngineException): + pass + + +class ModelDefinitionException(ModelException): + pass + + +class PolymorphicModelException(ModelException): + pass + + +class UndefinedKeyspaceWarning(Warning): + pass + +DEFAULT_KEYSPACE = None + + +class hybrid_classmethod(object): + """ + Allows a method to behave as both a class method and + normal instance method depending on how it's called + """ + def __init__(self, clsmethod, instmethod): + self.clsmethod = clsmethod + self.instmethod = instmethod + + def __get__(self, instance, owner): + if instance is None: + return self.clsmethod.__get__(owner, owner) + else: + return self.instmethod.__get__(instance, owner) + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + """ + raise NotImplementedError + + +class QuerySetDescriptor(object): + """ + returns a fresh queryset for the given model + it's declared on everytime it's accessed + """ + + def __get__(self, obj, model): + """ :rtype: ModelQuerySet """ + if model.__abstract__: + raise CQLEngineException('cannot execute queries against abstract models') + queryset = model.__queryset__(model) + + # if this is a concrete polymorphic model, and the discriminator + # key is an indexed column, add a filter clause to only return + # logical rows of the proper type + if model._is_polymorphic and not model._is_polymorphic_base: + name, column = model._discriminator_column_name, model._discriminator_column + if column.partition_key or column.index: + # look for existing poly types + return queryset.filter(**{name: model.__discriminator_value__}) + + return queryset + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + + :rtype: ModelQuerySet + """ + raise NotImplementedError + + +class ConditionalDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + def conditional_setter(*prepared_conditional, **unprepared_conditionals): + if len(prepared_conditional) > 0: + conditionals = prepared_conditional[0] + else: + conditionals = instance.objects.iff(**unprepared_conditionals)._conditional + instance._conditional = conditionals + return instance + + return conditional_setter + qs = model.__queryset__(model) + + def conditional_setter(**unprepared_conditionals): + conditionals = model.objects.iff(**unprepared_conditionals)._conditional + qs._conditional = conditionals + return qs + return conditional_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class TTLDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + # instance = copy.deepcopy(instance) + # instance method + def ttl_setter(ts): + instance._ttl = ts + return instance + return ttl_setter + + qs = model.__queryset__(model) + + def ttl_setter(ts): + qs._ttl = ts + return qs + + return ttl_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class TimestampDescriptor(object): + """ + returns a query set descriptor with a timestamp specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def timestamp_setter(ts): + instance._timestamp = ts + return instance + return timestamp_setter + + return model.objects.timestamp + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class IfNotExistsDescriptor(object): + """ + return a query set descriptor with an if_not_exists flag specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def ifnotexists_setter(ife=True): + instance._if_not_exists = ife + return instance + return ifnotexists_setter + + return model.objects.if_not_exists + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class IfExistsDescriptor(object): + """ + return a query set descriptor with an if_exists flag specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def ifexists_setter(ife=True): + instance._if_exists = ife + return instance + return ifexists_setter + + return model.objects.if_exists + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class ConsistencyDescriptor(object): + """ + returns a query set descriptor if called on Class, instance if it was an instance call + """ + def __get__(self, instance, model): + if instance: + # instance = copy.deepcopy(instance) + def consistency_setter(consistency): + instance.__consistency__ = consistency + return instance + return consistency_setter + + qs = model.__queryset__(model) + + def consistency_setter(consistency): + qs._consistency = consistency + return qs + + return consistency_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class UsingDescriptor(object): + """ + return a query set descriptor with a connection context specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def using_setter(connection=None): + if connection: + instance._connection = connection + return instance + return using_setter + + return model.objects.using + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class ColumnQueryEvaluator(query.AbstractQueryableColumn): + """ + Wraps a column and allows it to be used in comparator + expressions, returning query operators + + ie: + Model.column == 5 + """ + + def __init__(self, column): + self.column = column + + def __unicode__(self): + return self.column.db_field_name + + def _get_column(self): + return self.column + + +class ColumnDescriptor(object): + """ + Handles the reading and writing of column values to and from + a model instance's value manager, as well as creating + comparator queries + """ + + def __init__(self, column): + """ + :param column: + :type column: columns.Column + :return: + """ + self.column = column + self.query_evaluator = ColumnQueryEvaluator(self.column) + + def __get__(self, instance, owner): + """ + Returns either the value or column, depending + on if an instance is provided or not + + :param instance: the model instance + :type instance: Model + """ + try: + return instance._values[self.column.column_name].getval() + except AttributeError: + return self.query_evaluator + + def __set__(self, instance, value): + """ + Sets the value on an instance, raises an exception with classes + TODO: use None instance to create update statements + """ + if instance: + return instance._values[self.column.column_name].setval(value) + else: + raise AttributeError('cannot reassign column values') + + def __delete__(self, instance): + """ + Sets the column value to None, if possible + """ + if instance: + if self.column.can_delete: + instance._values[self.column.column_name].delval() + else: + raise AttributeError('cannot delete {0} columns'.format(self.column.column_name)) + + +class BaseModel(object): + """ + The base model class, don't inherit from this, inherit from Model, defined below + """ + + class DoesNotExist(_DoesNotExist): + pass + + class MultipleObjectsReturned(_MultipleObjectsReturned): + pass + + objects = QuerySetDescriptor() + ttl = TTLDescriptor() + consistency = ConsistencyDescriptor() + iff = ConditionalDescriptor() + + # custom timestamps, see USING TIMESTAMP X + timestamp = TimestampDescriptor() + + if_not_exists = IfNotExistsDescriptor() + + if_exists = IfExistsDescriptor() + + using = UsingDescriptor() + + # _len is lazily created by __len__ + + __table_name__ = None + + __table_name_case_sensitive__ = False + + __keyspace__ = None + + __connection__ = None + + __discriminator_value__ = None + + __options__ = None + + __compute_routing_key__ = True + + # the queryset class used for this class + __queryset__ = query.ModelQuerySet + __dmlquery__ = query.DMLQuery + + __consistency__ = None # can be set per query + + _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) + + _if_not_exists = False # optional if_not_exists flag to check existence before insertion + + _if_exists = False # optional if_exists flag to check existence before update + + _table_name = None # used internally to cache a derived table name + + _connection = None + + def __init__(self, **values): + self._ttl = None + self._timestamp = None + self._conditional = None + self._batch = None + self._timeout = connection.NOT_SET + self._is_persisted = False + self._connection = None + + self._values = {} + for name, column in self._columns.items(): + # Set default values on instantiation. Thanks to this, we don't have + # to wait any longer for a call to validate() to have CQLengine set + # default columns values. + column_default = column.get_default() if column.has_default else None + value = values.get(name, column_default) + if value is not None or isinstance(column, columns.BaseContainerColumn): + value = column.to_python(value) + value_mngr = column.value_manager(self, column, value) + value_mngr.explicit = name in values + self._values[name] = value_mngr + + def __repr__(self): + return '{0}({1})'.format(self.__class__.__name__, + ', '.join('{0}={1!r}'.format(k, getattr(self, k)) + for k in self._defined_columns.keys() + if k != self._discriminator_column_name)) + + def __str__(self): + """ + Pretty printing of models by their primary key + """ + return '{0} <{1}>'.format(self.__class__.__name__, + ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) + + @classmethod + def _routing_key_from_values(cls, pk_values, protocol_version): + return cls._key_serializer(pk_values, protocol_version) + + @classmethod + def _discover_polymorphic_submodels(cls): + if not cls._is_polymorphic_base: + raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes') + + def _discover(klass): + if not klass._is_polymorphic_base and klass.__discriminator_value__ is not None: + cls._discriminator_map[klass.__discriminator_value__] = klass + for subklass in klass.__subclasses__(): + _discover(subklass) + _discover(cls) + + @classmethod + def _get_model_by_discriminator_value(cls, key): + if not cls._is_polymorphic_base: + raise ModelException('_get_model_by_discriminator_value can only be called on polymorphic base classes') + return cls._discriminator_map.get(key) + + @classmethod + def _construct_instance(cls, values): + """ + method used to construct instances from query results + this is where polymorphic deserialization occurs + """ + # we're going to take the values, which is from the DB as a dict + # and translate that into our local fields + # the db_map is a db_field -> model field map + if cls._db_map: + values = dict((cls._db_map.get(k, k), v) for k, v in values.items()) + + if cls._is_polymorphic: + disc_key = values.get(cls._discriminator_column_name) + + if disc_key is None: + raise PolymorphicModelException('discriminator value was not found in values') + + poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base + + klass = poly_base._get_model_by_discriminator_value(disc_key) + if klass is None: + poly_base._discover_polymorphic_submodels() + klass = poly_base._get_model_by_discriminator_value(disc_key) + if klass is None: + raise PolymorphicModelException( + 'unrecognized discriminator column {0} for class {1}'.format(disc_key, poly_base.__name__) + ) + + if not issubclass(klass, cls): + raise PolymorphicModelException( + '{0} is not a subclass of {1}'.format(klass.__name__, cls.__name__) + ) + + values = dict((k, v) for k, v in values.items() if k in klass._columns.keys()) + + else: + klass = cls + + instance = klass(**values) + instance._set_persisted(force=True) + return instance + + def _set_persisted(self, force=False): + # ensure we don't modify to any values not affected by the last save/update + for v in [v for v in self._values.values() if v.changed or force]: + v.reset_previous_value() + v.explicit = False + self._is_persisted = True + + def _can_update(self): + """ + Called by the save function to check if this should be + persisted with update or insert + + :return: + """ + if not self._is_persisted: + return False + + return all([not self._values[k].changed for k in self._primary_keys]) + + @classmethod + def _get_keyspace(cls): + """ + Returns the manual keyspace, if set, otherwise the default keyspace + """ + return cls.__keyspace__ or DEFAULT_KEYSPACE + + @classmethod + def _get_column(cls, name): + """ + Returns the column matching the given name, raising a key error if + it doesn't exist + + :param name: the name of the column to return + :rtype: Column + """ + return cls._columns[name] + + @classmethod + def _get_column_by_db_name(cls, name): + """ + Returns the column, mapped by db_field name + """ + return cls._columns.get(cls._db_map.get(name, name)) + + def __eq__(self, other): + if self.__class__ != other.__class__: + return False + + # check attribute keys + keys = set(self._columns.keys()) + other_keys = set(other._columns.keys()) + if keys != other_keys: + return False + + return all(getattr(self, key, None) == getattr(other, key, None) for key in other_keys) + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def column_family_name(cls, include_keyspace=True): + """ + Returns the column family name if it's been defined + otherwise, it creates it from the module and class name + """ + cf_name = protect_name(cls._raw_column_family_name()) + if include_keyspace: + keyspace = cls._get_keyspace() + if not keyspace: + raise CQLEngineException("Model keyspace is not set and no default is available. Set model keyspace or setup connection before attempting to generate a query.") + return '{0}.{1}'.format(protect_name(keyspace), cf_name) + + return cf_name + + + @classmethod + def _raw_column_family_name(cls): + if not cls._table_name: + if cls.__table_name__: + if cls.__table_name_case_sensitive__: + warn("Model __table_name_case_sensitive__ will be removed in 4.0.", PendingDeprecationWarning) + cls._table_name = cls.__table_name__ + else: + table_name = cls.__table_name__.lower() + if cls.__table_name__ != table_name: + warn(("Model __table_name__ will be case sensitive by default in 4.0. " + "You should fix the __table_name__ value of the '{0}' model.").format(cls.__name__)) + cls._table_name = table_name + else: + if cls._is_polymorphic and not cls._is_polymorphic_base: + cls._table_name = cls._polymorphic_base._raw_column_family_name() + else: + camelcase = re.compile(r'([a-z])([A-Z])') + ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2).lower()), s) + + cf_name = ccase(cls.__name__) + # trim to less than 48 characters or cassandra will complain + cf_name = cf_name[-48:] + cf_name = cf_name.lower() + cf_name = re.sub(r'^_+', '', cf_name) + cls._table_name = cf_name + + return cls._table_name + + def _set_column_value(self, name, value): + """Function to change a column value without changing the value manager states""" + self._values[name].value = value # internal assignement, skip the main setter + + def validate(self): + """ + Cleans and validates the field values + """ + for name, col in self._columns.items(): + v = getattr(self, name) + if v is None and not self._values[name].explicit and col.has_default: + v = col.get_default() + val = col.validate(v) + self._set_column_value(name, val) + + # Let an instance be used like a dict of its columns keys/values + def __iter__(self): + """ Iterate over column ids. """ + for column_id in self._columns.keys(): + yield column_id + + def __getitem__(self, key): + """ Returns column's value. """ + if not isinstance(key, str): + raise TypeError + if key not in self._columns.keys(): + raise KeyError + return getattr(self, key) + + def __setitem__(self, key, val): + """ Sets a column's value. """ + if not isinstance(key, str): + raise TypeError + if key not in self._columns.keys(): + raise KeyError + return setattr(self, key, val) + + def __len__(self): + """ + Returns the number of columns defined on that model. + """ + try: + return self._len + except: + self._len = len(self._columns.keys()) + return self._len + + def keys(self): + """ Returns a list of column IDs. """ + return [k for k in self] + + def values(self): + """ Returns list of column values. """ + return [self[k] for k in self] + + def items(self): + """ Returns a list of column ID/value tuples. """ + return [(k, self[k]) for k in self] + + def _as_dict(self): + """ Returns a map of column names to cleaned values """ + values = self._dynamic_columns or {} + for name, col in self._columns.items(): + values[name] = col.to_database(getattr(self, name, None)) + return values + + @classmethod + def create(cls, **kwargs): + """ + Create an instance of this model in the database. + + Takes the model column values as keyword arguments. Setting a value to + `None` is equivalent to running a CQL `DELETE` on that column. + + Returns the instance. + """ + extra_columns = set(kwargs.keys()) - set(cls._columns.keys()) + if extra_columns: + raise ValidationError("Incorrect columns passed: {0}".format(extra_columns)) + return cls.objects.create(**kwargs) + + @classmethod + def all(cls): + """ + Returns a queryset representing all stored objects + + This is a pass-through to the model objects().all() + """ + return cls.objects.all() + + @classmethod + def filter(cls, *args, **kwargs): + """ + Returns a queryset based on filter parameters. + + This is a pass-through to the model objects().:method:`~cqlengine.queries.filter`. + """ + return cls.objects.filter(*args, **kwargs) + + @classmethod + def get(cls, *args, **kwargs): + """ + Returns a single object based on the passed filter constraints. + + This is a pass-through to the model objects().:method:`~cqlengine.queries.get`. + """ + return cls.objects.get(*args, **kwargs) + + def timeout(self, timeout): + """ + Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete` + operations + """ + assert self._batch is None, 'Setting both timeout and batch is not supported' + self._timeout = timeout + return self + + def save(self): + """ + Saves an object to the database. + + .. code-block:: python + + #create a person instance + person = Person(first_name='Kimberly', last_name='Eggleston') + #saves it to Cassandra + person.save() + """ + + # handle polymorphic models + if self._is_polymorphic: + if self._is_polymorphic_base: + raise PolymorphicModelException('cannot save polymorphic base model') + else: + setattr(self, self._discriminator_column_name, self.__discriminator_value__) + + self.validate() + self.__dmlquery__(self.__class__, self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__, + if_not_exists=self._if_not_exists, + conditional=self._conditional, + timeout=self._timeout, + if_exists=self._if_exists).save() + + self._set_persisted() + + self._timestamp = None + + return self + + def update(self, **values): + """ + Performs an update on the model instance. You can pass in values to set on the model + for updating, or you can call without values to execute an update against any modified + fields. If no fields on the model have been modified since loading, no query will be + performed. Model validation is performed normally. Setting a value to `None` is + equivalent to running a CQL `DELETE` on that column. + + It is possible to do a blind update, that is, to update a field without having first selected the object out of the database. + See :ref:`Blind Updates ` + """ + for column_id, v in values.items(): + col = self._columns.get(column_id) + + # check for nonexistant columns + if col is None: + raise ValidationError( + "{0}.{1} has no column named: {2}".format( + self.__module__, self.__class__.__name__, column_id)) + + # check for primary key update attempts + if col.is_primary_key: + current_value = getattr(self, column_id) + if v != current_value: + raise ValidationError( + "Cannot apply update to primary key '{0}' for {1}.{2}".format( + column_id, self.__module__, self.__class__.__name__)) + + setattr(self, column_id, v) + + # handle polymorphic models + if self._is_polymorphic: + if self._is_polymorphic_base: + raise PolymorphicModelException('cannot update polymorphic base model') + else: + setattr(self, self._discriminator_column_name, self.__discriminator_value__) + + self.validate() + self.__dmlquery__(self.__class__, self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__, + conditional=self._conditional, + timeout=self._timeout, + if_exists=self._if_exists).update() + + self._set_persisted() + + self._timestamp = None + + return self + + def delete(self): + """ + Deletes the object from the database + """ + self.__dmlquery__(self.__class__, self, + batch=self._batch, + timestamp=self._timestamp, + consistency=self.__consistency__, + timeout=self._timeout, + conditional=self._conditional, + if_exists=self._if_exists).delete() + + def get_changed_columns(self): + """ + Returns a list of the columns that have been updated since instantiation or save + """ + return [k for k, v in self._values.items() if v.changed] + + @classmethod + def _class_batch(cls, batch): + return cls.objects.batch(batch) + + def _inst_batch(self, batch): + assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported' + if self._connection: + raise CQLEngineException("Cannot specify a connection on model in batch mode.") + self._batch = batch + return self + + batch = hybrid_classmethod(_class_batch, _inst_batch) + + @classmethod + def _class_get_connection(cls): + return cls.__connection__ + + def _inst_get_connection(self): + return self._connection or self.__connection__ + + _get_connection = hybrid_classmethod(_class_get_connection, _inst_get_connection) + + +class ModelMetaClass(type): + + def __new__(cls, name, bases, attrs): + # move column definitions into columns dict + # and set default column names + column_dict = OrderedDict() + primary_keys = OrderedDict() + pk_name = None + + # get inherited properties + inherited_columns = OrderedDict() + for base in bases: + for k, v in getattr(base, '_defined_columns', {}).items(): + inherited_columns.setdefault(k, v) + + # short circuit __abstract__ inheritance + is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False) + + # short circuit __discriminator_value__ inheritance + attrs['__discriminator_value__'] = attrs.get('__discriminator_value__') + + # TODO __default__ttl__ should be removed in the next major release + options = attrs.get('__options__') or {} + attrs['__default_ttl__'] = options.get('default_time_to_live') + + column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] + column_definitions = sorted(column_definitions, key=lambda x: x[1].position) + + is_polymorphic_base = any([c[1].discriminator_column for c in column_definitions]) + + column_definitions = [x for x in inherited_columns.items()] + column_definitions + discriminator_columns = [c for c in column_definitions if c[1].discriminator_column] + is_polymorphic = len(discriminator_columns) > 0 + if len(discriminator_columns) > 1: + raise ModelDefinitionException('only one discriminator_column can be defined in a model, {0} found'.format(len(discriminator_columns))) + + if attrs['__discriminator_value__'] and not is_polymorphic: + raise ModelDefinitionException('__discriminator_value__ specified, but no base columns defined with discriminator_column=True') + + discriminator_column_name, discriminator_column = discriminator_columns[0] if discriminator_columns else (None, None) + + if isinstance(discriminator_column, (columns.BaseContainerColumn, columns.Counter)): + raise ModelDefinitionException('counter and container columns cannot be used as discriminator columns') + + # find polymorphic base class + polymorphic_base = None + if is_polymorphic and not is_polymorphic_base: + def _get_polymorphic_base(bases): + for base in bases: + if getattr(base, '_is_polymorphic_base', False): + return base + klass = _get_polymorphic_base(base.__bases__) + if klass: + return klass + polymorphic_base = _get_polymorphic_base(bases) + + defined_columns = OrderedDict(column_definitions) + + # check for primary key + if not is_abstract and not any([v.primary_key for k, v in column_definitions]): + raise ModelDefinitionException("At least 1 primary key is required.") + + counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)] + data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)] + if counter_columns and data_columns: + raise ModelDefinitionException('counter models may not have data columns') + + has_partition_keys = any(v.partition_key for (k, v) in column_definitions) + + def _transform_column(col_name, col_obj): + column_dict[col_name] = col_obj + if col_obj.primary_key: + primary_keys[col_name] = col_obj + col_obj.set_column_name(col_name) + # set properties + attrs[col_name] = ColumnDescriptor(col_obj) + + partition_key_index = 0 + # transform column definitions + for k, v in column_definitions: + # don't allow a column with the same name as a built-in attribute or method + if k in BaseModel.__dict__: + raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k)) + + # counter column primary keys are not allowed + if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter): + raise ModelDefinitionException('counter columns cannot be used as primary keys') + + # this will mark the first primary key column as a partition + # key, if one hasn't been set already + if not has_partition_keys and v.primary_key: + v.partition_key = True + has_partition_keys = True + if v.partition_key: + v._partition_key_index = partition_key_index + partition_key_index += 1 + + overriding = column_dict.get(k) + if overriding: + v.position = overriding.position + v.partition_key = overriding.partition_key + v._partition_key_index = overriding._partition_key_index + _transform_column(k, v) + + partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) + clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + + if attrs.get('__compute_routing_key__', True): + key_cols = [c for c in partition_keys.values()] + partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) + key_cql_types = [c.cql_type for c in key_cols] + key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) + else: + partition_key_index = {} + key_serializer = staticmethod(lambda parts, proto_version: None) + + # setup partition key shortcut + if len(partition_keys) == 0: + if not is_abstract: + raise ModelException("at least one partition key must be defined") + if len(partition_keys) == 1: + pk_name = [x for x in partition_keys.keys()][0] + attrs['pk'] = attrs[pk_name] + else: + # composite partition key case, get/set a tuple of values + _get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys()) + _set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val)) + attrs['pk'] = property(_get, _set) + + # some validation + col_names = set() + for v in column_dict.values(): + # check for duplicate column names + if v.db_field_name in col_names: + raise ModelException("{0} defines the column '{1}' more than once".format(name, v.db_field_name)) + if v.clustering_order and not (v.primary_key and not v.partition_key): + raise ModelException("clustering_order may be specified only for clustering primary keys") + if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'): + raise ModelException("invalid clustering order '{0}' for column '{1}'".format(repr(v.clustering_order), v.db_field_name)) + col_names.add(v.db_field_name) + + # create db_name -> model name map for loading + db_map = {} + for col_name, field in column_dict.items(): + db_field = field.db_field_name + if db_field != col_name: + db_map[db_field] = col_name + + # add management members to the class + attrs['_columns'] = column_dict + attrs['_primary_keys'] = primary_keys + attrs['_defined_columns'] = defined_columns + + # maps the database field to the models key + attrs['_db_map'] = db_map + attrs['_pk_name'] = pk_name + attrs['_dynamic_columns'] = {} + + attrs['_partition_keys'] = partition_keys + attrs['_partition_key_index'] = partition_key_index + attrs['_key_serializer'] = key_serializer + attrs['_clustering_keys'] = clustering_keys + attrs['_has_counter'] = len(counter_columns) > 0 + + # add polymorphic management attributes + attrs['_is_polymorphic_base'] = is_polymorphic_base + attrs['_is_polymorphic'] = is_polymorphic + attrs['_polymorphic_base'] = polymorphic_base + attrs['_discriminator_column'] = discriminator_column + attrs['_discriminator_column_name'] = discriminator_column_name + attrs['_discriminator_map'] = {} if is_polymorphic_base else None + + # setup class exceptions + DoesNotExistBase = None + for base in bases: + DoesNotExistBase = getattr(base, 'DoesNotExist', None) + if DoesNotExistBase is not None: + break + + DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist) + attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {}) + + MultipleObjectsReturnedBase = None + for base in bases: + MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None) + if MultipleObjectsReturnedBase is not None: + break + + MultipleObjectsReturnedBase = MultipleObjectsReturnedBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned) + attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {}) + + # create the class and add a QuerySet to it + klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) + + udts = [] + for col in column_dict.values(): + columns.resolve_udts(col, udts) + + for user_type in set(udts): + user_type.register_for_keyspace(klass._get_keyspace()) + + return klass + + +class Model(BaseModel, metaclass=ModelMetaClass): + __abstract__ = True + """ + *Optional.* Indicates that this model is only intended to be used as a base class for other models. + You can't create tables for abstract models, but checks around schema validity are skipped during class construction. + """ + + __table_name__ = None + """ + *Optional.* Sets the name of the CQL table for this model. If left blank, the table name will be the name of the model, with it's module name as it's prefix. Manually defined table names are not inherited. + """ + + __table_name_case_sensitive__ = False + """ + *Optional.* By default, __table_name__ is case insensitive. Set this to True if you want to preserve the case sensitivity. + """ + + __keyspace__ = None + """ + Sets the name of the keyspace used by this model. + """ + + __connection__ = None + """ + Sets the name of the default connection used by this model. + """ + + __options__ = None + """ + *Optional* Table options applied with this model + + (e.g. compaction, default ttl, cache settings, tec.) + """ + + __discriminator_value__ = None + """ + *Optional* Specifies a value for the discriminator column when using model inheritance. + """ + + __compute_routing_key__ = True + """ + *Optional* Setting False disables computing the routing key for TokenAwareRouting + """ diff --git a/cassandra/cqlengine/named.py b/cassandra/cqlengine/named.py new file mode 100644 index 0000000000..219155818c --- /dev/null +++ b/cassandra/cqlengine/named.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.util import OrderedDict + +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine.columns import Column +from cassandra.cqlengine.connection import get_cluster +from cassandra.cqlengine.models import UsingDescriptor, BaseModel +from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet +from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist +from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned + + +class QuerySetDescriptor(object): + """ + returns a fresh queryset for the given model + it's declared on everytime it's accessed + """ + + def __get__(self, obj, model): + """ :rtype: ModelQuerySet """ + if model.__abstract__: + raise CQLEngineException('cannot execute queries against abstract models') + return SimpleQuerySet(obj) + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + + :rtype: ModelQuerySet + """ + raise NotImplementedError + + +class NamedColumn(AbstractQueryableColumn): + """ + A column that is not coupled to a model class, or type + """ + + def __init__(self, name): + self.name = name + + def __unicode__(self): + return self.name + + def _get_column(self): + """ :rtype: NamedColumn """ + return self + + @property + def db_field_name(self): + return self.name + + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{0}"'.format(self.name) + + def to_database(self, val): + return val + + +class NamedTable(object): + """ + A Table that is not coupled to a model class + """ + + __abstract__ = False + + objects = QuerySetDescriptor() + + __partition_keys = None + + _partition_key_index = None + + __connection__ = None + _connection = None + + using = UsingDescriptor() + + _get_connection = BaseModel._get_connection + + class DoesNotExist(_DoesNotExist): + pass + + class MultipleObjectsReturned(_MultipleObjectsReturned): + pass + + def __init__(self, keyspace, name): + self.keyspace = keyspace + self.name = name + self._connection = None + + @property + def _partition_keys(self): + if not self.__partition_keys: + self._get_partition_keys() + return self.__partition_keys + + def _get_partition_keys(self): + try: + table_meta = get_cluster(self._get_connection()).metadata.keyspaces[self.keyspace].tables[self.name] + self.__partition_keys = OrderedDict((pk.name, Column(primary_key=True, partition_key=True, db_field=pk.name)) for pk in table_meta.partition_key) + except Exception as e: + raise CQLEngineException("Failed inspecting partition keys for {0}." + "Ensure cqlengine is connected before attempting this with NamedTable.".format(self.column_family_name())) + + def column(self, name): + return NamedColumn(name) + + def column_family_name(self, include_keyspace=True): + """ + Returns the column family name if it's been defined + otherwise, it creates it from the module and class name + """ + if include_keyspace: + return '{0}.{1}'.format(self.keyspace, self.name) + else: + return self.name + + def _get_column(self, name): + """ + Returns the column matching the given name + + :rtype: Column + """ + return self.column(name) + + # def create(self, **kwargs): + # return self.objects.create(**kwargs) + + def all(self): + return self.objects.all() + + def filter(self, *args, **kwargs): + return self.objects.filter(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.objects.get(*args, **kwargs) + + +class NamedKeyspace(object): + """ + A keyspace + """ + + def __init__(self, name): + self.name = name + + def table(self, name): + """ + returns a table descriptor with the given + name that belongs to this keyspace + """ + return NamedTable(self.name, name) diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py new file mode 100644 index 0000000000..a9e7db2545 --- /dev/null +++ b/cassandra/cqlengine/operators.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine import UnicodeMixin + + +class QueryOperatorException(Exception): + pass + + +class BaseQueryOperator(UnicodeMixin): + # The symbol that identifies this operator in kwargs + # ie: colname__ + symbol = None + + # The comparator symbol this operator uses in cql + cql_symbol = None + + def __unicode__(self): + if self.cql_symbol is None: + raise QueryOperatorException("cql symbol is None") + return self.cql_symbol + + +class OpMapMeta(type): + + def __init__(cls, name, bases, dct): + if not hasattr(cls, 'opmap'): + cls.opmap = {} + else: + cls.opmap[cls.symbol] = cls + super(OpMapMeta, cls).__init__(name, bases, dct) + + +class BaseWhereOperator(BaseQueryOperator, metaclass=OpMapMeta): + """ base operator used for where clauses """ + @classmethod + def get_operator(cls, symbol): + try: + return cls.opmap[symbol.upper()] + except KeyError: + raise QueryOperatorException("{0} doesn't map to a QueryOperator".format(symbol)) + + +class EqualsOperator(BaseWhereOperator): + symbol = 'EQ' + cql_symbol = '=' + + +class NotEqualsOperator(BaseWhereOperator): + symbol = 'NE' + cql_symbol = '!=' + + +class InOperator(EqualsOperator): + symbol = 'IN' + cql_symbol = 'IN' + + +class GreaterThanOperator(BaseWhereOperator): + symbol = "GT" + cql_symbol = '>' + + +class GreaterThanOrEqualOperator(BaseWhereOperator): + symbol = "GTE" + cql_symbol = '>=' + + +class LessThanOperator(BaseWhereOperator): + symbol = "LT" + cql_symbol = '<' + + +class LessThanOrEqualOperator(BaseWhereOperator): + symbol = "LTE" + cql_symbol = '<=' + + +class ContainsOperator(EqualsOperator): + symbol = "CONTAINS" + cql_symbol = 'CONTAINS' + + +class LikeOperator(EqualsOperator): + symbol = "LIKE" + cql_symbol = 'LIKE' + + +class IsNotNullOperator(EqualsOperator): + symbol = "IS NOT NULL" + cql_symbol = 'IS NOT NULL' diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py new file mode 100644 index 0000000000..329bc7fade --- /dev/null +++ b/cassandra/cqlengine/query.py @@ -0,0 +1,1532 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from datetime import datetime, timedelta +from functools import partial +import time +from warnings import warn + +from cassandra.query import SimpleStatement, BatchType as CBatchType, BatchStatement +from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin +from cassandra.cqlengine import connection as conn +from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue +from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterThanOperator, + GreaterThanOrEqualOperator, LessThanOperator, + LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) +from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, + UpdateStatement, InsertStatement, + BaseCQLStatement, MapDeleteClause, ConditionalClause) + + +class QueryException(CQLEngineException): + pass + + +class IfNotExistsWithCounterColumn(CQLEngineException): + pass + + +class IfExistsWithCounterColumn(CQLEngineException): + pass + + +class LWTException(CQLEngineException): + """Lightweight conditional exception. + + This exception will be raised when a write using an `IF` clause could not be + applied due to existing data violating the condition. The existing data is + available through the ``existing`` attribute. + + :param existing: The current state of the data which prevented the write. + """ + def __init__(self, existing): + super(LWTException, self).__init__("LWT Query was not applied") + self.existing = existing + + +class DoesNotExist(QueryException): + pass + + +class MultipleObjectsReturned(QueryException): + pass + + +def check_applied(result): + """ + Raises LWTException if it looks like a failed LWT request. A LWTException + won't be raised in the special case in which there are several failed LWT + in a :class:`~cqlengine.query.BatchQuery`. + """ + try: + applied = result.was_applied + except Exception: + applied = True # result was not LWT form + if not applied: + raise LWTException(result.one()) + + +class AbstractQueryableColumn(UnicodeMixin): + """ + exposes cql query operators through pythons + builtin comparator symbols + """ + + def _get_column(self): + raise NotImplementedError + + def __unicode__(self): + raise NotImplementedError + + def _to_database(self, val): + if isinstance(val, QueryValue): + return val + else: + return self._get_column().to_database(val) + + def in_(self, item): + """ + Returns an in operator + + used where you'd typically want to use python's `in` operator + """ + return WhereClause(str(self), InOperator(), item) + + def contains_(self, item): + """ + Returns a CONTAINS operator + """ + return WhereClause(str(self), ContainsOperator(), item) + + + def __eq__(self, other): + return WhereClause(str(self), EqualsOperator(), self._to_database(other)) + + def __gt__(self, other): + return WhereClause(str(self), GreaterThanOperator(), self._to_database(other)) + + def __ge__(self, other): + return WhereClause(str(self), GreaterThanOrEqualOperator(), self._to_database(other)) + + def __lt__(self, other): + return WhereClause(str(self), LessThanOperator(), self._to_database(other)) + + def __le__(self, other): + return WhereClause(str(self), LessThanOrEqualOperator(), self._to_database(other)) + + +class BatchType(object): + Unlogged = 'UNLOGGED' + Counter = 'COUNTER' + + +class BatchQuery(object): + """ + Handles the batching of queries + + http://docs.datastax.com/en/cql/3.0/cql/cql_reference/batch_r.html + + See :doc:`/cqlengine/batches` for more details. + """ + warn_multiple_exec = True + + _consistency = None + + _connection = None + _connection_explicit = False + + + def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False, + timeout=conn.NOT_SET, connection=None): + """ + :param batch_type: (optional) One of batch type values available through BatchType enum + :type batch_type: BatchType, str or None + :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied + to the batch conditional. + :type timestamp: datetime or timedelta or None + :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) + :type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None. + :param execute_on_exception: (Defaults to False) Indicates that when the BatchQuery instance is used + as a context manager the queries accumulated within the context must be executed despite + encountering an error within the context. By default, any exception raised from within + the context scope will cause the batched queries not to be executed. + :type execute_on_exception: bool + :param timeout: (optional) Timeout for the entire batch (in seconds), if not specified fallback + to default session timeout + :type timeout: float or None + :param str connection: Connection name to use for the batch execution + """ + self.queries = [] + self.batch_type = batch_type + if timestamp is not None and not isinstance(timestamp, (datetime, timedelta)): + raise CQLEngineException('timestamp object must be an instance of datetime') + self.timestamp = timestamp + self._consistency = consistency + self._execute_on_exception = execute_on_exception + self._timeout = timeout + self._callbacks = [] + self._executed = False + self._context_entered = False + self._connection = connection + if connection: + self._connection_explicit = True + + def add_query(self, query): + if not isinstance(query, BaseCQLStatement): + raise CQLEngineException('only BaseCQLStatements can be added to a batch query') + self.queries.append(query) + + def consistency(self, consistency): + self._consistency = consistency + + def _execute_callbacks(self): + for callback, args, kwargs in self._callbacks: + callback(*args, **kwargs) + + def add_callback(self, fn, *args, **kwargs): + """Add a function and arguments to be passed to it to be executed after the batch executes. + + A batch can support multiple callbacks. + + Note, that if the batch does not execute, the callbacks are not executed. + A callback, thus, is an "on batch success" handler. + + :param fn: Callable object + :type fn: callable + :param args: Positional arguments to be passed to the callback at the time of execution + :param kwargs: Named arguments to be passed to the callback at the time of execution + """ + if not callable(fn): + raise ValueError("Value for argument 'fn' is {0} and is not a callable object.".format(type(fn))) + self._callbacks.append((fn, args, kwargs)) + + def execute(self): + if self._executed and self.warn_multiple_exec: + msg = "Batch executed multiple times." + if self._context_entered: + msg += " If using the batch as a context manager, there is no need to call execute directly." + warn(msg) + self._executed = True + + if len(self.queries) == 0: + # Empty batch is a no-op + # except for callbacks + self._execute_callbacks() + return + + batch_type = None if self.batch_type is CBatchType.LOGGED else self.batch_type + opener = 'BEGIN ' + (str(batch_type) + ' ' if batch_type else '') + ' BATCH' + if self.timestamp: + + if isinstance(self.timestamp, int): + ts = self.timestamp + elif isinstance(self.timestamp, (datetime, timedelta)): + ts = self.timestamp + if isinstance(self.timestamp, timedelta): + ts += datetime.now() # Apply timedelta + ts = int(time.mktime(ts.timetuple()) * 1e+6 + ts.microsecond) + else: + raise ValueError("Batch expects a long, a timedelta, or a datetime") + + opener += ' USING TIMESTAMP {0}'.format(ts) + + query_list = [opener] + parameters = {} + ctx_counter = 0 + for query in self.queries: + query.update_context_id(ctx_counter) + ctx = query.get_context() + ctx_counter += len(ctx) + query_list.append(' ' + str(query)) + parameters.update(ctx) + + query_list.append('APPLY BATCH;') + + tmp = conn.execute('\n'.join(query_list), parameters, self._consistency, self._timeout, connection=self._connection) + check_applied(tmp) + + self.queries = [] + self._execute_callbacks() + + def __enter__(self): + self._context_entered = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # don't execute if there was an exception by default + if exc_type is not None and not self._execute_on_exception: + return + self.execute() + + +class ContextQuery(object): + """ + A Context manager to allow a Model to switch context easily. Presently, the context only + specifies a keyspace for model IO. + + :param args: One or more models. A model should be a class type, not an instance. + :param kwargs: (optional) Context parameters: can be *keyspace* or *connection* + + For example: + + .. code-block:: python + + with ContextQuery(Automobile, keyspace='test2') as A: + A.objects.create(manufacturer='honda', year=2008, model='civic') + print(len(A.objects.all())) # 1 result + + with ContextQuery(Automobile, keyspace='test4') as A: + print(len(A.objects.all())) # 0 result + + # Multiple models + with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2): + print(len(A.objects.all())) + print(len(A2.objects.all())) + + """ + + def __init__(self, *args, **kwargs): + from cassandra.cqlengine import models + + self.models = [] + + if len(args) < 1: + raise ValueError("No model provided.") + + keyspace = kwargs.pop('keyspace', None) + connection = kwargs.pop('connection', None) + + if kwargs: + raise ValueError("Unknown keyword argument(s): {0}".format( + ','.join(kwargs.keys()))) + + for model in args: + try: + issubclass(model, models.Model) + except TypeError: + raise ValueError("Models must be derived from base Model.") + + m = models._clone_model_class(model, {}) + + if keyspace: + m.__keyspace__ = keyspace + if connection: + m.__connection__ = connection + + self.models.append(m) + + def __enter__(self): + if len(self.models) > 1: + return tuple(self.models) + return self.models[0] + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + +class AbstractQuerySet(object): + + def __init__(self, model): + super(AbstractQuerySet, self).__init__() + self.model = model + + # Where clause filters + self._where = [] + + # Conditional clause filters + self._conditional = [] + + # ordering arguments + self._order = [] + + self._allow_filtering = False + + # CQL has a default limit of 10000, it's defined here + # because explicit is better than implicit + self._limit = 10000 + + # We store the fields for which we use the Equal operator + # in a query, so we don't select it from the DB. _defer_fields + # will contain the names of the fields in the DB, not the names + # of the variables used by the mapper + self._defer_fields = set() + self._deferred_values = {} + + # This variable will hold the names in the database of the fields + # for which we want to query + self._only_fields = [] + + self._values_list = False + self._flat_values_list = False + + # results cache + self._result_cache = None + self._result_idx = None + self._result_generator = None + self._materialize_results = True + + self._distinct_fields = None + + self._count = None + + self._batch = None + self._ttl = None + self._consistency = None + self._timestamp = None + self._if_not_exists = False + self._timeout = conn.NOT_SET + self._if_exists = False + self._fetch_size = None + self._connection = None + + @property + def column_family_name(self): + return self.model.column_family_name() + + def _execute(self, statement): + if self._batch: + return self._batch.add_query(statement) + else: + connection = self._connection or self.model._get_connection() + result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) + if self._if_not_exists or self._if_exists or self._conditional: + check_applied(result) + return result + + def __unicode__(self): + return str(self._select_query()) + + def __str__(self): + return str(self.__unicode__()) + + def __call__(self, *args, **kwargs): + return self.filter(*args, **kwargs) + + def __deepcopy__(self, memo): + clone = self.__class__(self.model) + for k, v in self.__dict__.items(): + if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution + clone.__dict__[k] = None + elif k == '_batch': + # we need to keep the same batch instance across + # all queryset clones, otherwise the batched queries + # fly off into other batch instances which are never + # executed, thx @dokai + clone.__dict__[k] = self._batch + elif k == '_timeout': + clone.__dict__[k] = self._timeout + else: + clone.__dict__[k] = copy.deepcopy(v, memo) + + return clone + + def __len__(self): + self._execute_query() + return self.count() + + # ----query generation / execution---- + + def _select_fields(self): + """ returns the fields to select """ + return [] + + def _validate_select_where(self): + """ put select query validation here """ + + def _select_query(self): + """ + Returns a select clause based on the given filter args + """ + if self._where: + self._validate_select_where() + return SelectStatement( + self.column_family_name, + fields=self._select_fields(), + where=self._where, + order_by=self._order, + limit=self._limit, + allow_filtering=self._allow_filtering, + distinct_fields=self._distinct_fields, + fetch_size=self._fetch_size + ) + + # ----Reads------ + + def _execute_query(self): + if self._batch: + raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") + if self._result_cache is None: + self._result_generator = (i for i in self._execute(self._select_query())) + self._result_cache = [] + self._construct_result = self._maybe_inject_deferred(self._get_result_constructor()) + + # "DISTINCT COUNT()" is not supported in C* < 2.2, so we need to materialize all results to get + # len() and count() working with DISTINCT queries + if self._materialize_results or self._distinct_fields: + self._fill_result_cache() + + def _fill_result_cache(self): + """ + Fill the result cache with all results. + """ + + idx = 0 + try: + while True: + idx += 1000 + self._fill_result_cache_to_idx(idx) + except StopIteration: + pass + + self._count = len(self._result_cache) + + def _fill_result_cache_to_idx(self, idx): + self._execute_query() + if self._result_idx is None: + self._result_idx = -1 + + qty = idx - self._result_idx + if qty < 1: + return + else: + for idx in range(qty): + self._result_idx += 1 + while True: + try: + self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) + break + except IndexError: + self._result_cache.append(next(self._result_generator)) + + def __iter__(self): + self._execute_query() + + idx = 0 + while True: + if len(self._result_cache) <= idx: + try: + self._result_cache.append(next(self._result_generator)) + except StopIteration: + break + + instance = self._result_cache[idx] + if isinstance(instance, dict): + self._fill_result_cache_to_idx(idx) + yield self._result_cache[idx] + + idx += 1 + + def __getitem__(self, s): + self._execute_query() + + if isinstance(s, slice): + start = s.start if s.start else 0 + + if start < 0 or (s.stop is not None and s.stop < 0): + warn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", + DeprecationWarning) + + # calculate the amount of results that need to be loaded + end = s.stop + if start < 0 or s.stop is None or s.stop < 0: + end = self.count() + + try: + self._fill_result_cache_to_idx(end) + except StopIteration: + pass + + return self._result_cache[start:s.stop:s.step] + else: + try: + s = int(s) + except (ValueError, TypeError): + raise TypeError('QuerySet indices must be integers') + + if s < 0: + warn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", + DeprecationWarning) + + # Using negative indexing is costly since we have to execute a count() + if s < 0: + num_results = self.count() + s += num_results + + try: + self._fill_result_cache_to_idx(s) + except StopIteration: + raise IndexError + + return self._result_cache[s] + + def _get_result_constructor(self): + """ + Returns a function that will be used to instantiate query results + """ + raise NotImplementedError + + @staticmethod + def _construct_with_deferred(f, deferred, row): + row.update(deferred) + return f(row) + + def _maybe_inject_deferred(self, constructor): + return partial(self._construct_with_deferred, constructor, self._deferred_values)\ + if self._deferred_values else constructor + + def batch(self, batch_obj): + """ + Set a batch object to run the query on. + + Note: running a select query with a batch object will raise an exception + """ + if self._connection: + raise CQLEngineException("Cannot specify the connection on model in batch mode.") + + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): + raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + clone = copy.deepcopy(self) + clone._batch = batch_obj + return clone + + def first(self): + try: + return next(iter(self)) + except StopIteration: + return None + + def all(self): + """ + Returns a queryset matching all rows + + .. code-block:: python + + for user in User.objects().all(): + print(user) + """ + return copy.deepcopy(self) + + def consistency(self, consistency): + """ + Sets the consistency level for the operation. See :class:`.ConsistencyLevel`. + + .. code-block:: python + + for user in User.objects(id=3).consistency(CL.ONE): + print(user) + """ + clone = copy.deepcopy(self) + clone._consistency = consistency + return clone + + def _parse_filter_arg(self, arg): + """ + Parses a filter arg in the format: + __ + :returns: colname, op tuple + """ + statement = arg.rsplit('__', 1) + if len(statement) == 1: + return arg, None + elif len(statement) == 2: + return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) + else: + raise QueryException("Can't parse '{0}'".format(arg)) + + def iff(self, *args, **kwargs): + """Adds IF statements to queryset""" + if len([x for x in kwargs.values() if x is None]): + raise CQLEngineException("None values on iff are not allowed") + + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, ConditionalClause): + raise QueryException('{0} is not a valid query operator'.format(operator)) + clone._conditional.append(operator) + + for arg, val in kwargs.items(): + if isinstance(val, Token): + raise QueryException("Token() values are not valid in conditionals") + + col_name, col_op = self._parse_filter_arg(arg) + try: + column = self.model._get_column(col_name) + except KeyError: + raise QueryException("Can't resolve column name: '{0}'".format(col_name)) + + if isinstance(val, BaseQueryFunction): + query_val = val + else: + query_val = column.to_database(val) + + operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator = operator_class() + clone._conditional.append(WhereClause(column.db_field_name, operator, query_val)) + + return clone + + def filter(self, *args, **kwargs): + """ + Adds WHERE arguments to the queryset, returning a new queryset + + See :ref:`retrieving-objects-with-filters` + + Returns a QuerySet filtered on the keyword arguments + """ + # add arguments to the where clause filters + if len([x for x in kwargs.values() if x is None]): + raise CQLEngineException("None values on filter are not allowed") + + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, WhereClause): + raise QueryException('{0} is not a valid query operator'.format(operator)) + clone._where.append(operator) + + for arg, val in kwargs.items(): + col_name, col_op = self._parse_filter_arg(arg) + quote_field = True + + if not isinstance(val, Token): + try: + column = self.model._get_column(col_name) + except KeyError: + raise QueryException("Can't resolve column name: '{0}'".format(col_name)) + else: + if col_name != 'pk__token': + raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + + column = columns._PartitionKeysToken(self.model) + quote_field = False + + partition_columns = column.partition_columns + if len(partition_columns) != len(val.value): + raise QueryException( + 'Token() received {0} arguments but model has {1} partition keys'.format( + len(val.value), len(partition_columns))) + val.set_columns(partition_columns) + + # get query operator, or use equals if not supplied + operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator = operator_class() + + if isinstance(operator, InOperator): + if not isinstance(val, (list, tuple)): + raise QueryException('IN queries must use a list/tuple value') + query_val = [column.to_database(v) for v in val] + elif isinstance(val, BaseQueryFunction): + query_val = val + elif (isinstance(operator, ContainsOperator) and + isinstance(column, (columns.List, columns.Set, columns.Map))): + # For ContainsOperator and collections, we query using the value, not the container + query_val = val + else: + query_val = column.to_database(val) + if not col_op: # only equal values should be deferred + clone._defer_fields.add(column.db_field_name) + clone._deferred_values[column.db_field_name] = val # map by db field name for substitution in results + + clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) + + return clone + + def get(self, *args, **kwargs): + """ + Returns a single instance matching this query, optionally with additional filter kwargs. + + See :ref:`retrieving-objects-with-filters` + + Returns a single object matching the QuerySet. + + .. code-block:: python + + user = User.get(id=1) + + If no objects are matched, a :class:`~.DoesNotExist` exception is raised. + + If more than one object is found, a :class:`~.MultipleObjectsReturned` exception is raised. + """ + if args or kwargs: + return self.filter(*args, **kwargs).get() + + self._execute_query() + + # Check that the resultset only contains one element, avoiding sending a COUNT query + try: + self[1] + raise self.model.MultipleObjectsReturned('Multiple objects found') + except IndexError: + pass + + try: + obj = self[0] + except IndexError: + raise self.model.DoesNotExist + + return obj + + def _get_ordering_condition(self, colname): + order_type = 'DESC' if colname.startswith('-') else 'ASC' + colname = colname.replace('-', '') + + return colname, order_type + + def order_by(self, *colnames): + """ + Sets the column(s) to be used for ordering + + Default order is ascending, prepend a '-' to any column name for descending + + *Note: column names must be a clustering key* + + .. code-block:: python + + from uuid import uuid1,uuid4 + + class Comment(Model): + photo_id = UUID(primary_key=True) + comment_id = TimeUUID(primary_key=True, default=uuid1) # second primary key component is a clustering key + comment = Text() + + sync_table(Comment) + + u = uuid4() + for x in range(5): + Comment.create(photo_id=u, comment="test %d" % x) + + print("Normal") + for comment in Comment.objects(photo_id=u): + print(comment.comment_id) + + print("Reversed") + for comment in Comment.objects(photo_id=u).order_by("-comment_id"): + print(comment.comment_id) + """ + if len(colnames) == 0: + clone = copy.deepcopy(self) + clone._order = [] + return clone + + conditions = [] + for colname in colnames: + conditions.append('"{0}" {1}'.format(*self._get_ordering_condition(colname))) + + clone = copy.deepcopy(self) + clone._order.extend(conditions) + return clone + + def count(self): + """ + Returns the number of rows matched by this query. + + *Note: This function executes a SELECT COUNT() and has a performance cost on large datasets* + """ + if self._batch: + raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") + + if self._count is None: + query = self._select_query() + query.count = True + result = self._execute(query) + count_row = result.one().popitem() + self._count = count_row[1] + return self._count + + def distinct(self, distinct_fields=None): + """ + Returns the DISTINCT rows matched by this query. + + distinct_fields default to the partition key fields if not specified. + + *Note: distinct_fields must be a partition key or a static column* + + .. code-block:: python + + class Automobile(Model): + manufacturer = columns.Text(partition_key=True) + year = columns.Integer(primary_key=True) + model = columns.Text(primary_key=True) + price = columns.Decimal() + + sync_table(Automobile) + + # create rows + + Automobile.objects.distinct() + + # or + + Automobile.objects.distinct(['manufacturer']) + + """ + + clone = copy.deepcopy(self) + if distinct_fields: + clone._distinct_fields = distinct_fields + else: + clone._distinct_fields = [x.column_name for x in self.model._partition_keys.values()] + + return clone + + def limit(self, v): + """ + Limits the number of results returned by Cassandra. Use *0* or *None* to disable. + + *Note that CQL's default limit is 10,000, so all queries without a limit set explicitly will have an implicit limit of 10,000* + + .. code-block:: python + + # Fetch 100 users + for user in User.objects().limit(100): + print(user) + + # Fetch all users + for user in User.objects().limit(None): + print(user) + """ + + if v is None: + v = 0 + + if not isinstance(v, int): + raise TypeError + if v == self._limit: + return self + + if v < 0: + raise QueryException("Negative limit is not allowed") + + clone = copy.deepcopy(self) + clone._limit = v + return clone + + def fetch_size(self, v): + """ + Sets the number of rows that are fetched at a time. + + *Note that driver's default fetch size is 5000.* + + .. code-block:: python + + for user in User.objects().fetch_size(500): + print(user) + """ + + if not isinstance(v, int): + raise TypeError + if v == self._fetch_size: + return self + + if v < 1: + raise QueryException("fetch size less than 1 is not allowed") + + clone = copy.deepcopy(self) + clone._fetch_size = v + return clone + + def allow_filtering(self): + """ + Enables the (usually) unwise practice of querying on a clustering key without also defining a partition key + """ + clone = copy.deepcopy(self) + clone._allow_filtering = True + return clone + + def _only_or_defer(self, action, fields): + if action == 'only' and self._only_fields: + raise QueryException("QuerySet already has 'only' fields defined") + + clone = copy.deepcopy(self) + + # check for strange fields + missing_fields = [f for f in fields if f not in self.model._columns.keys()] + if missing_fields: + raise QueryException( + "Can't resolve fields {0} in {1}".format( + ', '.join(missing_fields), self.model.__name__)) + + fields = [self.model._columns[field].db_field_name for field in fields] + + if action == 'defer': + clone._defer_fields.update(fields) + elif action == 'only': + clone._only_fields = fields + else: + raise ValueError + + return clone + + def only(self, fields): + """ Load only these fields for the returned query """ + return self._only_or_defer('only', fields) + + def defer(self, fields): + """ Don't load these fields for the returned query """ + return self._only_or_defer('defer', fields) + + def create(self, **kwargs): + return self.model(**kwargs) \ + .batch(self._batch) \ + .ttl(self._ttl) \ + .consistency(self._consistency) \ + .if_not_exists(self._if_not_exists) \ + .timestamp(self._timestamp) \ + .if_exists(self._if_exists) \ + .using(connection=self._connection) \ + .save() + + def delete(self): + """ + Deletes the contents of a query + """ + # validate where clause + partition_keys = set(x.db_field_name for x in self.model._partition_keys.values()) + if partition_keys - set(c.field for c in self._where): + raise QueryException("The partition key must be defined on delete queries") + + dq = DeleteStatement( + self.column_family_name, + where=self._where, + timestamp=self._timestamp, + conditionals=self._conditional, + if_exists=self._if_exists + ) + self._execute(dq) + + def __eq__(self, q): + if len(self._where) == len(q._where): + return all([w in q._where for w in self._where]) + return False + + def __ne__(self, q): + return not (self != q) + + def timeout(self, timeout): + """ + :param timeout: Timeout for the query (in seconds) + :type timeout: float or None + """ + clone = copy.deepcopy(self) + clone._timeout = timeout + return clone + + def using(self, keyspace=None, connection=None): + """ + Change the context on-the-fly of the Model class (keyspace, connection) + """ + + if connection and self._batch: + raise CQLEngineException("Cannot specify a connection on model in batch mode.") + + clone = copy.deepcopy(self) + if keyspace: + from cassandra.cqlengine.models import _clone_model_class + clone.model = _clone_model_class(self.model, {'__keyspace__': keyspace}) + + if connection: + clone._connection = connection + + return clone + + +class ResultObject(dict): + """ + adds attribute access to a dictionary + """ + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError + + +class SimpleQuerySet(AbstractQuerySet): + """ + Overrides _get_result_constructor for querysets that do not define a model (e.g. NamedTable queries) + """ + + def _get_result_constructor(self): + """ + Returns a function that will be used to instantiate query results + """ + return ResultObject + + +class ModelQuerySet(AbstractQuerySet): + """ + """ + def _validate_select_where(self): + """ Checks that a filterset will not create invalid select statement """ + # check that there's either a =, a IN or a CONTAINS (collection) + # relationship with a primary key or indexed field. We also allow + # custom indexes to be queried with any operator (a difference + # between a secondary index) + equal_ops = [self.model._get_column_by_db_name(w.field) \ + for w in self._where if not isinstance(w.value, Token) + and (isinstance(w.operator, EqualsOperator) + or self.model._get_column_by_db_name(w.field).custom_index)] + token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) + if not any(w.primary_key or w.has_index for w in equal_ops) and not token_comparison and not self._allow_filtering: + raise QueryException( + ('Where clauses require either =, a IN or a CONTAINS ' + '(collection) comparison with either a primary key or ' + 'indexed field. You might want to consider setting ' + 'custom_index on fields that you manage index outside ' + 'cqlengine.')) + + if not self._allow_filtering: + # if the query is not on an indexed field + if not any(w.has_index for w in equal_ops): + if not any([w.partition_key for w in equal_ops]) and not token_comparison: + raise QueryException( + ('Filtering on a clustering key without a partition ' + 'key is not allowed unless allow_filtering() is ' + 'called on the queryset. You might want to consider ' + 'setting custom_index on fields that you manage ' + 'index outside cqlengine.')) + + def _select_fields(self): + if self._defer_fields or self._only_fields: + fields = [columns.db_field_name for columns in self.model._columns.values()] + if self._defer_fields: + fields = [f for f in fields if f not in self._defer_fields] + # select the partition keys if all model fields are set defer + if not fields: + fields = [columns.db_field_name for columns in self.model._partition_keys.values()] + if self._only_fields: + fields = [f for f in fields if f in self._only_fields] + if not fields: + raise QueryException('No fields in select query. Only fields: "{0}", defer fields: "{1}"'.format( + ','.join(self._only_fields), ','.join(self._defer_fields))) + return fields + return super(ModelQuerySet, self)._select_fields() + + def _get_result_constructor(self): + """ Returns a function that will be used to instantiate query results """ + if not self._values_list: # we want models + return self.model._construct_instance + elif self._flat_values_list: # the user has requested flattened list (1 value per row) + key = self._only_fields[0] + return lambda row: row[key] + else: + return lambda row: [row[f] for f in self._only_fields] + + def _get_ordering_condition(self, colname): + colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname) + + column = self.model._columns.get(colname) + if column is None: + raise QueryException("Can't resolve the column name: '{0}'".format(colname)) + + # validate the column selection + if not column.primary_key: + raise QueryException( + "Can't order on '{0}', can only order on (clustered) primary keys".format(colname)) + + pks = [v for k, v in self.model._columns.items() if v.primary_key] + if column == pks[0]: + raise QueryException( + "Can't order by the first primary key (partition key), clustering (secondary) keys only") + + return column.db_field_name, order_type + + def values_list(self, *fields, **kwargs): + """ Instructs the query set to return tuples, not model instance """ + flat = kwargs.pop('flat', False) + if kwargs: + raise TypeError('Unexpected keyword arguments to values_list: %s' + % (kwargs.keys(),)) + if flat and len(fields) > 1: + raise TypeError("'flat' is not valid when values_list is called with more than one field.") + clone = self.only(fields) + clone._values_list = True + clone._flat_values_list = flat + return clone + + def ttl(self, ttl): + """ + Sets the ttl (in seconds) for modified data. + + *Note that running a select query with a ttl value will raise an exception* + """ + clone = copy.deepcopy(self) + clone._ttl = ttl + return clone + + def timestamp(self, timestamp): + """ + Allows for custom timestamps to be saved with the record. + """ + clone = copy.deepcopy(self) + clone._timestamp = timestamp + return clone + + def if_not_exists(self): + """ + Check the existence of an object before insertion. + + If the insertion isn't applied, a LWTException is raised. + """ + if self.model._has_counter: + raise IfNotExistsWithCounterColumn('if_not_exists cannot be used with tables containing counter columns') + clone = copy.deepcopy(self) + clone._if_not_exists = True + return clone + + def if_exists(self): + """ + Check the existence of an object before an update or delete. + + If the update or delete isn't applied, a LWTException is raised. + """ + if self.model._has_counter: + raise IfExistsWithCounterColumn('if_exists cannot be used with tables containing counter columns') + clone = copy.deepcopy(self) + clone._if_exists = True + return clone + + def update(self, **values): + """ + Performs an update on the row selected by the queryset. Include values to update in the + update like so: + + .. code-block:: python + + Model.objects(key=n).update(value='x') + + Passing in updates for columns which are not part of the model will raise a ValidationError. + + Per column validation will be performed, but instance level validation will not + (i.e., `Model.validate` is not called). This is sometimes referred to as a blind update. + + For example: + + .. code-block:: python + + class User(Model): + id = Integer(primary_key=True) + name = Text() + + setup(["localhost"], "test") + sync_table(User) + + u = User.create(id=1, name="jon") + + User.objects(id=1).update(name="Steve") + + # sets name to null + User.objects(id=1).update(name=None) + + + Also supported is blindly adding and removing elements from container columns, + without loading a model instance from Cassandra. + + Using the syntax `.update(column_name={x, y, z})` will overwrite the contents of the container, like updating a + non container column. However, adding `__` to the end of the keyword arg, makes the update call add + or remove items from the collection, without overwriting then entire column. + + Given the model below, here are the operations that can be performed on the different container columns: + + .. code-block:: python + + class Row(Model): + row_id = columns.Integer(primary_key=True) + set_column = columns.Set(Integer) + list_column = columns.List(Integer) + map_column = columns.Map(Integer, Integer) + + :class:`~cqlengine.columns.Set` + + - `add`: adds the elements of the given set to the column + - `remove`: removes the elements of the given set to the column + + + .. code-block:: python + + # add elements to a set + Row.objects(row_id=5).update(set_column__add={6}) + + # remove elements to a set + Row.objects(row_id=5).update(set_column__remove={4}) + + :class:`~cqlengine.columns.List` + + - `append`: appends the elements of the given list to the end of the column + - `prepend`: prepends the elements of the given list to the beginning of the column + + .. code-block:: python + + # append items to a list + Row.objects(row_id=5).update(list_column__append=[6, 7]) + + # prepend items to a list + Row.objects(row_id=5).update(list_column__prepend=[1, 2]) + + + :class:`~cqlengine.columns.Map` + + - `update`: adds the given keys/values to the columns, creating new entries if they didn't exist, and overwriting old ones if they did + + .. code-block:: python + + # add items to a map + Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4}) + + # remove items from a map + Row.objects(row_id=5).update(map_column__remove={1, 2}) + """ + if not values: + return + + nulled_columns = set() + updated_columns = set() + us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, + timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) + for name, val in values.items(): + col_name, col_op = self._parse_filter_arg(name) + col = self.model._columns.get(col_name) + # check for nonexistant columns + if col is None: + raise ValidationError("{0}.{1} has no column named: {2}".format(self.__module__, self.model.__name__, col_name)) + # check for primary key update attempts + if col.is_primary_key: + raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(col_name, self.__module__, self.model.__name__)) + + if col_op == 'remove' and isinstance(col, columns.Map): + if not isinstance(val, set): + raise ValidationError( + "Cannot apply update operation '{0}' on column '{1}' with value '{2}'. A set is required.".format(col_op, col_name, val)) + val = {v: None for v in val} + else: + # we should not provide default values in this use case. + val = col.validate(val) + + if val is None: + nulled_columns.add(col_name) + continue + + us.add_update(col, val, operation=col_op) + updated_columns.add(col_name) + + if us.assignments: + self._execute(us) + + if nulled_columns: + delete_conditional = [condition for condition in self._conditional + if condition.field not in updated_columns] if self._conditional else None + ds = DeleteStatement(self.column_family_name, fields=nulled_columns, + where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) + self._execute(ds) + + +class DMLQuery(object): + """ + A query object used for queries performing inserts, updates, or deletes + + this is usually instantiated by the model instance to be modified + + unlike the read query object, this is mutable + """ + _ttl = None + _consistency = None + _timestamp = None + _if_not_exists = False + _if_exists = False + + def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, + if_not_exists=False, conditional=None, timeout=conn.NOT_SET, if_exists=False): + self.model = model + self.column_family_name = self.model.column_family_name() + self.instance = instance + self._batch = batch + self._ttl = ttl + self._consistency = consistency + self._timestamp = timestamp + self._if_not_exists = if_not_exists + self._if_exists = if_exists + self._conditional = conditional + self._timeout = timeout + + def _execute(self, statement): + connection = self.instance._get_connection() if self.instance else self.model._get_connection() + if self._batch: + if self._batch._connection: + if not self._batch._connection_explicit and connection and \ + connection != self._batch._connection: + raise CQLEngineException('BatchQuery queries must be executed on the same connection') + else: + # set the BatchQuery connection from the model + self._batch._connection = connection + return self._batch.add_query(statement) + else: + results = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) + if self._if_not_exists or self._if_exists or self._conditional: + check_applied(results) + return results + + def batch(self, batch_obj): + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): + raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + self._batch = batch_obj + return self + + def _delete_null_columns(self, conditionals=None): + """ + executes a delete query to remove columns that have changed to null + """ + ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) + deleted_fields = False + static_only = True + for _, v in self.instance._values.items(): + col = v.column + if v.deleted: + ds.add_field(col.db_field_name) + deleted_fields = True + static_only &= col.static + elif isinstance(col, columns.Map): + uc = MapDeleteClause(col.db_field_name, v.value, v.previous_value) + if uc.get_context_size() > 0: + ds.add_field(uc) + deleted_fields = True + static_only |= col.static + + if deleted_fields: + keys = self.model._partition_keys if static_only else self.model._primary_keys + for name, col in keys.items(): + ds.add_where(col, EqualsOperator(), getattr(self.instance, name)) + self._execute(ds) + + def update(self): + """ + updates a row. + This is a blind update call. + All validation and cleaning needs to happen + prior to calling this. + """ + if self.instance is None: + raise CQLEngineException("DML Query instance attribute is None") + assert type(self.instance) == self.model + null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True + static_changed_only = True + statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, + conditionals=self._conditional, if_exists=self._if_exists) + for name, col in self.instance._clustering_keys.items(): + null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) + + updated_columns = set() + # get defined fields and their column names + for name, col in self.model._columns.items(): + # if clustering key is null, don't include non-static columns + if null_clustering_key and not col.static and not col.partition_key: + continue + if not col.is_primary_key: + val = getattr(self.instance, name, None) + val_mgr = self.instance._values[name] + + if val is None: + continue + + if not val_mgr.changed and not isinstance(col, columns.Counter): + continue + + static_changed_only = static_changed_only and col.static + statement.add_update(col, val, previous=val_mgr.previous_value) + updated_columns.add(col.db_field_name) + + if statement.assignments: + for name, col in self.model._primary_keys.items(): + # only include clustering key if clustering key is not null, and non-static columns are changed to avoid cql error + if (null_clustering_key or static_changed_only) and (not col.partition_key): + continue + statement.add_where(col, EqualsOperator(), getattr(self.instance, name)) + self._execute(statement) + + if not null_clustering_key: + # remove conditions on fields that have been updated + delete_conditionals = [condition for condition in self._conditional + if condition.field not in updated_columns] if self._conditional else None + self._delete_null_columns(delete_conditionals) + + def save(self): + """ + Creates / updates a row. + This is a blind insert call. + All validation and cleaning needs to happen + prior to calling this. + """ + if self.instance is None: + raise CQLEngineException("DML Query instance attribute is None") + assert type(self.instance) == self.model + + nulled_fields = set() + if self.instance._has_counter or self.instance._can_update(): + if self.instance._has_counter: + warn("'create' and 'save' actions on Counters are deprecated. It will be disallowed in 4.0. " + "Use the 'update' mechanism instead.", DeprecationWarning) + return self.update() + else: + insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) + static_save_only = False if len(self.instance._clustering_keys) == 0 else True + for name, col in self.instance._clustering_keys.items(): + static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None)) + for name, col in self.instance._columns.items(): + if static_save_only and not col.static and not col.partition_key: + continue + val = getattr(self.instance, name, None) + if col._val_is_null(val): + if self.instance._values[name].changed: + nulled_fields.add(col.db_field_name) + continue + if col.has_default and not self.instance._values[name].changed: + # Ensure default columns included in a save() are marked as explicit, to get them *persisted* properly + self.instance._values[name].explicit = True + insert.add_assignment(col, getattr(self.instance, name, None)) + + # skip query execution if it's empty + # caused by pointless update queries + if not insert.is_empty: + self._execute(insert) + # delete any nulled columns + if not static_save_only: + self._delete_null_columns() + + def delete(self): + """ Deletes one instance """ + if self.instance is None: + raise CQLEngineException("DML Query instance attribute is None") + + ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) + for name, col in self.model._primary_keys.items(): + val = getattr(self.instance, name) + if val is None and not col.partition_key: + continue + ds.add_where(col, EqualsOperator(), val) + self._execute(ds) + + +def _execute_statement(model, statement, consistency_level, timeout, connection=None): + params = statement.get_context() + s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size) + if model._partition_key_index: + key_values = statement.partition_key_values(model._partition_key_index) + if not any(v is None for v in key_values): + parts = model._routing_key_from_values(key_values, conn.get_cluster(connection).protocol_version) + s.routing_key = parts + s.keyspace = model._get_keyspace() + connection = connection or model._get_connection() + return conn.execute(s, params, timeout=timeout, connection=connection) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py new file mode 100644 index 0000000000..b20b07ef56 --- /dev/null +++ b/cassandra/cqlengine/statements.py @@ -0,0 +1,907 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +import time + +from cassandra.query import FETCH_SIZE_UNSET +from cassandra.cqlengine import columns +from cassandra.cqlengine import UnicodeMixin +from cassandra.cqlengine.functions import QueryValue +from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator, IsNotNullOperator + + +class StatementException(Exception): + pass + + +class ValueQuoter(UnicodeMixin): + + def __init__(self, value): + self.value = value + + def __unicode__(self): + from cassandra.encoder import cql_quote + if isinstance(self.value, (list, tuple)): + return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' + elif isinstance(self.value, dict): + return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' + elif isinstance(self.value, set): + return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' + return cql_quote(self.value) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value + return False + + +class InQuoter(ValueQuoter): + + def __unicode__(self): + from cassandra.encoder import cql_quote + return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' + + +class BaseClause(UnicodeMixin): + + def __init__(self, field, value): + self.field = field + self.value = value + self.context_id = None + + def __unicode__(self): + raise NotImplementedError + + def __hash__(self): + return hash(self.field) ^ hash(self.value) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.field == other.field and self.value == other.value + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_context_size(self): + """ returns the number of entries this clause will add to the query context """ + return 1 + + def set_context_id(self, i): + """ sets the value placeholder that will be used in the query """ + self.context_id = i + + def update_context(self, ctx): + """ updates the query context with this clauses values """ + assert isinstance(ctx, dict) + ctx[str(self.context_id)] = self.value + + +class WhereClause(BaseClause): + """ a single where statement used in queries """ + + def __init__(self, field, operator, value, quote_field=True): + """ + + :param field: + :param operator: + :param value: + :param quote_field: hack to get the token function rendering properly + :return: + """ + if not isinstance(operator, BaseWhereOperator): + raise StatementException( + "operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator)) + ) + super(WhereClause, self).__init__(field, value) + self.operator = operator + self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) + self.quote_field = quote_field + + def __unicode__(self): + field = ('"{0}"' if self.quote_field else '{0}').format(self.field) + return u'{0} {1} {2}'.format(field, self.operator, str(self.query_value)) + + def __hash__(self): + return super(WhereClause, self).__hash__() ^ hash(self.operator) + + def __eq__(self, other): + if super(WhereClause, self).__eq__(other): + return self.operator.__class__ == other.operator.__class__ + return False + + def get_context_size(self): + return self.query_value.get_context_size() + + def set_context_id(self, i): + super(WhereClause, self).set_context_id(i) + self.query_value.set_context_id(i) + + def update_context(self, ctx): + if isinstance(self.operator, InOperator): + ctx[str(self.context_id)] = InQuoter(self.value) + else: + self.query_value.update_context(ctx) + + +class IsNotNullClause(WhereClause): + def __init__(self, field): + super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), '') + + def __unicode__(self): + field = ('"{0}"' if self.quote_field else '{0}').format(self.field) + return u'{0} {1}'.format(field, self.operator) + + def update_context(self, ctx): + pass + + def get_context_size(self): + return 0 + +# alias for convenience +IsNotNull = IsNotNullClause + + +class AssignmentClause(BaseClause): + """ a single variable st statement """ + + def __unicode__(self): + return u'"{0}" = %({1})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + +class ConditionalClause(BaseClause): + """ A single variable iff statement """ + + def __unicode__(self): + return u'"{0}" = %({1})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + +class ContainerUpdateTypeMapMeta(type): + + def __init__(cls, name, bases, dct): + if not hasattr(cls, 'type_map'): + cls.type_map = {} + else: + cls.type_map[cls.col_type] = cls + super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct) + + +class ContainerUpdateClause(AssignmentClause, metaclass=ContainerUpdateTypeMapMeta): + + def __init__(self, field, value, operation=None, previous=None): + super(ContainerUpdateClause, self).__init__(field, value) + self.previous = previous + self._assignments = None + self._operation = operation + self._analyzed = False + + def _analyze(self): + raise NotImplementedError + + def get_context_size(self): + raise NotImplementedError + + def update_context(self, ctx): + raise NotImplementedError + + +class SetUpdateClause(ContainerUpdateClause): + """ updates a set collection """ + + col_type = columns.Set + + _additions = None + _removals = None + + def __unicode__(self): + qs = [] + ctx_id = self.context_id + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): + qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] + if self._assignments is not None: + qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 + if self._additions is not None: + qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 + if self._removals is not None: + qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] + + return ', '.join(qs) + + def _analyze(self): + """ works out the updates to be performed """ + if self.value is None or self.value == self.previous: + pass + elif self._operation == "add": + self._additions = self.value + elif self._operation == "remove": + self._removals = self.value + elif self.previous is None: + self._assignments = self.value + else: + # partial update time + self._additions = (self.value - self.previous) or None + self._removals = (self.previous - self.value) or None + self._analyzed = True + + def get_context_size(self): + if not self._analyzed: + self._analyze() + if (self.previous is None and + not self._assignments and + self._additions is None and + self._removals is None): + return 1 + return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + ctx_id = self.context_id + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): + ctx[str(ctx_id)] = set() + if self._assignments is not None: + ctx[str(ctx_id)] = self._assignments + ctx_id += 1 + if self._additions is not None: + ctx[str(ctx_id)] = self._additions + ctx_id += 1 + if self._removals is not None: + ctx[str(ctx_id)] = self._removals + + +class ListUpdateClause(ContainerUpdateClause): + """ updates a list collection """ + + col_type = columns.List + + _append = None + _prepend = None + + def __unicode__(self): + if not self._analyzed: + self._analyze() + qs = [] + ctx_id = self.context_id + if self._assignments is not None: + qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 + + if self._prepend is not None: + qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] + ctx_id += 1 + + if self._append is not None: + qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] + + return ', '.join(qs) + + def get_context_size(self): + if not self._analyzed: + self._analyze() + return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + ctx_id = self.context_id + if self._assignments is not None: + ctx[str(ctx_id)] = self._assignments + ctx_id += 1 + if self._prepend is not None: + ctx[str(ctx_id)] = self._prepend + ctx_id += 1 + if self._append is not None: + ctx[str(ctx_id)] = self._append + + def _analyze(self): + """ works out the updates to be performed """ + if self.value is None or self.value == self.previous: + pass + + elif self._operation == "append": + self._append = self.value + + elif self._operation == "prepend": + self._prepend = self.value + + elif self.previous is None: + self._assignments = self.value + + elif len(self.value) < len(self.previous): + # if elements have been removed, + # rewrite the whole list + self._assignments = self.value + + elif len(self.previous) == 0: + # if we're updating from an empty + # list, do a complete insert + self._assignments = self.value + else: + + # the max start idx we want to compare + search_space = len(self.value) - max(0, len(self.previous) - 1) + + # the size of the sub lists we want to look at + search_size = len(self.previous) + + for i in range(search_space): + # slice boundary + j = i + search_size + sub = self.value[i:j] + idx_cmp = lambda idx: self.previous[idx] == sub[idx] + if idx_cmp(0) and idx_cmp(-1) and self.previous == sub: + self._prepend = self.value[:i] or None + self._append = self.value[j:] or None + break + + # if both append and prepend are still None after looking + # at both lists, an insert statement will be created + if self._prepend is self._append is None: + self._assignments = self.value + + self._analyzed = True + + +class MapUpdateClause(ContainerUpdateClause): + """ updates a map collection """ + + col_type = columns.Map + + _updates = None + _removals = None + + def _analyze(self): + if self._operation == "update": + self._updates = self.value.keys() + elif self._operation == "remove": + self._removals = {v for v in self.value.keys()} + else: + if self.previous is None: + self._updates = sorted([k for k, v in self.value.items()]) + else: + self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None + self._analyzed = True + + def get_context_size(self): + if self.is_assignment: + return 1 + return int((len(self._updates or []) * 2) + int(bool(self._removals))) + + def update_context(self, ctx): + ctx_id = self.context_id + if self.is_assignment: + ctx[str(ctx_id)] = {} + elif self._removals is not None: + ctx[str(ctx_id)] = self._removals + else: + for key in self._updates or []: + val = self.value.get(key) + ctx[str(ctx_id)] = key + ctx[str(ctx_id + 1)] = val + ctx_id += 2 + + @property + def is_assignment(self): + if not self._analyzed: + self._analyze() + return self.previous is None and not self._updates and not self._removals + + def __unicode__(self): + qs = [] + + ctx_id = self.context_id + if self.is_assignment: + qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] + elif self._removals is not None: + qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 + else: + for _ in self._updates or []: + qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)] + ctx_id += 2 + + return ', '.join(qs) + + +class CounterUpdateClause(AssignmentClause): + + col_type = columns.Counter + + def __init__(self, field, value, previous=None): + super(CounterUpdateClause, self).__init__(field, value) + self.previous = previous or 0 + + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = abs(self.value - self.previous) + + def __unicode__(self): + delta = self.value - self.previous + sign = '-' if delta < 0 else '+' + return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) + + +class BaseDeleteClause(BaseClause): + pass + + +class FieldDeleteClause(BaseDeleteClause): + """ deletes a field from a row """ + + def __init__(self, field): + super(FieldDeleteClause, self).__init__(field, None) + + def __unicode__(self): + return '"{0}"'.format(self.field) + + def update_context(self, ctx): + pass + + def get_context_size(self): + return 0 + + +class MapDeleteClause(BaseDeleteClause): + """ removes keys from a map """ + + def __init__(self, field, value, previous=None): + super(MapDeleteClause, self).__init__(field, value) + self.value = self.value or {} + self.previous = previous or {} + self._analyzed = False + self._removals = None + + def _analyze(self): + self._removals = sorted([k for k in self.previous if k not in self.value]) + self._analyzed = True + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + for idx, key in enumerate(self._removals): + ctx[str(self.context_id + idx)] = key + + def get_context_size(self): + if not self._analyzed: + self._analyze() + return len(self._removals) + + def __unicode__(self): + if not self._analyzed: + self._analyze() + return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) + + +class BaseCQLStatement(UnicodeMixin): + """ The base cql statement class """ + + def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None): + super(BaseCQLStatement, self).__init__() + self.table = table + self.context_id = 0 + self.context_counter = self.context_id + self.timestamp = timestamp + self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET + + self.where_clauses = [] + for clause in where or []: + self._add_where_clause(clause) + + self.conditionals = [] + for conditional in conditionals or []: + self.add_conditional_clause(conditional) + + def _update_part_key_values(self, field_index_map, clauses, parts): + for clause in filter(lambda c: c.field in field_index_map, clauses): + parts[field_index_map[clause.field]] = clause.value + + def partition_key_values(self, field_index_map): + parts = [None] * len(field_index_map) + self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts) + return parts + + def add_where(self, column, operator, value, quote_field=True): + value = column.to_database(value) + clause = WhereClause(column.db_field_name, operator, value, quote_field) + self._add_where_clause(clause) + + def _add_where_clause(self, clause): + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.where_clauses.append(clause) + + def get_context(self): + """ + returns the context dict for this statement + :rtype: dict + """ + ctx = {} + for clause in self.where_clauses or []: + clause.update_context(ctx) + return ctx + + def add_conditional_clause(self, clause): + """ + Adds an iff clause to this statement + + :param clause: The clause that will be added to the iff statement + :type clause: ConditionalClause + """ + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.conditionals.append(clause) + + def _get_conditionals(self): + return 'IF {0}'.format(' AND '.join([str(c) for c in self.conditionals])) + + def get_context_size(self): + return len(self.get_context()) + + def update_context_id(self, i): + self.context_id = i + self.context_counter = self.context_id + for clause in self.where_clauses: + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + + @property + def timestamp_normalized(self): + """ + We're expecting self.timestamp to be either a long, int, a datetime, or a timedelta + :return: + """ + if not self.timestamp: + return None + + if isinstance(self.timestamp, int): + return self.timestamp + + if isinstance(self.timestamp, timedelta): + tmp = datetime.now() + self.timestamp + else: + tmp = self.timestamp + + return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) + + def __unicode__(self): + raise NotImplementedError + + def __repr__(self): + return self.__unicode__() + + @property + def _where(self): + return 'WHERE {0}'.format(' AND '.join([str(c) for c in self.where_clauses])) + + +class SelectStatement(BaseCQLStatement): + """ a cql select statement """ + + def __init__(self, + table, + fields=None, + count=False, + where=None, + order_by=None, + limit=None, + allow_filtering=False, + distinct_fields=None, + fetch_size=None): + + """ + :param where + :type where list of cqlengine.statements.WhereClause + """ + super(SelectStatement, self).__init__( + table, + where=where, + fetch_size=fetch_size + ) + + self.fields = [fields] if isinstance(fields, str) else (fields or []) + self.distinct_fields = distinct_fields + self.count = count + self.order_by = [order_by] if isinstance(order_by, str) else order_by + self.limit = limit + self.allow_filtering = allow_filtering + + def __unicode__(self): + qs = ['SELECT'] + if self.distinct_fields: + if self.count: + qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] + else: + qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] + elif self.count: + qs += ['COUNT(*)'] + else: + qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*'] + qs += ['FROM', self.table] + + if self.where_clauses: + qs += [self._where] + + if self.order_by and not self.count: + qs += ['ORDER BY {0}'.format(', '.join(str(o) for o in self.order_by))] + + if self.limit: + qs += ['LIMIT {0}'.format(self.limit)] + + if self.allow_filtering: + qs += ['ALLOW FILTERING'] + + return ' '.join(qs) + + +class AssignmentStatement(BaseCQLStatement): + """ value assignment statements """ + + def __init__(self, + table, + assignments=None, + where=None, + ttl=None, + timestamp=None, + conditionals=None): + super(AssignmentStatement, self).__init__( + table, + where=where, + conditionals=conditionals + ) + self.ttl = ttl + self.timestamp = timestamp + + # add assignments + self.assignments = [] + for assignment in assignments or []: + self._add_assignment_clause(assignment) + + def update_context_id(self, i): + super(AssignmentStatement, self).update_context_id(i) + for assignment in self.assignments: + assignment.set_context_id(self.context_counter) + self.context_counter += assignment.get_context_size() + + def partition_key_values(self, field_index_map): + parts = super(AssignmentStatement, self).partition_key_values(field_index_map) + self._update_part_key_values(field_index_map, self.assignments, parts) + return parts + + def add_assignment(self, column, value): + value = column.to_database(value) + clause = AssignmentClause(column.db_field_name, value) + self._add_assignment_clause(clause) + + def _add_assignment_clause(self, clause): + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.assignments.append(clause) + + @property + def is_empty(self): + return len(self.assignments) == 0 + + def get_context(self): + ctx = super(AssignmentStatement, self).get_context() + for clause in self.assignments: + clause.update_context(ctx) + return ctx + + +class InsertStatement(AssignmentStatement): + """ an cql insert statement """ + + def __init__(self, + table, + assignments=None, + where=None, + ttl=None, + timestamp=None, + if_not_exists=False): + super(InsertStatement, self).__init__(table, + assignments=assignments, + where=where, + ttl=ttl, + timestamp=timestamp) + + self.if_not_exists = if_not_exists + + def __unicode__(self): + qs = ['INSERT INTO {0}'.format(self.table)] + + # get column names and context placeholders + fields = [a.insert_tuple() for a in self.assignments] + columns, values = zip(*fields) + + qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))] + qs += ['VALUES'] + qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))] + + if self.if_not_exists: + qs += ["IF NOT EXISTS"] + + using_options = [] + if self.ttl: + using_options += ["TTL {}".format(self.ttl)] + + if self.timestamp: + using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)] + + if using_options: + qs += ["USING {}".format(" AND ".join(using_options))] + return ' '.join(qs) + + +class UpdateStatement(AssignmentStatement): + """ an cql update select statement """ + + def __init__(self, + table, + assignments=None, + where=None, + ttl=None, + timestamp=None, + conditionals=None, + if_exists=False): + super(UpdateStatement, self). __init__(table, + assignments=assignments, + where=where, + ttl=ttl, + timestamp=timestamp, + conditionals=conditionals) + + self.if_exists = if_exists + + def __unicode__(self): + qs = ['UPDATE', self.table] + + using_options = [] + + if self.ttl: + using_options += ["TTL {0}".format(self.ttl)] + + if self.timestamp: + using_options += ["TIMESTAMP {0}".format(self.timestamp_normalized)] + + if using_options: + qs += ["USING {0}".format(" AND ".join(using_options))] + + qs += ['SET'] + qs += [', '.join([str(c) for c in self.assignments])] + + if self.where_clauses: + qs += [self._where] + + if len(self.conditionals) > 0: + qs += [self._get_conditionals()] + + if self.if_exists: + qs += ["IF EXISTS"] + + return ' '.join(qs) + + def get_context(self): + ctx = super(UpdateStatement, self).get_context() + for clause in self.conditionals: + clause.update_context(ctx) + return ctx + + def update_context_id(self, i): + super(UpdateStatement, self).update_context_id(i) + for conditional in self.conditionals: + conditional.set_context_id(self.context_counter) + self.context_counter += conditional.get_context_size() + + def add_update(self, column, value, operation=None, previous=None): + value = column.to_database(value) + col_type = type(column) + container_update_type = ContainerUpdateClause.type_map.get(col_type) + if container_update_type: + previous = column.to_database(previous) + clause = container_update_type(column.db_field_name, value, operation, previous) + elif col_type == columns.Counter: + clause = CounterUpdateClause(column.db_field_name, value, previous) + else: + clause = AssignmentClause(column.db_field_name, value) + if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates + self._add_assignment_clause(clause) + + +class DeleteStatement(BaseCQLStatement): + """ a cql delete statement """ + + def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False): + super(DeleteStatement, self).__init__( + table, + where=where, + timestamp=timestamp, + conditionals=conditionals + ) + self.fields = [] + if isinstance(fields, str): + fields = [fields] + for field in fields or []: + self.add_field(field) + + self.if_exists = if_exists + + def update_context_id(self, i): + super(DeleteStatement, self).update_context_id(i) + for field in self.fields: + field.set_context_id(self.context_counter) + self.context_counter += field.get_context_size() + for t in self.conditionals: + t.set_context_id(self.context_counter) + self.context_counter += t.get_context_size() + + def get_context(self): + ctx = super(DeleteStatement, self).get_context() + for field in self.fields: + field.update_context(ctx) + for clause in self.conditionals: + clause.update_context(ctx) + return ctx + + def add_field(self, field): + if isinstance(field, str): + field = FieldDeleteClause(field) + if not isinstance(field, BaseClause): + raise StatementException("only instances of AssignmentClause can be added to statements") + field.set_context_id(self.context_counter) + self.context_counter += field.get_context_size() + self.fields.append(field) + + def __unicode__(self): + qs = ['DELETE'] + if self.fields: + qs += [', '.join(['{0}'.format(f) for f in self.fields])] + qs += ['FROM', self.table] + + delete_option = [] + + if self.timestamp: + delete_option += ["TIMESTAMP {0}".format(self.timestamp_normalized)] + + if delete_option: + qs += [" USING {0} ".format(" AND ".join(delete_option))] + + if self.where_clauses: + qs += [self._where] + + if self.conditionals: + qs += [self._get_conditionals()] + + if self.if_exists: + qs += ["IF EXISTS"] + + return ' '.join(qs) diff --git a/cassandra/cqlengine/usertype.py b/cassandra/cqlengine/usertype.py new file mode 100644 index 0000000000..e96534f9c6 --- /dev/null +++ b/cassandra/cqlengine/usertype.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from cassandra.util import OrderedDict +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine import columns +from cassandra.cqlengine import connection as conn +from cassandra.cqlengine import models + + +class UserTypeException(CQLEngineException): + pass + + +class UserTypeDefinitionException(UserTypeException): + pass + + +class BaseUserType(object): + """ + The base type class; don't inherit from this, inherit from UserType, defined below + """ + __type_name__ = None + + _fields = None + _db_map = None + + def __init__(self, **values): + self._values = {} + if self._db_map: + values = dict((self._db_map.get(k, k), v) for k, v in values.items()) + + for name, field in self._fields.items(): + field_default = field.get_default() if field.has_default else None + value = values.get(name, field_default) + if value is not None or isinstance(field, columns.BaseContainerColumn): + value = field.to_python(value) + value_mngr = field.value_manager(self, field, value) + value_mngr.explicit = name in values + self._values[name] = value_mngr + + def __eq__(self, other): + if self.__class__ != other.__class__: + return False + + keys = set(self._fields.keys()) + other_keys = set(other._fields.keys()) + if keys != other_keys: + return False + + for key in other_keys: + if getattr(self, key, None) != getattr(other, key, None): + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in self._values.items())) + + def has_changed_fields(self): + return any(v.changed for v in self._values.values()) + + def reset_changed_fields(self): + for v in self._values.values(): + v.reset_previous_value() + + def __iter__(self): + for field in self._fields.keys(): + yield field + + def __getattr__(self, attr): + # provides the mapping from db_field to fields + try: + return getattr(self, self._db_map[attr]) + except KeyError: + raise AttributeError(attr) + + def __getitem__(self, key): + if not isinstance(key, str): + raise TypeError + if key not in self._fields.keys(): + raise KeyError + return getattr(self, key) + + def __setitem__(self, key, val): + if not isinstance(key, str): + raise TypeError + if key not in self._fields.keys(): + raise KeyError + return setattr(self, key, val) + + def __len__(self): + try: + return self._len + except: + self._len = len(self._fields.keys()) + return self._len + + def keys(self): + """ Returns a list of column IDs. """ + return [k for k in self] + + def values(self): + """ Returns list of column values. """ + return [self[k] for k in self] + + def items(self): + """ Returns a list of column ID/value tuples. """ + return [(k, self[k]) for k in self] + + @classmethod + def register_for_keyspace(cls, keyspace, connection=None): + conn.register_udt(keyspace, cls.type_name(), cls, connection=connection) + + @classmethod + def type_name(cls): + """ + Returns the type name if it's been defined + otherwise, it creates it from the class name + """ + if cls.__type_name__: + type_name = cls.__type_name__.lower() + else: + camelcase = re.compile(r'([a-z])([A-Z])') + ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2)), s) + + type_name = ccase(cls.__name__) + # trim to less than 48 characters or cassandra will complain + type_name = type_name[-48:] + type_name = type_name.lower() + type_name = re.sub(r'^_+', '', type_name) + cls.__type_name__ = type_name + + return type_name + + def validate(self): + """ + Cleans and validates the field values + """ + for name, field in self._fields.items(): + v = getattr(self, name) + if v is None and not self._values[name].explicit and field.has_default: + v = field.get_default() + val = field.validate(v) + setattr(self, name, val) + + +class UserTypeMetaClass(type): + + def __new__(cls, name, bases, attrs): + field_dict = OrderedDict() + + field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] + field_defs = sorted(field_defs, key=lambda x: x[1].position) + + def _transform_column(field_name, field_obj): + field_dict[field_name] = field_obj + field_obj.set_column_name(field_name) + attrs[field_name] = models.ColumnDescriptor(field_obj) + + # transform field definitions + for k, v in field_defs: + # don't allow a field with the same name as a built-in attribute or method + if k in BaseUserType.__dict__: + raise UserTypeDefinitionException("field '{0}' conflicts with built-in attribute/method".format(k)) + _transform_column(k, v) + + attrs['_fields'] = field_dict + + db_map = {} + for field_name, field in field_dict.items(): + db_field = field.db_field_name + if db_field != field_name: + if db_field in field_dict: + raise UserTypeDefinitionException("db_field '{0}' for field '{1}' conflicts with another attribute name".format(db_field, field_name)) + db_map[db_field] = field_name + attrs['_db_map'] = db_map + + klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs) + + return klass + + +class UserType(BaseUserType, metaclass=UserTypeMetaClass): + """ + This class is used to model User Defined Types. To define a type, declare a class inheriting from this, + and assign field types as class attributes: + + .. code-block:: python + + # connect with default keyspace ... + + from cassandra.cqlengine.columns import Text, Integer + from cassandra.cqlengine.usertype import UserType + + class address(UserType): + street = Text() + zipcode = Integer() + + from cassandra.cqlengine import management + management.sync_type(address) + + Please see :ref:`user_types` for a complete example and discussion. + """ + + __type_name__ = None + """ + *Optional.* Sets the name of the CQL type for this type. + + If not specified, the type name will be the name of the class, with it's module name as it's prefix. + """ diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 7fe5e4c527..7cde6765c0 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Representation of Cassandra data types. These classes should make it simple for the library (and caller software) to deal with Cassandra-style Java class type @@ -13,44 +29,67 @@ # for example), these classes would be a good place to tack on # .from_cql_literal() and .as_cql_literal() classmethods (or whatever). +import ast +from binascii import unhexlify import calendar +from collections import namedtuple from decimal import Decimal +import io +from itertools import chain +import logging import re import socket import time -from datetime import datetime +import struct +import sys from uuid import UUID -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO # NOQA - -from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack, +from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, + uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, - varint_pack, varint_unpack) + varint_pack, varint_unpack, point_be, point_le, + vints_pack, vints_unpack, uvint_unpack, uvint_pack) +from cassandra import util + +_little_endian_flag = 1 # we always serialize LE +import ipaddress apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' -_number_types = frozenset((int, long, float)) +cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType' +cql_empty_type = 'empty' + +log = logging.getLogger(__name__) + +_number_types = frozenset((int, float)) -from blist import sortedset -try: - from collections import OrderedDict -except ImportError: # Python <2.7 - from cassandra.util import OrderedDict # NOQA +def _name_from_hex_string(encoded_name): + bin_str = unhexlify(encoded_name) + return bin_str.decode('ascii') + def trim_if_startswith(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s -def unix_time_from_uuid1(u): - return (u.get_time() - 0x01B21DD213814000) / 10000000.0 _casstypes = {} +_cqltypes = {} + + +cql_type_scanner = re.Scanner(( + ('frozen', None), + (r'[a-zA-Z0-9_]+', lambda s, t: t), + (r'[\s,<>]', None), +)) + + +def cql_types_from_string(cql_type): + return cql_type_scanner.scan(cql_type)[0] + class CassandraTypeType(type): """ @@ -68,6 +107,8 @@ def __new__(metacls, name, bases, dct): cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _casstypes[name] = cls + if not cls.typename.startswith(apache_cassandra_type_prefix): + _cqltypes[cls.typename] = cls return cls @@ -78,6 +119,73 @@ def __new__(metacls, name, bases, dct): )) +def cqltype_to_python(cql_string): + """ + Given a cql type string, creates a list that can be manipulated in python + Example: + int -> ['int'] + frozen> -> ['frozen', ['tuple', ['text', 'int']]] + """ + scanner = re.Scanner(( + (r'[a-zA-Z0-9_]+', lambda s, t: "'{}'".format(t)), + (r'<', lambda s, t: ', ['), + (r'>', lambda s, t: ']'), + (r'[, ]', lambda s, t: t), + (r'".*?"', lambda s, t: "'{}'".format(t)), + )) + + scanned_tokens = scanner.scan(cql_string)[0] + hierarchy = ast.literal_eval(''.join(scanned_tokens)) + return [hierarchy] if isinstance(hierarchy, str) else list(hierarchy) + + +def python_to_cqltype(types): + """ + Opposite of the `cql_to_python` function. Given a python list, creates a cql type string from the representation + Example: + ['int'] -> int + ['frozen', ['tuple', ['text', 'int']]] -> frozen> + """ + scanner = re.Scanner(( + (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), + (r'^\[', lambda s, t: None), + (r'\]$', lambda s, t: None), + (r',\s*\[', lambda s, t: '<'), + (r'\]', lambda s, t: '>'), + (r'[, ]', lambda s, t: t), + (r'\'".*?"\'', lambda s, t: t[1:-1]), + )) + + scanned_tokens = scanner.scan(repr(types))[0] + cql = ''.join(scanned_tokens).replace('\\\\', '\\') + return cql + + +def _strip_frozen_from_python(types): + """ + Given a python list representing a cql type, removes 'frozen' + Example: + ['frozen', ['tuple', ['text', 'int']]] -> ['tuple', ['text', 'int']] + """ + while 'frozen' in types: + index = types.index('frozen') + types = types[:index] + types[index + 1] + types[index + 2:] + new_types = [_strip_frozen_from_python(item) if isinstance(item, list) else item for item in types] + return new_types + + +def strip_frozen(cql): + """ + Given a cql type string, and removes frozen + Example: + frozen> -> tuple + """ + types = cqltype_to_python(cql) + types_without_frozen = _strip_frozen_from_python(types) + cql = python_to_cqltype(types_without_frozen) + return cql + + def lookup_casstype_simple(casstype): """ Given a Cassandra type name (either fully distinguished or not), hand @@ -100,25 +208,33 @@ def parse_casstype_args(typestring): tokens, remainder = casstype_scanner.scan(typestring) if remainder: raise ValueError("weird characters %r at end" % remainder) - args = [[]] + + # use a stack of (types, names) lists + args = [([], [])] for tok in tokens: if tok == '(': - args.append([]) + args.append(([], [])) elif tok == ')': - arglist = args.pop() - ctype = args[-1].pop() - paramized = ctype.apply_parameters(*arglist) - args[-1].append(paramized) + types, names = args.pop() + prev_types, prev_names = args[-1] + prev_types[-1] = prev_types[-1].apply_parameters(types, names) else: - if ':' in tok: - # ignore those column name hex encoding bit; we have the - # proper column name from elsewhere - tok = tok.rsplit(':', 1)[-1] - ctype = lookup_casstype_simple(tok) - args[-1].append(ctype) + types, names = args[-1] + parts = re.split(':|=>', tok) + tok = parts.pop() + if parts: + names.append(parts[0]) + else: + names.append(None) - return args[0][0] + try: + ctype = int(tok) + except ValueError: + ctype = lookup_casstype_simple(tok) + types.append(ctype) + # return the first (outer) type, which will have all parameters applied + return args[0][0][0] def lookup_casstype(casstype): """ @@ -129,7 +245,7 @@ def lookup_casstype(casstype): Example: >>> lookup_casstype('org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type)') - + """ if isinstance(casstype, (CassandraType, CassandraTypeType)): @@ -140,50 +256,65 @@ def lookup_casstype(casstype): raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e)) -class _CassandraType(object): - __metaclass__ = CassandraTypeType - subtypes = () - num_subtypes = 0 - empty_binary_ok = False +def is_reversed_casstype(data_type): + return issubclass(data_type, ReversedType) - def __init__(self, val): - self.val = self.validate(val) + +class EmptyValue(object): + """ See _CassandraType.support_empty_values """ def __str__(self): - return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) + return "EMPTY" __repr__ = __str__ - @staticmethod - def validate(val): - """ - Called to transform an input value into one of a suitable type - for this class. As an example, the BooleanType class uses this - to convert an incoming value to True or False. - """ - return val +EMPTY = EmptyValue() + + +class _CassandraType(object, metaclass=CassandraTypeType): + subtypes = () + num_subtypes = 0 + empty_binary_ok = False + + support_empty_values = False + """ + Back in the Thrift days, empty strings were used for "null" values of + all types, including non-string types. For most users, an empty + string value in an int column is the same as being null/not present, + so the driver normally returns None in this case. (For string-like + types, it *will* return an empty string by default instead of None.) + + To avoid this behavior, set this to :const:`True`. Instead of returning + None for empty string values, the EMPTY singleton (an instance + of EmptyValue) will be returned. + """ + + def __repr__(self): + return '<%s>' % (self.cql_parameterized_type()) @classmethod - def from_binary(cls, byts): + def from_binary(cls, byts, protocol_version): """ Deserialize a bytestring into a value. See the deserialize() method for more information. This method differs in that if None or the empty string is passed in, None may be returned. """ - if byts is None or (byts == '' and not cls.empty_binary_ok): + if byts is None: return None - return cls.deserialize(byts) + elif len(byts) == 0 and not cls.empty_binary_ok: + return EMPTY if cls.support_empty_values else None + return cls.deserialize(byts, protocol_version) @classmethod - def to_binary(cls, val): + def to_binary(cls, val, protocol_version): """ Serialize a value into a bytestring. See the serialize() method for more information. This method differs in that if None is passed in, the result is the empty string. """ - return '' if val is None else cls.serialize(val) + return b'' if val is None else cls.serialize(val, protocol_version) @staticmethod - def deserialize(byts): + def deserialize(byts, protocol_version): """ Given a bytestring, deserialize into a value according to the protocol for this type. Note that this does not create a new instance of this @@ -193,7 +324,7 @@ def deserialize(byts): return byts @staticmethod - def serialize(val): + def serialize(val, protocol_version): """ Given a value appropriate for this class, serialize it according to the protocol for this type and return the corresponding bytestring. @@ -227,19 +358,22 @@ def cass_parameterized_type_with(cls, subtypes, full=False): return '%s(%s)' % (cname, sublist) @classmethod - def apply_parameters(cls, *subtypes): + def apply_parameters(cls, subtypes, names=None): """ Given a set of other CassandraTypes, create a new subtype of this type using them as parameters. This is how composite types are constructed. - >>> MapType.apply_parameters(DateType, BooleanType) - + >>> MapType.apply_parameters([DateType, BooleanType]) + + + `subtypes` will be a sequence of CassandraTypes. If provided, `names` + will be an equally long sequence of column names or Nones. """ if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: raise ValueError("%s types require %d subtypes (%d given)" % (cls.typename, cls.num_subtypes, len(subtypes))) - newname = cls.cass_parameterized_type_with(subtypes).encode('utf8') - return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname}) + newname = cls.cass_parameterized_type_with(subtypes) + return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) @classmethod def cql_parameterized_type(cls): @@ -259,6 +393,9 @@ def cass_parameterized_type(cls, full=False): """ return cls.cass_parameterized_type_with(cls.subtypes, full=full) + @classmethod + def serial_size(cls): + return None # it's initially named with a _ to avoid registering it as a real type, but # client programs may want to use the name still for isinstance(), etc @@ -270,7 +407,7 @@ class _UnrecognizedType(_CassandraType): def mkUnrecognizedType(casstypename): - return CassandraTypeType(casstypename.encode('utf8'), + return CassandraTypeType(casstypename, (_UnrecognizedType,), {'typename': "'%s'" % casstypename}) @@ -280,30 +417,28 @@ class BytesType(_CassandraType): empty_binary_ok = True @staticmethod - def validate(val): - return buffer(val) - - @staticmethod - def serialize(val): - return str(val) + def serialize(val, protocol_version): + return bytes(val) class DecimalType(_CassandraType): typename = 'decimal' @staticmethod - def validate(val): - return Decimal(val) - - @staticmethod - def deserialize(byts): + def deserialize(byts, protocol_version): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) return Decimal('%de%d' % (unscaled, -scale)) @staticmethod - def serialize(dec): - sign, digits, exponent = dec.as_tuple() + def serialize(dec, protocol_version): + try: + sign, digits, exponent = dec.as_tuple() + except AttributeError: + try: + sign, digits, exponent = Decimal(dec).as_tuple() + except Exception: + raise TypeError("Invalid type for Decimal value: %r", dec) unscaled = int(''.join([str(digit) for digit in digits])) if sign: unscaled *= -1 @@ -316,109 +451,166 @@ class UUIDType(_CassandraType): typename = 'uuid' @staticmethod - def deserialize(byts): + def deserialize(byts, protocol_version): return UUID(bytes=byts) @staticmethod - def serialize(uuid): - return uuid.bytes + def serialize(uuid, protocol_version): + try: + return uuid.bytes + except AttributeError: + raise TypeError("Got a non-UUID object for a UUID value") + @classmethod + def serial_size(cls): + return 16 class BooleanType(_CassandraType): typename = 'boolean' @staticmethod - def validate(val): - return bool(val) + def deserialize(byts, protocol_version): + return bool(int8_unpack(byts)) + + @staticmethod + def serialize(truth, protocol_version): + return int8_pack(truth) + + @classmethod + def serial_size(cls): + return 1 + +class ByteType(_CassandraType): + typename = 'tinyint' @staticmethod - def deserialize(byts): - return bool(int8_unpack(byts)) + def deserialize(byts, protocol_version): + return int8_unpack(byts) @staticmethod - def serialize(truth): - return int8_pack(bool(truth)) + def serialize(byts, protocol_version): + return int8_pack(byts) class AsciiType(_CassandraType): typename = 'ascii' empty_binary_ok = True + @staticmethod + def deserialize(byts, protocol_version): + return byts.decode('ascii') + + @staticmethod + def serialize(var, protocol_version): + try: + return var.encode('ascii') + except UnicodeDecodeError: + return var + class FloatType(_CassandraType): typename = 'float' - deserialize = staticmethod(float_unpack) - serialize = staticmethod(float_pack) + @staticmethod + def deserialize(byts, protocol_version): + return float_unpack(byts) + @staticmethod + def serialize(byts, protocol_version): + return float_pack(byts) + + @classmethod + def serial_size(cls): + return 4 class DoubleType(_CassandraType): typename = 'double' - deserialize = staticmethod(double_unpack) - serialize = staticmethod(double_pack) + @staticmethod + def deserialize(byts, protocol_version): + return double_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return double_pack(byts) + @classmethod + def serial_size(cls): + return 8 class LongType(_CassandraType): typename = 'bigint' - deserialize = staticmethod(int64_unpack) - serialize = staticmethod(int64_pack) + @staticmethod + def deserialize(byts, protocol_version): + return int64_unpack(byts) + @staticmethod + def serialize(byts, protocol_version): + return int64_pack(byts) + + @classmethod + def serial_size(cls): + return 8 class Int32Type(_CassandraType): typename = 'int' - deserialize = staticmethod(int32_unpack) - serialize = staticmethod(int32_pack) + @staticmethod + def deserialize(byts, protocol_version): + return int32_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return int32_pack(byts) + @classmethod + def serial_size(cls): + return 4 class IntegerType(_CassandraType): typename = 'varint' - deserialize = staticmethod(varint_unpack) - serialize = staticmethod(varint_pack) + @staticmethod + def deserialize(byts, protocol_version): + return varint_unpack(byts) + @staticmethod + def serialize(byts, protocol_version): + return varint_pack(byts) -have_ipv6_packing = hasattr(socket, 'inet_ntop') class InetAddressType(_CassandraType): typename = 'inet' - # TODO: implement basic ipv6 support for Windows? - # inet_ntop and inet_pton aren't available on Windows - @staticmethod - def deserialize(byts): + def deserialize(byts, protocol_version): if len(byts) == 16: - if not have_ipv6_packing: - raise Exception( - "IPv6 addresses cannot currently be handled on Windows") - return socket.inet_ntop(socket.AF_INET6, byts) + return util.inet_ntop(socket.AF_INET6, byts) else: + # util.inet_pton could also handle, but this is faster + # since we've already determined the AF return socket.inet_ntoa(byts) @staticmethod - def serialize(addr): - if ':' in addr: - fam = socket.AF_INET6 - if not have_ipv6_packing: - raise Exception( - "IPv6 addresses cannot currently be handled on Windows") - return socket.inet_pton(fam, addr) - else: - fam = socket.AF_INET - return socket.inet_aton(addr) - - -class CounterColumnType(_CassandraType): + def serialize(addr, protocol_version): + try: + if ':' in addr: + return util.inet_pton(socket.AF_INET6, addr) + else: + # util.inet_pton could also handle, but this is faster + # since we've already determined the AF + return socket.inet_aton(addr) + except: + if isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)): + return addr.packed + raise ValueError("can't interpret %r as an inet address" % (addr,)) + + +class CounterColumnType(LongType): typename = 'counter' - deserialize = staticmethod(int64_unpack) - serialize = staticmethod(int64_pack) - - -cql_time_formats = ( +cql_timestamp_formats = ( '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%dT%H:%M', @@ -426,66 +618,156 @@ class CounterColumnType(_CassandraType): '%Y-%m-%d' ) +_have_warned_about_timestamps = False + class DateType(_CassandraType): typename = 'timestamp' - @classmethod - def validate(cls, date): - if isinstance(date, basestring): - date = cls.interpret_datestring(date) - return date - @staticmethod - def interpret_datestring(date): - if date[-5] in ('+', '-'): - offset = (int(date[-4:-2]) * 3600 + int(date[-2:]) * 60) * int(date[-5] + '1') - date = date[:-5] + def interpret_datestring(val): + if val[-5] in ('+', '-'): + offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') + val = val[:-5] else: offset = -time.timezone - for tformat in cql_time_formats: + for tformat in cql_timestamp_formats: try: - tval = time.strptime(date, tformat) + tval = time.strptime(val, tformat) except ValueError: continue - return calendar.timegm(tval) + offset + # scale seconds to millis for the raw value + return (calendar.timegm(tval) + offset) * 1e3 else: - raise ValueError("can't interpret %r as a date" % (date,)) - - def my_timestamp(self): - return self.val + raise ValueError("can't interpret %r as a date" % (val,)) @staticmethod - def deserialize(byts): - return datetime.utcfromtimestamp(int64_unpack(byts) / 1000.0) + def deserialize(byts, protocol_version): + timestamp = int64_unpack(byts) / 1000.0 + return util.datetime_from_timestamp(timestamp) @staticmethod - def serialize(v): + def serialize(v, protocol_version): try: - converted = calendar.timegm(v.utctimetuple()) - converted = converted * 1e3 + getattr(v, 'microsecond', 0) / 1e3 + # v is datetime + timestamp_seconds = calendar.timegm(v.utctimetuple()) + timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3 except AttributeError: - # Ints and floats are valid timestamps too - if type(v) not in _number_types: - raise TypeError('DateType arguments must be a datetime or timestamp') + try: + timestamp = calendar.timegm(v.timetuple()) * 1e3 + except AttributeError: + # Ints and floats are valid timestamps too + if type(v) not in _number_types: + raise TypeError('DateType arguments must be a datetime, date, or timestamp') + timestamp = v + + return int64_pack(int(timestamp)) + + @classmethod + def serial_size(cls): + return 8 - converted = v * 1e3 +class TimestampType(DateType): + pass - return int64_pack(long(converted)) class TimeUUIDType(DateType): typename = 'timeuuid' def my_timestamp(self): - return unix_time_from_uuid1(self.val) + return util.unix_time_from_uuid1(self.val) @staticmethod - def deserialize(byts): + def deserialize(byts, protocol_version): return UUID(bytes=byts) @staticmethod - def serialize(timeuuid): - return timeuuid.bytes + def serialize(timeuuid, protocol_version): + try: + return timeuuid.bytes + except AttributeError: + raise TypeError("Got a non-UUID object for a UUID value") + + @classmethod + def serial_size(cls): + return 16 + +class SimpleDateType(_CassandraType): + typename = 'date' + date_format = "%Y-%m-%d" + + # Values of the 'date'` type are encoded as 32-bit unsigned integers + # representing a number of days with epoch (January 1st, 1970) at the center of the + # range (2^31). + EPOCH_OFFSET_DAYS = 2 ** 31 + + @staticmethod + def deserialize(byts, protocol_version): + days = uint32_unpack(byts) - SimpleDateType.EPOCH_OFFSET_DAYS + return util.Date(days) + + @staticmethod + def serialize(val, protocol_version): + try: + days = val.days_from_epoch + except AttributeError: + if isinstance(val, int): + # the DB wants offset int values, but util.Date init takes days from epoch + # here we assume int values are offset, as they would appear in CQL + # short circuit to avoid subtracting just to add offset + return uint32_pack(val) + days = util.Date(val).days_from_epoch + return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS) + + +class ShortType(_CassandraType): + typename = 'smallint' + + @staticmethod + def deserialize(byts, protocol_version): + return int16_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return int16_pack(byts) + +class TimeType(_CassandraType): + typename = 'time' + # Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as + # variable size... and we have to match what the server expects since the server + # uses that specification to encode data of that type. + #@classmethod + #def serial_size(cls): + # return 8 + + @staticmethod + def deserialize(byts, protocol_version): + return util.Time(int64_unpack(byts)) + + @staticmethod + def serialize(val, protocol_version): + try: + nano = val.nanosecond_time + except AttributeError: + nano = util.Time(val).nanosecond_time + return int64_pack(nano) + + +class DurationType(_CassandraType): + typename = 'duration' + + @staticmethod + def deserialize(byts, protocol_version): + months, days, nanoseconds = vints_unpack(byts) + return util.Duration(months, days, nanoseconds) + + @staticmethod + def serialize(duration, protocol_version): + try: + m, d, n = duration.months, duration.days, duration.nanoseconds + except AttributeError: + raise TypeError('DurationType arguments must be a Duration.') + return vints_pack([m, d, n]) class UTF8Type(_CassandraType): @@ -493,63 +775,78 @@ class UTF8Type(_CassandraType): empty_binary_ok = True @staticmethod - def deserialize(byts): + def deserialize(byts, protocol_version): return byts.decode('utf8') @staticmethod - def serialize(ustr): - return ustr.encode('utf8') + def serialize(ustr, protocol_version): + try: + return ustr.encode('utf-8') + except UnicodeDecodeError: + # already utf-8 + return ustr + + +class VarcharType(UTF8Type): + typename = 'varchar' class _ParameterizedType(_CassandraType): - def __init__(self, val): - if not self.subtypes: - raise ValueError("%s type with no parameters can't be instantiated" % (self.typename,)) - _CassandraType.__init__(self, val) + num_subtypes = 'UNKNOWN' @classmethod - def deserialize(cls, byts): + def deserialize(cls, byts, protocol_version): if not cls.subtypes: raise NotImplementedError("can't deserialize unparameterized %s" % cls.typename) - return cls.deserialize_safe(byts) + return cls.deserialize_safe(byts, protocol_version) @classmethod - def serialize(cls, val): + def serialize(cls, val, protocol_version): if not cls.subtypes: raise NotImplementedError("can't serialize unparameterized %s" % cls.typename) - return cls.serialize_safe(val) + return cls.serialize_safe(val, protocol_version) class _SimpleParameterizedType(_ParameterizedType): @classmethod - def validate(cls, val): + def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes - return cls.adapter([subtype.validate(subval) for subval in val]) - - @classmethod - def deserialize_safe(cls, byts): - subtype, = cls.subtypes - numelements = uint16_unpack(byts[:2]) - p = 2 + if protocol_version >= 3: + unpack = int32_unpack + length = 4 + else: + unpack = uint16_unpack + length = 2 + numelements = unpack(byts[:length]) + p = length result = [] - for n in xrange(numelements): - itemlen = uint16_unpack(byts[p:p + 2]) - p += 2 - item = byts[p:p + itemlen] - p += itemlen - result.append(subtype.from_binary(item)) + inner_proto = max(3, protocol_version) + for _ in range(numelements): + itemlen = unpack(byts[p:p + length]) + p += length + if itemlen < 0: + result.append(None) + else: + item = byts[p:p + itemlen] + p += itemlen + result.append(subtype.from_binary(item, inner_proto)) return cls.adapter(result) @classmethod - def serialize_safe(cls, items): + def serialize_safe(cls, items, protocol_version): + if isinstance(items, str): + raise TypeError("Received a string for a type that expects a sequence") + subtype, = cls.subtypes - buf = StringIO() - buf.write(uint16_pack(len(items))) + pack = int32_pack if protocol_version >= 3 else uint16_pack + buf = io.BytesIO() + buf.write(pack(len(items))) + inner_proto = max(3, protocol_version) for item in items: - itembytes = subtype.to_binary(item) - buf.write(uint16_pack(len(itembytes))) + itembytes = subtype.to_binary(item, inner_proto) + buf.write(pack(len(itembytes))) buf.write(itembytes) return buf.getvalue() @@ -557,13 +854,13 @@ def serialize_safe(cls, items): class ListType(_SimpleParameterizedType): typename = 'list' num_subtypes = 1 - adapter = tuple + adapter = list class SetType(_SimpleParameterizedType): typename = 'set' num_subtypes = 1 - adapter = sortedset + adapter = util.sortedset class MapType(_ParameterizedType): @@ -571,53 +868,251 @@ class MapType(_ParameterizedType): num_subtypes = 2 @classmethod - def validate(cls, val): - subkeytype, subvaltype = cls.subtypes - return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in val.iteritems()) - - @classmethod - def deserialize_safe(cls, byts): - subkeytype, subvaltype = cls.subtypes - numelements = uint16_unpack(byts[:2]) - p = 2 - themap = OrderedDict() - for n in xrange(numelements): - key_len = uint16_unpack(byts[p:p + 2]) - p += 2 - keybytes = byts[p:p + key_len] - p += key_len - val_len = uint16_unpack(byts[p:p + 2]) - p += 2 - valbytes = byts[p:p + val_len] - p += val_len - key = subkeytype.from_binary(keybytes) - val = subvaltype.from_binary(valbytes) - themap[key] = val + def deserialize_safe(cls, byts, protocol_version): + key_type, value_type = cls.subtypes + if protocol_version >= 3: + unpack = int32_unpack + length = 4 + else: + unpack = uint16_unpack + length = 2 + numelements = unpack(byts[:length]) + p = length + themap = util.OrderedMapSerializedKey(key_type, protocol_version) + inner_proto = max(3, protocol_version) + for _ in range(numelements): + key_len = unpack(byts[p:p + length]) + p += length + if key_len < 0: + keybytes = None + key = None + else: + keybytes = byts[p:p + key_len] + p += key_len + key = key_type.from_binary(keybytes, inner_proto) + + val_len = unpack(byts[p:p + length]) + p += length + if val_len < 0: + val = None + else: + valbytes = byts[p:p + val_len] + p += val_len + val = value_type.from_binary(valbytes, inner_proto) + + themap._insert_unchecked(key, keybytes, val) return themap @classmethod - def serialize_safe(cls, themap): - subkeytype, subvaltype = cls.subtypes - buf = StringIO() - buf.write(uint16_pack(len(themap))) - for key, val in themap.iteritems(): - keybytes = subkeytype.to_binary(key) - valbytes = subvaltype.to_binary(val) - buf.write(uint16_pack(len(keybytes))) + def serialize_safe(cls, themap, protocol_version): + key_type, value_type = cls.subtypes + pack = int32_pack if protocol_version >= 3 else uint16_pack + buf = io.BytesIO() + buf.write(pack(len(themap))) + try: + items = themap.items() + except AttributeError: + raise TypeError("Got a non-map object for a map value") + inner_proto = max(3, protocol_version) + for key, val in items: + keybytes = key_type.to_binary(key, inner_proto) + valbytes = value_type.to_binary(val, inner_proto) + buf.write(pack(len(keybytes))) buf.write(keybytes) - buf.write(uint16_pack(len(valbytes))) + buf.write(pack(len(valbytes))) buf.write(valbytes) return buf.getvalue() +class TupleType(_ParameterizedType): + typename = 'tuple' + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + proto_version = max(3, protocol_version) + p = 0 + values = [] + for col_type in cls.subtypes: + if p == len(byts): + break + itemlen = int32_unpack(byts[p:p + 4]) + p += 4 + if itemlen >= 0: + item = byts[p:p + itemlen] + p += itemlen + else: + item = None + # collections inside UDTs are always encoded with at least the + # version 3 format + values.append(col_type.from_binary(item, proto_version)) + + if len(values) < len(cls.subtypes): + nones = [None] * (len(cls.subtypes) - len(values)) + values = values + nones + + return tuple(values) + + @classmethod + def serialize_safe(cls, val, protocol_version): + if len(val) > len(cls.subtypes): + raise ValueError("Expected %d items in a tuple, but got %d: %s" % + (len(cls.subtypes), len(val), val)) + + proto_version = max(3, protocol_version) + buf = io.BytesIO() + for item, subtype in zip(val, cls.subtypes): + if item is not None: + packed_item = subtype.to_binary(item, proto_version) + buf.write(int32_pack(len(packed_item))) + buf.write(packed_item) + else: + buf.write(int32_pack(-1)) + return buf.getvalue() + + @classmethod + def cql_parameterized_type(cls): + subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes) + return 'frozen>' % (subtypes_string,) + + +class UserType(TupleType): + typename = "org.apache.cassandra.db.marshal.UserType" + + _cache = {} + _module = sys.modules[__name__] + + @classmethod + def make_udt_class(cls, keyspace, udt_name, field_names, field_types): + assert len(field_names) == len(field_types) + + instance = cls._cache.get((keyspace, udt_name)) + if not instance or instance.fieldnames != field_names or instance.subtypes != field_types: + instance = type(udt_name, (cls,), {'subtypes': field_types, + 'cassname': cls.cassname, + 'typename': udt_name, + 'fieldnames': field_names, + 'keyspace': keyspace, + 'mapped_class': None, + 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) + cls._cache[(keyspace, udt_name)] = instance + return instance + + @classmethod + def evict_udt_class(cls, keyspace, udt_name): + try: + del cls._cache[(keyspace, udt_name)] + except KeyError: + pass + + @classmethod + def apply_parameters(cls, subtypes, names): + keyspace = subtypes[0].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back + udt_name = _name_from_hex_string(subtypes[1].cassname) + field_names = tuple(_name_from_hex_string(encoded_name) for encoded_name in names[2:]) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) + return cls.make_udt_class(keyspace, udt_name, field_names, tuple(subtypes[2:])) + + @classmethod + def cql_parameterized_type(cls): + return "frozen<%s>" % (cls.typename,) + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + values = super(UserType, cls).deserialize_safe(byts, protocol_version) + if cls.mapped_class: + return cls.mapped_class(**dict(zip(cls.fieldnames, values))) + elif cls.tuple_type: + return cls.tuple_type(*values) + else: + return tuple(values) + + @classmethod + def serialize_safe(cls, val, protocol_version): + proto_version = max(3, protocol_version) + buf = io.BytesIO() + for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)): + # first treat as a tuple, else by custom type + try: + item = val[i] + except TypeError: + item = getattr(val, fieldname, None) + if item is None and not hasattr(val, fieldname): + log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}") + + if item is not None: + packed_item = subtype.to_binary(item, proto_version) + buf.write(int32_pack(len(packed_item))) + buf.write(packed_item) + else: + buf.write(int32_pack(-1)) + return buf.getvalue() + + @classmethod + def _make_registered_udt_namedtuple(cls, keyspace, name, field_names): + # this is required to make the type resolvable via this module... + # required when unregistered udts are pickled for use as keys in + # util.OrderedMap + t = cls._make_udt_tuple_type(name, field_names) + if t: + qualified_name = "%s_%s" % (keyspace, name) + setattr(cls._module, qualified_name, t) + return t + + @classmethod + def _make_udt_tuple_type(cls, name, field_names): + # fallback to positional named, then unnamed tuples + # for CQL identifiers that aren't valid in Python, + try: + t = namedtuple(name, field_names) + except ValueError: + try: + t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names)) + log.warning("could not create a namedtuple for '%s' because one or more " + "field names are not valid Python identifiers (%s); " + "returning positionally-named fields" % (name, field_names)) + except ValueError: + t = None + log.warning("could not create a namedtuple for '%s' because the name is " + "not a valid Python identifier; will return tuples in " + "its place" % (name,)) + return t + + class CompositeType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.CompositeType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.CompositeType" + + @classmethod + def cql_parameterized_type(cls): + """ + There is no CQL notation for Composites, so we override this. + """ + typestring = cls.cass_parameterized_type(full=True) + return "'%s'" % (typestring,) + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + result = [] + for subtype in cls.subtypes: + if not byts: + # CompositeType can have missing elements at the end + break + + element_length = uint16_unpack(byts[:2]) + element = byts[2:2 + element_length] + + # skip element length, element, and the EOC (one byte) + byts = byts[2 + element_length + 1:] + result.append(subtype.from_binary(element, protocol_version)) + + return tuple(result) class DynamicCompositeType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.DynamicCompositeType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.DynamicCompositeType" + + @classmethod + def cql_parameterized_type(cls): + sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) + return "'%s(%s)'" % (cls.typename, sublist) class ColumnToCollectionType(_ParameterizedType): @@ -626,27 +1121,41 @@ class ColumnToCollectionType(_ParameterizedType): Cassandra includes this. We don't actually need or want the extra information. """ - typename = "'org.apache.cassandra.db.marshal.ColumnToCollectionType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" class ReversedType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.ReversedType'" + typename = "org.apache.cassandra.db.marshal.ReversedType" + num_subtypes = 1 + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + subtype, = cls.subtypes + return subtype.from_binary(byts, protocol_version) + + @classmethod + def serialize_safe(cls, val, protocol_version): + subtype, = cls.subtypes + return subtype.to_binary(val, protocol_version) + + +class FrozenType(_ParameterizedType): + typename = "frozen" num_subtypes = 1 @classmethod - def deserialize_safe(cls, byts): + def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes - return subtype.from_binary(byts) + return subtype.from_binary(byts, protocol_version) @classmethod - def serialize_safe(cls, val): + def serialize_safe(cls, val, protocol_version): subtype, = cls.subtypes - return subtype.to_binary(val) + return subtype.to_binary(val, protocol_version) def is_counter_type(t): - if isinstance(t, basestring): + if isinstance(t, str): t = lookup_casstype(t) return issubclass(t, CounterColumnType) @@ -663,3 +1172,327 @@ def cql_typename(casstypename): 'list' """ return lookup_casstype(casstypename).cql_parameterized_type() + + +class WKBGeometryType(object): + POINT = 1 + LINESTRING = 2 + POLYGON = 3 + + +class PointType(CassandraType): + typename = 'PointType' + + _type = struct.pack('[[]] + type_ = int8_unpack(byts[0:1]) + + if type_ in (BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE), + BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN)): + time0 = precision0 = None + else: + time0 = int64_unpack(byts[1:9]) + precision0 = int8_unpack(byts[9:10]) + + if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE): + time1 = int64_unpack(byts[10:18]) + precision1 = int8_unpack(byts[18:19]) + else: + time1 = precision1 = None + + if time0 is not None: + date_range_bound0 = util.DateRangeBound( + time0, + cls._decode_precision(precision0) + ) + if time1 is not None: + date_range_bound1 = util.DateRangeBound( + time1, + cls._decode_precision(precision1) + ) + + if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE): + return util.DateRange(value=date_range_bound0) + if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE): + return util.DateRange(lower_bound=date_range_bound0, + upper_bound=date_range_bound1) + if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_HIGH): + return util.DateRange(lower_bound=date_range_bound0, + upper_bound=util.OPEN_BOUND) + if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_LOW): + return util.DateRange(lower_bound=util.OPEN_BOUND, + upper_bound=date_range_bound0) + if type_ == BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE): + return util.DateRange(lower_bound=util.OPEN_BOUND, + upper_bound=util.OPEN_BOUND) + if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN): + return util.DateRange(value=util.OPEN_BOUND) + raise ValueError('Could not deserialize %r' % (byts,)) + + @classmethod + def serialize(cls, v, protocol_version): + buf = io.BytesIO() + bound_kind, bounds = None, () + + try: + value = v.value + except AttributeError: + raise ValueError( + '%s.serialize expects an object with a value attribute; got' + '%r' % (cls.__name__, v) + ) + + if value is None: + try: + lower_bound, upper_bound = v.lower_bound, v.upper_bound + except AttributeError: + raise ValueError( + '%s.serialize expects an object with lower_bound and ' + 'upper_bound attributes; got %r' % (cls.__name__, v) + ) + if lower_bound == util.OPEN_BOUND and upper_bound == util.OPEN_BOUND: + bound_kind = BoundKind.BOTH_OPEN_RANGE + elif lower_bound == util.OPEN_BOUND: + bound_kind = BoundKind.OPEN_RANGE_LOW + bounds = (upper_bound,) + elif upper_bound == util.OPEN_BOUND: + bound_kind = BoundKind.OPEN_RANGE_HIGH + bounds = (lower_bound,) + else: + bound_kind = BoundKind.CLOSED_RANGE + bounds = lower_bound, upper_bound + else: # value is not None + if value == util.OPEN_BOUND: + bound_kind = BoundKind.SINGLE_DATE_OPEN + else: + bound_kind = BoundKind.SINGLE_DATE + bounds = (value,) + + if bound_kind is None: + raise ValueError( + 'Cannot serialize %r; could not find bound kind' % (v,) + ) + + buf.write(int8_pack(BoundKind.to_int(bound_kind))) + for bound in bounds: + buf.write(int64_pack(bound.milliseconds)) + buf.write(int8_pack(cls._encode_precision(bound.precision))) + + return buf.getvalue() + +class VectorType(_CassandraType): + typename = 'org.apache.cassandra.db.marshal.VectorType' + vector_size = 0 + subtype = None + + @classmethod + def serial_size(cls): + serialized_size = cls.subtype.serial_size() + return cls.vector_size * serialized_size if serialized_size is not None else None + + @classmethod + def apply_parameters(cls, params, names): + assert len(params) == 2 + subtype = lookup_casstype(params[0]) + vsize = params[1] + return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype}) + + @classmethod + def deserialize(cls, byts, protocol_version): + serialized_size = cls.subtype.serial_size() + if serialized_size is not None: + expected_byte_size = serialized_size * cls.vector_size + if len(byts) != expected_byte_size: + raise ValueError( + "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\ + .format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts))) + indexes = (serialized_size * x for x in range(0, cls.vector_size)) + return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] + + idx = 0 + rv = [] + while (len(rv) < cls.vector_size): + try: + size, bytes_read = uvint_unpack(byts[idx:]) + idx += bytes_read + rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) + idx += size + except: + raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\ + .format(len(rv))) + + # If we have any additional data in the serialized vector treat that as an error as well + if idx < len(byts): + raise ValueError("Additional bytes remaining after vector deserialization completed") + return rv + + @classmethod + def serialize(cls, v, protocol_version): + v_length = len(v) + if cls.vector_size != v_length: + raise ValueError( + "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\ + .format(cls.vector_size, cls.subtype.typename, v_length)) + + serialized_size = cls.subtype.serial_size() + buf = io.BytesIO() + for item in v: + item_bytes = cls.subtype.serialize(item, protocol_version) + if serialized_size is None: + buf.write(uvint_pack(len(item_bytes))) + buf.write(item_bytes) + return buf.getvalue() + + @classmethod + def cql_parameterized_type(cls): + return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size) diff --git a/cassandra/cython_deps.py b/cassandra/cython_deps.py new file mode 100644 index 0000000000..5cc86fe706 --- /dev/null +++ b/cassandra/cython_deps.py @@ -0,0 +1,11 @@ +try: + from cassandra.row_parser import make_recv_results_rows + HAVE_CYTHON = True +except ImportError: + HAVE_CYTHON = False + +try: + import numpy + HAVE_NUMPY = True +except ImportError: + HAVE_NUMPY = False diff --git a/cassandra/cython_marshal.pyx b/cassandra/cython_marshal.pyx new file mode 100644 index 0000000000..4733a47935 --- /dev/null +++ b/cassandra/cython_marshal.pyx @@ -0,0 +1,72 @@ +# -- cython: profile=True +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t) +from libc.string cimport memcpy +from cassandra.buffer cimport Buffer, buf_read, to_bytes + +cdef bint is_little_endian +from cassandra.util import is_little_endian + +ctypedef fused num_t: + int64_t + int32_t + int16_t + int8_t + uint64_t + uint32_t + uint16_t + uint8_t + double + float + +cdef inline num_t unpack_num(Buffer *buf, num_t *dummy=NULL): # dummy pointer because cython wants the fused type as an arg + """ + Copy to aligned destination, conditionally swapping to native byte order + """ + cdef Py_ssize_t start, end, i + cdef char *src = buf_read(buf, sizeof(num_t)) + cdef num_t ret = 0 + cdef char *out = &ret + + if is_little_endian: + for i in range(sizeof(num_t)): + out[sizeof(num_t) - i - 1] = src[i] + else: + memcpy(out, src, sizeof(num_t)) + + return ret + +cdef varint_unpack(Buffer *term): + """Unpack a variable-sized integer""" + return varint_unpack_py3(to_bytes(term)) + +# TODO: Optimize these two functions +cdef varint_unpack_py3(bytes term): + val = int(''.join(["%02x" % i for i in term]), 16) + if (term[0] & 128) != 0: + shift = len(term) * 8 # * Note below + val -= 1 << shift + return val + +# * Note * +# '1 << (len(term) * 8)' Cython tries to do native +# integer shifts, which overflows. We need this to +# emulate Python shifting, which will expand the long +# to accommodate diff --git a/cassandra/cython_utils.pxd b/cassandra/cython_utils.pxd new file mode 100644 index 0000000000..4a1e71dba5 --- /dev/null +++ b/cassandra/cython_utils.pxd @@ -0,0 +1,2 @@ +from libc.stdint cimport int64_t +cdef datetime_from_timestamp(double timestamp) diff --git a/cassandra/cython_utils.pyx b/cassandra/cython_utils.pyx new file mode 100644 index 0000000000..1b6a136c69 --- /dev/null +++ b/cassandra/cython_utils.pyx @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Duplicate module of util.py, with some accelerated functions +used for deserialization. +""" + +from libc.math cimport modf, round, fabs + +from cpython.datetime cimport ( + timedelta_new, + # cdef inline object timedelta_new(int days, int seconds, int useconds) + # Create timedelta object using DateTime CAPI factory function. + # Note, there are no range checks for any of the arguments. + import_datetime, + # Datetime C API initialization function. + # You have to call it before any usage of DateTime CAPI functions. + ) + +import datetime +import sys + +cdef bint is_little_endian +from cassandra.util import is_little_endian + +import_datetime() + +DEF DAY_IN_SECONDS = 86400 + +DATETIME_EPOC = datetime.datetime(1970, 1, 1) + + +cdef datetime_from_timestamp(double timestamp): + cdef int days = (timestamp / DAY_IN_SECONDS) + cdef int64_t days_in_seconds = ( days) * DAY_IN_SECONDS + cdef int seconds = (timestamp - days_in_seconds) + cdef double tmp + cdef double micros_left = modf(timestamp, &tmp) * 1000000. + micros_left = modf(micros_left, &tmp) + cdef int microseconds = tmp + + # rounding to emulate fp math in delta_new + cdef int x_odd + tmp = round(micros_left) + if fabs(tmp - micros_left) == 0.5: + x_odd = microseconds & 1 + tmp = 2.0 * round((micros_left + x_odd) * 0.5) - x_odd + microseconds += tmp + + return DATETIME_EPOC + timedelta_new(days, seconds, microseconds) diff --git a/cassandra/datastax/__init__.py b/cassandra/datastax/__init__.py new file mode 100644 index 0000000000..635f0d9e60 --- /dev/null +++ b/cassandra/datastax/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cassandra/datastax/cloud/__init__.py b/cassandra/datastax/cloud/__init__.py new file mode 100644 index 0000000000..e175b2928b --- /dev/null +++ b/cassandra/datastax/cloud/__init__.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import json +import sys +import tempfile +import shutil +from urllib.request import urlopen + +_HAS_SSL = True +try: + from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED +except: + _HAS_SSL = False + +from zipfile import ZipFile + +# 2.7 vs 3.x +try: + from zipfile import BadZipFile +except: + from zipfile import BadZipfile as BadZipFile + +from cassandra import DriverException + +log = logging.getLogger(__name__) + +__all__ = ['get_cloud_config'] + +DATASTAX_CLOUD_PRODUCT_TYPE = "DATASTAX_APOLLO" + + +class CloudConfig(object): + + username = None + password = None + host = None + port = None + keyspace = None + local_dc = None + ssl_context = None + + sni_host = None + sni_port = None + host_ids = None + + @classmethod + def from_dict(cls, d): + c = cls() + + c.port = d.get('port', None) + try: + c.port = int(d['port']) + except: + pass + + c.username = d.get('username', None) + c.password = d.get('password', None) + c.host = d.get('host', None) + c.keyspace = d.get('keyspace', None) + c.local_dc = d.get('localDC', None) + + return c + + +def get_cloud_config(cloud_config, create_pyopenssl_context=False): + if not _HAS_SSL: + raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.") + + if 'secure_connect_bundle' not in cloud_config: + raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.") + + try: + config = read_cloud_config_from_zip(cloud_config, create_pyopenssl_context) + except BadZipFile: + raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.") + + config = read_metadata_info(config, cloud_config) + if create_pyopenssl_context: + config.ssl_context = config.pyopenssl_context + return config + + +def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context): + secure_bundle = cloud_config['secure_connect_bundle'] + use_default_tempdir = cloud_config.get('use_default_tempdir', None) + with ZipFile(secure_bundle) as zipfile: + base_dir = tempfile.gettempdir() if use_default_tempdir else os.path.dirname(secure_bundle) + tmp_dir = tempfile.mkdtemp(dir=base_dir) + try: + zipfile.extractall(path=tmp_dir) + return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config, create_pyopenssl_context) + finally: + shutil.rmtree(tmp_dir) + + +def parse_cloud_config(path, cloud_config, create_pyopenssl_context): + with open(path, 'r') as stream: + data = json.load(stream) + + config = CloudConfig.from_dict(data) + config_dir = os.path.dirname(path) + + if 'ssl_context' in cloud_config: + config.ssl_context = cloud_config['ssl_context'] + else: + # Load the ssl_context before we delete the temporary directory + ca_cert_location = os.path.join(config_dir, 'ca.crt') + cert_location = os.path.join(config_dir, 'cert') + key_location = os.path.join(config_dir, 'key') + # Regardless of if we create a pyopenssl context, we still need the builtin one + # to connect to the metadata service + config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location) + if create_pyopenssl_context: + config.pyopenssl_context = _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location) + + return config + + +def read_metadata_info(config, cloud_config): + url = "https://{}:{}/metadata".format(config.host, config.port) + timeout = cloud_config['connect_timeout'] if 'connect_timeout' in cloud_config else 5 + try: + response = urlopen(url, context=config.ssl_context, timeout=timeout) + except Exception as e: + log.exception(e) + raise DriverException("Unable to connect to the metadata service at %s. " + "Check the cluster status in the cloud console. " % url) + + if response.code != 200: + raise DriverException(("Error while fetching the metadata at: %s. " + "The service returned error code %d." % (url, response.code))) + return parse_metadata_info(config, response.read().decode('utf-8')) + + +def parse_metadata_info(config, http_data): + try: + data = json.loads(http_data) + except: + msg = "Failed to load cluster metadata" + raise DriverException(msg) + + contact_info = data['contact_info'] + config.local_dc = contact_info['local_dc'] + + proxy_info = contact_info['sni_proxy_address'].split(':') + config.sni_host = proxy_info[0] + try: + config.sni_port = int(proxy_info[1]) + except: + config.sni_port = 9042 + + config.host_ids = [host_id for host_id in contact_info['contact_points']] + + return config + + +def _ssl_context_from_cert(ca_cert_location, cert_location, key_location): + ssl_context = SSLContext(PROTOCOL_TLS) + ssl_context.load_verify_locations(ca_cert_location) + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location) + + return ssl_context + + +def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location): + try: + from OpenSSL import SSL + except ImportError as e: + raise ImportError( + "PyOpenSSL must be installed to connect to Astra with the Eventlet or Twisted event loops")\ + .with_traceback(e.__traceback__) + ssl_context = SSL.Context(SSL.TLSv1_METHOD) + ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok) + ssl_context.use_certificate_file(cert_location) + ssl_context.use_privatekey_file(key_location) + ssl_context.load_verify_locations(ca_cert_location) + + return ssl_context \ No newline at end of file diff --git a/cassandra/datastax/graph/__init__.py b/cassandra/datastax/graph/__init__.py new file mode 100644 index 0000000000..8315843a36 --- /dev/null +++ b/cassandra/datastax/graph/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cassandra.datastax.graph.types import Element, Vertex, VertexProperty, Edge, Path, T +from cassandra.datastax.graph.query import ( + GraphOptions, GraphProtocol, GraphStatement, SimpleGraphStatement, Result, + graph_object_row_factory, single_object_row_factory, + graph_result_row_factory, graph_graphson2_row_factory, + graph_graphson3_row_factory +) +from cassandra.datastax.graph.graphson import * diff --git a/cassandra/datastax/graph/fluent/__init__.py b/cassandra/datastax/graph/fluent/__init__.py new file mode 100644 index 0000000000..0dfd5230e5 --- /dev/null +++ b/cassandra/datastax/graph/fluent/__init__.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import copy + +from concurrent.futures import Future + +HAVE_GREMLIN = False +try: + import gremlin_python + HAVE_GREMLIN = True +except ImportError: + # gremlinpython is not installed. + pass + +if HAVE_GREMLIN: + from gremlin_python.structure.graph import Graph + from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal + from gremlin_python.process.traversal import Traverser, TraversalSideEffects + from gremlin_python.process.graph_traversal import GraphTraversal + + from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphOptions, GraphProtocol + from cassandra.datastax.graph.query import _GraphSONContextRowFactory + + from cassandra.datastax.graph.fluent.serializers import ( + GremlinGraphSONReaderV2, + GremlinGraphSONReaderV3, + dse_graphson2_deserializers, + gremlin_graphson2_deserializers, + dse_graphson3_deserializers, + gremlin_graphson3_deserializers + ) + from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal + + log = logging.getLogger(__name__) + + __all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory', + 'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph'] + + # Traversal result keys + _bulk_key = 'bulk' + _result_key = 'result' + + + class BaseGraphRowFactory(_GraphSONContextRowFactory): + """ + Base row factory for graph traversal. This class basically wraps a + graphson reader function to handle additional features of Gremlin/DSE + and is callable as a normal row factory. + + Currently supported: + - bulk results + """ + + def __call__(self, column_names, rows): + for row in rows: + parsed_row = self.graphson_reader.readObject(row[0]) + yield parsed_row[_result_key] + bulk = parsed_row.get(_bulk_key, 1) + for _ in range(bulk - 1): + yield copy.deepcopy(parsed_row[_result_key]) + + + class _GremlinGraphSON2RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson2.""" + graphson_reader_class = GremlinGraphSONReaderV2 + graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers} + + + class _DseGraphSON2RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson2 as DSE types.""" + graphson_reader_class = GremlinGraphSONReaderV2 + graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers} + + gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory + # TODO remove in next major + graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory + + dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory + # TODO remove in next major + graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory + + + class _GremlinGraphSON3RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson2.""" + graphson_reader_class = GremlinGraphSONReaderV3 + graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers} + + + class _DseGraphSON3RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson3 as DSE types.""" + graphson_reader_class = GremlinGraphSONReaderV3 + graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers} + + + gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory + dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory + + + class DSESessionRemoteGraphConnection(RemoteConnection): + """ + A Tinkerpop RemoteConnection to execute traversal queries on DSE. + + :param session: A DSE session + :param graph_name: (Optional) DSE Graph name. + :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`. + """ + + session = None + graph_name = None + execution_profile = None + + def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + super(DSESessionRemoteGraphConnection, self).__init__(None, None) + + if not isinstance(session, Session): + raise ValueError('A DSE Session must be provided to execute graph traversal queries.') + + self.session = session + self.graph_name = graph_name + self.execution_profile = execution_profile + + @staticmethod + def _traversers_generator(traversers): + for t in traversers: + yield Traverser(t) + + def _prepare_query(self, bytecode): + ep = self.session.execution_profile_clone_update(self.execution_profile) + graph_options = ep.graph_options + graph_options.graph_name = self.graph_name or graph_options.graph_name + graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE + # We resolve the execution profile options here , to know how what gremlin factory to set + self.session._resolve_execution_profile_options(ep) + + context = None + if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0: + row_factory = gremlin_graphson2_traversal_row_factory + elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0: + row_factory = gremlin_graphson3_traversal_row_factory + context = { + 'cluster': self.session.cluster, + 'graph_name': graph_options.graph_name.decode('utf-8') + } + else: + raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol)) + + ep.row_factory = row_factory + query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context) + + return query, ep + + @staticmethod + def _handle_query_results(result_set, gremlin_future): + try: + gremlin_future.set_result( + RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects()) + ) + except Exception as e: + gremlin_future.set_exception(e) + + @staticmethod + def _handle_query_error(response, gremlin_future): + gremlin_future.set_exception(response) + + def submit(self, bytecode): + # the only reason I don't use submitAsync here + # is to avoid an unuseful future wrap + query, ep = self._prepare_query(bytecode) + + traversers = self.session.execute_graph(query, execution_profile=ep) + return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects()) + + def submitAsync(self, bytecode): + query, ep = self._prepare_query(bytecode) + + # to be compatible with gremlinpython, we need to return a concurrent.futures.Future + gremlin_future = Future() + response_future = self.session.execute_graph_async(query, execution_profile=ep) + response_future.add_callback(self._handle_query_results, gremlin_future) + response_future.add_errback(self._handle_query_error, gremlin_future) + + return gremlin_future + + def __str__(self): + return "".format(self.graph_name) + + __repr__ = __str__ + + + class DseGraph(object): + """ + Dse Graph utility class for GraphTraversal construction and execution. + """ + + DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json' + """ + Graph query language, Default is 'bytecode-json' (GraphSON). + """ + + DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0 + """ + Graph query language, Default is GraphProtocol.GRAPHSON_2_0. + """ + + @staticmethod + def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None): + """ + From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`. + + :param traversal: The GraphTraversal object + :param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`. + :param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt). + e.g: {'cluster': cluster, 'graph_name': name} + """ + + if isinstance(traversal, GraphTraversal): + for strategy in traversal.traversal_strategies.traversal_strategies: + rc = strategy.remote_connection + if (isinstance(rc, DSESessionRemoteGraphConnection) and + rc.session or rc.graph_name or rc.execution_profile): + log.warning("GraphTraversal session, graph_name and execution_profile are " + "only taken into account when executed with TinkerPop.") + + return _query_from_traversal(traversal, graph_protocol, context) + + @staticmethod + def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, + traversal_class=None): + """ + Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided. + + :param session: (Optional) A DSE session + :param graph_name: (Optional) DSE Graph name + :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`. + :param traversal_class: (Optional) The GraphTraversalSource class to use (DSL). + + .. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.datastax.graph.fluent import DseGraph + + c = Cluster() + session = c.connect() + + g = DseGraph.traversal_source(session, 'my_graph') + print(g.V().valueMap().toList()) + + """ + + graph = Graph() + traversal_source = graph.traversal(traversal_class) + + if session: + traversal_source = traversal_source.withRemote( + DSESessionRemoteGraphConnection(session, graph_name, execution_profile)) + + return traversal_source + + @staticmethod + def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs): + """ + Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the + cluster by using `cluster.add_execution_profile`. + + :param graph_name: The graph name + :param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`. + """ + + if graph_protocol == GraphProtocol.GRAPHSON_2_0: + row_factory = dse_graphson2_traversal_row_factory + elif graph_protocol == GraphProtocol.GRAPHSON_3_0: + row_factory = dse_graphson3_traversal_row_factory + else: + raise ValueError('Unknown graph protocol: {}'.format(graph_protocol)) + + ep = GraphExecutionProfile(row_factory=row_factory, + graph_options=GraphOptions(graph_name=graph_name, + graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE, + graph_protocol=graph_protocol), + **kwargs) + return ep + + @staticmethod + def batch(*args, **kwargs): + """ + Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to + execute multiple traversals in the same transaction. + """ + return _DefaultTraversalBatch(*args, **kwargs) diff --git a/cassandra/datastax/graph/fluent/_predicates.py b/cassandra/datastax/graph/fluent/_predicates.py new file mode 100644 index 0000000000..1c7825455a --- /dev/null +++ b/cassandra/datastax/graph/fluent/_predicates.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from gremlin_python.process.traversal import P + +from cassandra.util import Distance + +__all__ = ['GeoP', 'TextDistanceP', 'Search', 'GeoUnit', 'Geo', 'CqlCollection'] + + +class GeoP(object): + + def __init__(self, operator, value, other=None): + self.operator = operator + self.value = value + self.other = other + + @staticmethod + def inside(*args, **kwargs): + return GeoP("inside", *args, **kwargs) + + def __eq__(self, other): + return isinstance(other, + self.__class__) and self.operator == other.operator and self.value == other.value and self.other == other.other + + def __repr__(self): + return self.operator + "(" + str(self.value) + ")" if self.other is None else self.operator + "(" + str( + self.value) + "," + str(self.other) + ")" + + +class TextDistanceP(object): + + def __init__(self, operator, value, distance): + self.operator = operator + self.value = value + self.distance = distance + + @staticmethod + def fuzzy(*args): + return TextDistanceP("fuzzy", *args) + + @staticmethod + def token_fuzzy(*args): + return TextDistanceP("tokenFuzzy", *args) + + @staticmethod + def phrase(*args): + return TextDistanceP("phrase", *args) + + def __eq__(self, other): + return isinstance(other, + self.__class__) and self.operator == other.operator and self.value == other.value and self.distance == other.distance + + def __repr__(self): + return self.operator + "(" + str(self.value) + "," + str(self.distance) + ")" + + +class Search(object): + + @staticmethod + def token(value): + """ + Search any instance of a certain token within the text property targeted. + :param value: the value to look for. + """ + return P('token', value) + + @staticmethod + def token_prefix(value): + """ + Search any instance of a certain token prefix withing the text property targeted. + :param value: the value to look for. + """ + return P('tokenPrefix', value) + + @staticmethod + def token_regex(value): + """ + Search any instance of the provided regular expression for the targeted property. + :param value: the value to look for. + """ + return P('tokenRegex', value) + + @staticmethod + def prefix(value): + """ + Search for a specific prefix at the beginning of the text property targeted. + :param value: the value to look for. + """ + return P('prefix', value) + + @staticmethod + def regex(value): + """ + Search for this regular expression inside the text property targeted. + :param value: the value to look for. + """ + return P('regex', value) + + @staticmethod + def fuzzy(value, distance): + """ + Search for a fuzzy string inside the text property targeted. + :param value: the value to look for. + :param distance: The distance for the fuzzy search. ie. 1, to allow a one-letter misspellings. + """ + return TextDistanceP.fuzzy(value, distance) + + @staticmethod + def token_fuzzy(value, distance): + """ + Search for a token fuzzy inside the text property targeted. + :param value: the value to look for. + :param distance: The distance for the token fuzzy search. ie. 1, to allow a one-letter misspellings. + """ + return TextDistanceP.token_fuzzy(value, distance) + + @staticmethod + def phrase(value, proximity): + """ + Search for a phrase inside the text property targeted. + :param value: the value to look for. + :param proximity: The proximity for the phrase search. ie. phrase('David Felcey', 2).. to find 'David Felcey' with up to two middle names. + """ + return TextDistanceP.phrase(value, proximity) + + +class CqlCollection(object): + + @staticmethod + def contains(value): + """ + Search for a value inside a cql list/set column. + :param value: the value to look for. + """ + return P('contains', value) + + @staticmethod + def contains_value(value): + """ + Search for a map value. + :param value: the value to look for. + """ + return P('containsValue', value) + + @staticmethod + def contains_key(value): + """ + Search for a map key. + :param value: the value to look for. + """ + return P('containsKey', value) + + @staticmethod + def entry_eq(value): + """ + Search for a map entry. + :param value: the value to look for. + """ + return P('entryEq', value) + + +class GeoUnit(object): + _EARTH_MEAN_RADIUS_KM = 6371.0087714 + _DEGREES_TO_RADIANS = math.pi / 180 + _DEG_TO_KM = _DEGREES_TO_RADIANS * _EARTH_MEAN_RADIUS_KM + _KM_TO_DEG = 1 / _DEG_TO_KM + _MILES_TO_KM = 1.609344001 + + MILES = _MILES_TO_KM * _KM_TO_DEG + KILOMETERS = _KM_TO_DEG + METERS = _KM_TO_DEG / 1000.0 + DEGREES = 1 + + +class Geo(object): + + @staticmethod + def inside(value, units=GeoUnit.DEGREES): + """ + Search any instance of geometry inside the Distance targeted. + :param value: A Distance to look for. + :param units: The units for ``value``. See GeoUnit enum. (Can also + provide an integer to use as a multiplier to convert ``value`` to + degrees.) + """ + return GeoP.inside( + value=Distance(x=value.x, y=value.y, radius=value.radius * units) + ) diff --git a/cassandra/datastax/graph/fluent/_query.py b/cassandra/datastax/graph/fluent/_query.py new file mode 100644 index 0000000000..c476653541 --- /dev/null +++ b/cassandra/datastax/graph/fluent/_query.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from cassandra.graph import SimpleGraphStatement, GraphProtocol +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT + +from gremlin_python.process.graph_traversal import GraphTraversal +from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter as GraphSONWriterV2 +from gremlin_python.structure.io.graphsonV3d0 import GraphSONWriter as GraphSONWriterV3 + +from cassandra.datastax.graph.fluent.serializers import GremlinUserTypeIO, \ + dse_graphson2_serializers, dse_graphson3_serializers + +log = logging.getLogger(__name__) + + +__all__ = ['TraversalBatch', '_query_from_traversal', '_DefaultTraversalBatch'] + + +class _GremlinGraphSONWriterAdapter(object): + + def __init__(self, context, **kwargs): + super(_GremlinGraphSONWriterAdapter, self).__init__(**kwargs) + self.context = context + self.user_types = None + + def serialize(self, value, _): + return self.toDict(value) + + def get_serializer(self, value): + serializer = None + try: + serializer = self.serializers[type(value)] + except KeyError: + for key, ser in self.serializers.items(): + if isinstance(value, key): + serializer = ser + + if self.context: + # Check if UDT + if self.user_types is None: + try: + user_types = self.context['cluster']._user_types[self.context['graph_name']] + self.user_types = dict(map(reversed, user_types.items())) + except KeyError: + self.user_types = {} + + # Custom detection to map a namedtuple to udt + if (tuple in self.serializers and serializer is self.serializers[tuple] and hasattr(value, '_fields') or + (not serializer and type(value) in self.user_types)): + serializer = GremlinUserTypeIO + + if serializer: + try: + # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant) + serializer = serializer.get_specialized_serializer(value) + except AttributeError: + pass + + return serializer + + def toDict(self, obj): + serializer = self.get_serializer(obj) + return serializer.dictify(obj, self) if serializer else obj + + def definition(self, value): + serializer = self.get_serializer(value) + return serializer.definition(value, self) + + +class GremlinGraphSON2Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV2): + pass + + +class GremlinGraphSON3Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV3): + pass + + +graphson2_writer = GremlinGraphSON2Writer +graphson3_writer = GremlinGraphSON3Writer + + +def _query_from_traversal(traversal, graph_protocol, context=None): + """ + From a GraphTraversal, return a query string. + + :param traversal: The GraphTraversal object + :param graphson_protocol: The graph protocol to determine the output format. + """ + if graph_protocol == GraphProtocol.GRAPHSON_2_0: + graphson_writer = graphson2_writer(context, serializer_map=dse_graphson2_serializers) + elif graph_protocol == GraphProtocol.GRAPHSON_3_0: + if context is None: + raise ValueError('Missing context for GraphSON3 serialization requires.') + graphson_writer = graphson3_writer(context, serializer_map=dse_graphson3_serializers) + else: + raise ValueError('Unknown graph protocol: {}'.format(graph_protocol)) + + try: + query = graphson_writer.writeObject(traversal) + except Exception: + log.exception("Error preparing graphson traversal query:") + raise + + return query + + +class TraversalBatch(object): + """ + A `TraversalBatch` is used to execute multiple graph traversals in a + single transaction. If any traversal in the batch fails, the entire + batch will fail to apply. + + If a TraversalBatch is bounded to a DSE session, it can be executed using + `traversal_batch.execute()`. + """ + + _session = None + _execution_profile = None + + def __init__(self, session=None, execution_profile=None): + """ + :param session: (Optional) A DSE session + :param execution_profile: (Optional) The execution profile to use for the batch execution + """ + self._session = session + self._execution_profile = execution_profile + + def add(self, traversal): + """ + Add a traversal to the batch. + + :param traversal: A gremlin GraphTraversal + """ + raise NotImplementedError() + + def add_all(self, traversals): + """ + Adds a sequence of traversals to the batch. + + :param traversals: A sequence of gremlin GraphTraversal + """ + raise NotImplementedError() + + def execute(self): + """ + Execute the traversal batch if bounded to a `DSE Session`. + """ + raise NotImplementedError() + + def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0): + """ + Return the traversal batch as GraphStatement. + + :param graph_protocol: The graph protocol for the GraphSONWriter. Default is GraphProtocol.GRAPHSON_2_0. + """ + raise NotImplementedError() + + def clear(self): + """ + Clear a traversal batch for reuse. + """ + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + def __str__(self): + return u''.format(len(self)) + __repr__ = __str__ + + +class _DefaultTraversalBatch(TraversalBatch): + + _traversals = None + + def __init__(self, *args, **kwargs): + super(_DefaultTraversalBatch, self).__init__(*args, **kwargs) + self._traversals = [] + + def add(self, traversal): + if not isinstance(traversal, GraphTraversal): + raise ValueError('traversal should be a gremlin GraphTraversal') + + self._traversals.append(traversal) + return self + + def add_all(self, traversals): + for traversal in traversals: + self.add(traversal) + + def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None): + statements = [_query_from_traversal(t, graph_protocol, context) for t in self._traversals] + query = u"[{0}]".format(','.join(statements)) + return SimpleGraphStatement(query) + + def execute(self): + if self._session is None: + raise ValueError('A DSE Session must be provided to execute the traversal batch.') + + execution_profile = self._execution_profile if self._execution_profile else EXEC_PROFILE_GRAPH_DEFAULT + graph_options = self._session.get_execution_profile(execution_profile).graph_options + context = { + 'cluster': self._session.cluster, + 'graph_name': graph_options.graph_name + } + statement = self.as_graph_statement(graph_options.graph_protocol, context=context) \ + if graph_options.graph_protocol else self.as_graph_statement(context=context) + return self._session.execute_graph(statement, execution_profile=execution_profile) + + def clear(self): + del self._traversals[:] + + def __len__(self): + return len(self._traversals) diff --git a/cassandra/datastax/graph/fluent/_serializers.py b/cassandra/datastax/graph/fluent/_serializers.py new file mode 100644 index 0000000000..b6c705771f --- /dev/null +++ b/cassandra/datastax/graph/fluent/_serializers.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +from gremlin_python.structure.io.graphsonV2d0 import ( + GraphSONReader as GraphSONReaderV2, + GraphSONUtil as GraphSONUtil, # no difference between v2 and v3 + VertexDeserializer as VertexDeserializerV2, + VertexPropertyDeserializer as VertexPropertyDeserializerV2, + PropertyDeserializer as PropertyDeserializerV2, + EdgeDeserializer as EdgeDeserializerV2, + PathDeserializer as PathDeserializerV2 +) + +from gremlin_python.structure.io.graphsonV3d0 import ( + GraphSONReader as GraphSONReaderV3, + VertexDeserializer as VertexDeserializerV3, + VertexPropertyDeserializer as VertexPropertyDeserializerV3, + PropertyDeserializer as PropertyDeserializerV3, + EdgeDeserializer as EdgeDeserializerV3, + PathDeserializer as PathDeserializerV3 +) + +try: + from gremlin_python.structure.io.graphsonV2d0 import ( + TraversalMetricsDeserializer as TraversalMetricsDeserializerV2, + MetricsDeserializer as MetricsDeserializerV2 + ) + from gremlin_python.structure.io.graphsonV3d0 import ( + TraversalMetricsDeserializer as TraversalMetricsDeserializerV3, + MetricsDeserializer as MetricsDeserializerV3 + ) +except ImportError: + TraversalMetricsDeserializerV2 = MetricsDeserializerV2 = None + TraversalMetricsDeserializerV3 = MetricsDeserializerV3 = None + +from cassandra.graph import ( + GraphSON2Serializer, + GraphSON2Deserializer, + GraphSON3Serializer, + GraphSON3Deserializer +) +from cassandra.graph.graphson import UserTypeIO, TypeWrapperTypeIO +from cassandra.datastax.graph.fluent.predicates import GeoP, TextDistanceP +from cassandra.util import Distance + + +__all__ = ['GremlinGraphSONReader', 'GeoPSerializer', 'TextDistancePSerializer', + 'DistanceIO', 'gremlin_deserializers', 'deserializers', 'serializers', + 'GremlinGraphSONReaderV2', 'GremlinGraphSONReaderV3', 'dse_graphson2_serializers', + 'dse_graphson2_deserializers', 'dse_graphson3_serializers', 'dse_graphson3_deserializers', + 'gremlin_graphson2_deserializers', 'gremlin_graphson3_deserializers', 'GremlinUserTypeIO'] + + +class _GremlinGraphSONTypeSerializer(object): + TYPE_KEY = "@type" + VALUE_KEY = "@value" + serializer = None + + def __init__(self, serializer): + self.serializer = serializer + + def dictify(self, v, writer): + value = self.serializer.serialize(v, writer) + if self.serializer is TypeWrapperTypeIO: + graphson_base_type = v.type_io.graphson_base_type + graphson_type = v.type_io.graphson_type + else: + graphson_base_type = self.serializer.graphson_base_type + graphson_type = self.serializer.graphson_type + + if graphson_base_type is None: + out = value + else: + out = {self.TYPE_KEY: graphson_type} + if value is not None: + out[self.VALUE_KEY] = value + + return out + + def definition(self, value, writer=None): + return self.serializer.definition(value, writer) + + def get_specialized_serializer(self, value): + ser = self.serializer.get_specialized_serializer(value) + if ser is not self.serializer: + return _GremlinGraphSONTypeSerializer(ser) + return self + + +class _GremlinGraphSONTypeDeserializer(object): + + deserializer = None + + def __init__(self, deserializer): + self.deserializer = deserializer + + def objectify(self, v, reader): + return self.deserializer.deserialize(v, reader) + + +def _make_gremlin_graphson2_deserializer(graphson_type): + return _GremlinGraphSONTypeDeserializer( + GraphSON2Deserializer.get_deserializer(graphson_type.graphson_type) + ) + + +def _make_gremlin_graphson3_deserializer(graphson_type): + return _GremlinGraphSONTypeDeserializer( + GraphSON3Deserializer.get_deserializer(graphson_type.graphson_type) + ) + + +class _GremlinGraphSONReader(object): + """Gremlin GraphSONReader Adapter, required to use gremlin types""" + + context = None + + def __init__(self, context, deserializer_map=None): + self.context = context + super(_GremlinGraphSONReader, self).__init__(deserializer_map) + + def deserialize(self, obj): + return self.toObject(obj) + + +class GremlinGraphSONReaderV2(_GremlinGraphSONReader, GraphSONReaderV2): + pass + +# TODO remove next major +GremlinGraphSONReader = GremlinGraphSONReaderV2 + +class GremlinGraphSONReaderV3(_GremlinGraphSONReader, GraphSONReaderV3): + pass + + +class GeoPSerializer(object): + @classmethod + def dictify(cls, p, writer): + out = { + "predicateType": "Geo", + "predicate": p.operator, + "value": [writer.toDict(p.value), writer.toDict(p.other)] if p.other is not None else writer.toDict(p.value) + } + return GraphSONUtil.typedValue("P", out, prefix='dse') + + +class TextDistancePSerializer(object): + @classmethod + def dictify(cls, p, writer): + out = { + "predicate": p.operator, + "value": { + 'query': writer.toDict(p.value), + 'distance': writer.toDict(p.distance) + } + } + return GraphSONUtil.typedValue("P", out) + + +class DistanceIO(object): + @classmethod + def dictify(cls, v, _): + return GraphSONUtil.typedValue('Distance', str(v), prefix='dse') + + +GremlinUserTypeIO = _GremlinGraphSONTypeSerializer(UserTypeIO) + +# GraphSON2 +dse_graphson2_serializers = OrderedDict([ + (t, _GremlinGraphSONTypeSerializer(s)) + for t, s in GraphSON2Serializer.get_type_definitions().items() +]) + +dse_graphson2_serializers.update(OrderedDict([ + (Distance, DistanceIO), + (GeoP, GeoPSerializer), + (TextDistanceP, TextDistancePSerializer) +])) + +# TODO remove next major, this is just in case someone was using it +serializers = dse_graphson2_serializers + +dse_graphson2_deserializers = { + k: _make_gremlin_graphson2_deserializer(v) + for k, v in GraphSON2Deserializer.get_type_definitions().items() +} + +dse_graphson2_deserializers.update({ + "dse:Distance": DistanceIO, +}) + +# TODO remove next major, this is just in case someone was using it +deserializers = dse_graphson2_deserializers + +gremlin_graphson2_deserializers = dse_graphson2_deserializers.copy() +gremlin_graphson2_deserializers.update({ + 'g:Vertex': VertexDeserializerV2, + 'g:VertexProperty': VertexPropertyDeserializerV2, + 'g:Edge': EdgeDeserializerV2, + 'g:Property': PropertyDeserializerV2, + 'g:Path': PathDeserializerV2 +}) + +if TraversalMetricsDeserializerV2: + gremlin_graphson2_deserializers.update({ + 'g:TraversalMetrics': TraversalMetricsDeserializerV2, + 'g:lMetrics': MetricsDeserializerV2 + }) + +# TODO remove next major, this is just in case someone was using it +gremlin_deserializers = gremlin_graphson2_deserializers + +# GraphSON3 +dse_graphson3_serializers = OrderedDict([ + (t, _GremlinGraphSONTypeSerializer(s)) + for t, s in GraphSON3Serializer.get_type_definitions().items() +]) + +dse_graphson3_serializers.update(OrderedDict([ + (Distance, DistanceIO), + (GeoP, GeoPSerializer), + (TextDistanceP, TextDistancePSerializer) +])) + +dse_graphson3_deserializers = { + k: _make_gremlin_graphson3_deserializer(v) + for k, v in GraphSON3Deserializer.get_type_definitions().items() +} + +dse_graphson3_deserializers.update({ + "dse:Distance": DistanceIO +}) + +gremlin_graphson3_deserializers = dse_graphson3_deserializers.copy() +gremlin_graphson3_deserializers.update({ + 'g:Vertex': VertexDeserializerV3, + 'g:VertexProperty': VertexPropertyDeserializerV3, + 'g:Edge': EdgeDeserializerV3, + 'g:Property': PropertyDeserializerV3, + 'g:Path': PathDeserializerV3 +}) + +if TraversalMetricsDeserializerV3: + gremlin_graphson3_deserializers.update({ + 'g:TraversalMetrics': TraversalMetricsDeserializerV3, + 'g:Metrics': MetricsDeserializerV3 + }) diff --git a/cassandra/datastax/graph/fluent/predicates.py b/cassandra/datastax/graph/fluent/predicates.py new file mode 100644 index 0000000000..8dca8b84ce --- /dev/null +++ b/cassandra/datastax/graph/fluent/predicates.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import gremlin_python + from cassandra.datastax.graph.fluent._predicates import * +except ImportError: + # gremlinpython is not installed. + pass diff --git a/cassandra/datastax/graph/fluent/query.py b/cassandra/datastax/graph/fluent/query.py new file mode 100644 index 0000000000..f599f2c979 --- /dev/null +++ b/cassandra/datastax/graph/fluent/query.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import gremlin_python + from cassandra.datastax.graph.fluent._query import * +except ImportError: + # gremlinpython is not installed. + pass diff --git a/cassandra/datastax/graph/fluent/serializers.py b/cassandra/datastax/graph/fluent/serializers.py new file mode 100644 index 0000000000..3c175f92d4 --- /dev/null +++ b/cassandra/datastax/graph/fluent/serializers.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import gremlin_python + from cassandra.datastax.graph.fluent._serializers import * +except ImportError: + # gremlinpython is not installed. + pass diff --git a/cassandra/datastax/graph/graphson.py b/cassandra/datastax/graph/graphson.py new file mode 100644 index 0000000000..7b284c4c26 --- /dev/null +++ b/cassandra/datastax/graph/graphson.py @@ -0,0 +1,1134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import base64 +import uuid +import re +import json +from decimal import Decimal +from collections import OrderedDict +import logging +import itertools +from functools import partial + +import ipaddress + + +from cassandra.cqltypes import cql_types_from_string +from cassandra.metadata import UserType +from cassandra.util import Polygon, Point, LineString, Duration +from cassandra.datastax.graph.types import Vertex, VertexProperty, Edge, Path, T + +__all__ = ['GraphSON1Serializer', 'GraphSON1Deserializer', 'GraphSON1TypeDeserializer', + 'GraphSON2Serializer', 'GraphSON2Deserializer', 'GraphSON2Reader', + 'GraphSON3Serializer', 'GraphSON3Deserializer', 'GraphSON3Reader', + 'to_bigint', 'to_int', 'to_double', 'to_float', 'to_smallint', + 'BooleanTypeIO', 'Int16TypeIO', 'Int32TypeIO', 'DoubleTypeIO', + 'FloatTypeIO', 'UUIDTypeIO', 'BigDecimalTypeIO', 'DurationTypeIO', 'InetTypeIO', + 'InstantTypeIO', 'LocalDateTypeIO', 'LocalTimeTypeIO', 'Int64TypeIO', 'BigIntegerTypeIO', + 'LocalDateTypeIO', 'PolygonTypeIO', 'PointTypeIO', 'LineStringTypeIO', 'BlobTypeIO', + 'GraphSON3Serializer', 'GraphSON3Deserializer', 'UserTypeIO', 'TypeWrapperTypeIO'] + +""" +Supported types: + +DSE Graph GraphSON 2.0 GraphSON 3.0 | Python Driver +------------ | -------------- | -------------- | ------------ +text | string | string | str +boolean | | | bool +bigint | g:Int64 | g:Int64 | long +int | g:Int32 | g:Int32 | int +double | g:Double | g:Double | float +float | g:Float | g:Float | float +uuid | g:UUID | g:UUID | UUID +bigdecimal | gx:BigDecimal | gx:BigDecimal | Decimal +duration | gx:Duration | N/A | timedelta (Classic graph only) +DSE Duration | N/A | dse:Duration | Duration (Core graph only) +inet | gx:InetAddress | gx:InetAddress | str (unicode), IPV4Address/IPV6Address (PY3) +timestamp | gx:Instant | gx:Instant | datetime.datetime +date | gx:LocalDate | gx:LocalDate | datetime.date +time | gx:LocalTime | gx:LocalTime | datetime.time +smallint | gx:Int16 | gx:Int16 | int +varint | gx:BigInteger | gx:BigInteger | long +date | gx:LocalDate | gx:LocalDate | Date +polygon | dse:Polygon | dse:Polygon | Polygon +point | dse:Point | dse:Point | Point +linestring | dse:Linestring | dse:LineString | LineString +blob | dse:Blob | dse:Blob | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +blob | gx:ByteBuffer | gx:ByteBuffer | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +list | N/A | g:List | list (Core graph only) +map | N/A | g:Map | dict (Core graph only) +set | N/A | g:Set | set or list (Core graph only) + Can return a list due to numerical values returned by Java +tuple | N/A | dse:Tuple | tuple (Core graph only) +udt | N/A | dse:UDT | class or namedtuple (Core graph only) +""" + +MAX_INT32 = 2 ** 32 - 1 +MIN_INT32 = -2 ** 31 + +log = logging.getLogger(__name__) + + +class _GraphSONTypeType(type): + """GraphSONType metaclass, required to create a class property.""" + + @property + def graphson_type(cls): + return "{0}:{1}".format(cls.prefix, cls.graphson_base_type) + + +class GraphSONTypeIO(object, metaclass=_GraphSONTypeType): + """Represent a serializable GraphSON type""" + + prefix = 'g' + graphson_base_type = None + cql_type = None + + @classmethod + def definition(cls, value, writer=None): + return {'cqlType': cls.cql_type} + + @classmethod + def serialize(cls, value, writer=None): + return str(value) + + @classmethod + def deserialize(cls, value, reader=None): + return value + + @classmethod + def get_specialized_serializer(cls, value): + return cls + + +class TextTypeIO(GraphSONTypeIO): + cql_type = 'text' + + +class BooleanTypeIO(GraphSONTypeIO): + graphson_base_type = None + cql_type = 'boolean' + + @classmethod + def serialize(cls, value, writer=None): + return bool(value) + + +class IntegerTypeIO(GraphSONTypeIO): + + @classmethod + def serialize(cls, value, writer=None): + return value + + @classmethod + def get_specialized_serializer(cls, value): + if type(value) is int and (value > MAX_INT32 or value < MIN_INT32): + return Int64TypeIO + + return Int32TypeIO + + +class Int16TypeIO(IntegerTypeIO): + prefix = 'gx' + graphson_base_type = 'Int16' + cql_type = 'smallint' + + +class Int32TypeIO(IntegerTypeIO): + graphson_base_type = 'Int32' + cql_type = 'int' + + +class Int64TypeIO(IntegerTypeIO): + graphson_base_type = 'Int64' + cql_type = 'bigint' + + @classmethod + def deserialize(cls, value, reader=None): + return value + + +class FloatTypeIO(GraphSONTypeIO): + graphson_base_type = 'Float' + cql_type = 'float' + + @classmethod + def serialize(cls, value, writer=None): + return value + + @classmethod + def deserialize(cls, value, reader=None): + return float(value) + + +class DoubleTypeIO(FloatTypeIO): + graphson_base_type = 'Double' + cql_type = 'double' + + +class BigIntegerTypeIO(IntegerTypeIO): + prefix = 'gx' + graphson_base_type = 'BigInteger' + + +class LocalDateTypeIO(GraphSONTypeIO): + FORMAT = '%Y-%m-%d' + + prefix = 'gx' + graphson_base_type = 'LocalDate' + cql_type = 'date' + + @classmethod + def serialize(cls, value, writer=None): + return value.isoformat() + + @classmethod + def deserialize(cls, value, reader=None): + try: + return datetime.datetime.strptime(value, cls.FORMAT).date() + except ValueError: + # negative date + return value + + +class InstantTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'Instant' + cql_type = 'timestamp' + + @classmethod + def serialize(cls, value, writer=None): + if isinstance(value, datetime.datetime): + value = datetime.datetime(*value.utctimetuple()[:6]).replace(microsecond=value.microsecond) + else: + value = datetime.datetime.combine(value, datetime.datetime.min.time()) + + return "{0}Z".format(value.isoformat()) + + @classmethod + def deserialize(cls, value, reader=None): + try: + d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ') + except ValueError: + d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%SZ') + return d + + +class LocalTimeTypeIO(GraphSONTypeIO): + FORMATS = [ + '%H:%M', + '%H:%M:%S', + '%H:%M:%S.%f' + ] + + prefix = 'gx' + graphson_base_type = 'LocalTime' + cql_type = 'time' + + @classmethod + def serialize(cls, value, writer=None): + return value.strftime(cls.FORMATS[2]) + + @classmethod + def deserialize(cls, value, reader=None): + dt = None + for f in cls.FORMATS: + try: + dt = datetime.datetime.strptime(value, f) + break + except ValueError: + continue + + if dt is None: + raise ValueError('Unable to decode LocalTime: {0}'.format(value)) + + return dt.time() + + +class BlobTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Blob' + cql_type = 'blob' + + @classmethod + def serialize(cls, value, writer=None): + value = base64.b64encode(value) + value = value.decode('utf-8') + return value + + @classmethod + def deserialize(cls, value, reader=None): + return bytearray(base64.b64decode(value)) + + +class ByteBufferTypeIO(BlobTypeIO): + prefix = 'gx' + graphson_base_type = 'ByteBuffer' + + +class UUIDTypeIO(GraphSONTypeIO): + graphson_base_type = 'UUID' + cql_type = 'uuid' + + @classmethod + def deserialize(cls, value, reader=None): + return uuid.UUID(value) + + +class BigDecimalTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'BigDecimal' + cql_type = 'bigdecimal' + + @classmethod + def deserialize(cls, value, reader=None): + return Decimal(value) + + +class DurationTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'Duration' + cql_type = 'duration' + + _duration_regex = re.compile(r""" + ^P((?P\d+)D)? + T((?P\d+)H)? + ((?P\d+)M)? + ((?P[0-9.]+)S)?$ + """, re.VERBOSE) + _duration_format = "P{days}DT{hours}H{minutes}M{seconds}S" + + _seconds_in_minute = 60 + _seconds_in_hour = 60 * _seconds_in_minute + _seconds_in_day = 24 * _seconds_in_hour + + @classmethod + def serialize(cls, value, writer=None): + total_seconds = int(value.total_seconds()) + days, total_seconds = divmod(total_seconds, cls._seconds_in_day) + hours, total_seconds = divmod(total_seconds, cls._seconds_in_hour) + minutes, total_seconds = divmod(total_seconds, cls._seconds_in_minute) + total_seconds += value.microseconds / 1e6 + + return cls._duration_format.format( + days=int(days), hours=int(hours), minutes=int(minutes), seconds=total_seconds + ) + + @classmethod + def deserialize(cls, value, reader=None): + duration = cls._duration_regex.match(value) + if duration is None: + raise ValueError('Invalid duration: {0}'.format(value)) + + duration = {k: float(v) if v is not None else 0 + for k, v in duration.groupdict().items()} + return datetime.timedelta(days=duration['days'], hours=duration['hours'], + minutes=duration['minutes'], seconds=duration['seconds']) + + +class DseDurationTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Duration' + cql_type = 'duration' + + @classmethod + def serialize(cls, value, writer=None): + return { + 'months': value.months, + 'days': value.days, + 'nanos': value.nanoseconds + } + + @classmethod + def deserialize(cls, value, reader=None): + return Duration( + reader.deserialize(value['months']), + reader.deserialize(value['days']), + reader.deserialize(value['nanos']) + ) + + +class TypeWrapperTypeIO(GraphSONTypeIO): + + @classmethod + def definition(cls, value, writer=None): + return {'cqlType': value.type_io.cql_type} + + @classmethod + def serialize(cls, value, writer=None): + return value.type_io.serialize(value.value) + + @classmethod + def deserialize(cls, value, reader=None): + return value.type_io.deserialize(value.value) + + +class PointTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Point' + cql_type = "org.apache.cassandra.db.marshal.PointType" + + @classmethod + def deserialize(cls, value, reader=None): + return Point.from_wkt(value) + + +class LineStringTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'LineString' + cql_type = "org.apache.cassandra.db.marshal.LineStringType" + + @classmethod + def deserialize(cls, value, reader=None): + return LineString.from_wkt(value) + + +class PolygonTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Polygon' + cql_type = "org.apache.cassandra.db.marshal.PolygonType" + + @classmethod + def deserialize(cls, value, reader=None): + return Polygon.from_wkt(value) + + +class InetTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'InetAddress' + cql_type = 'inet' + + +class VertexTypeIO(GraphSONTypeIO): + graphson_base_type = 'Vertex' + + @classmethod + def deserialize(cls, value, reader=None): + vertex = Vertex(id=reader.deserialize(value["id"]), + label=value["label"] if "label" in value else "vertex", + type='vertex', + properties={}) + # avoid the properties processing in Vertex.__init__ + vertex.properties = reader.deserialize(value.get('properties', {})) + return vertex + + +class VertexPropertyTypeIO(GraphSONTypeIO): + graphson_base_type = 'VertexProperty' + + @classmethod + def deserialize(cls, value, reader=None): + return VertexProperty(label=value['label'], + value=reader.deserialize(value["value"]), + properties=reader.deserialize(value.get('properties', {}))) + + +class EdgeTypeIO(GraphSONTypeIO): + graphson_base_type = 'Edge' + + @classmethod + def deserialize(cls, value, reader=None): + in_vertex = Vertex(id=reader.deserialize(value["inV"]), + label=value['inVLabel'], + type='vertex', + properties={}) + out_vertex = Vertex(id=reader.deserialize(value["outV"]), + label=value['outVLabel'], + type='vertex', + properties={}) + return Edge( + id=reader.deserialize(value["id"]), + label=value["label"] if "label" in value else "vertex", + type='edge', + properties=reader.deserialize(value.get("properties", {})), + inV=in_vertex, + inVLabel=value['inVLabel'], + outV=out_vertex, + outVLabel=value['outVLabel'] + ) + + +class PropertyTypeIO(GraphSONTypeIO): + graphson_base_type = 'Property' + + @classmethod + def deserialize(cls, value, reader=None): + return {value["key"]: reader.deserialize(value["value"])} + + +class PathTypeIO(GraphSONTypeIO): + graphson_base_type = 'Path' + + @classmethod + def deserialize(cls, value, reader=None): + labels = [set(label) for label in reader.deserialize(value['labels'])] + objects = [obj for obj in reader.deserialize(value['objects'])] + p = Path(labels, []) + p.objects = objects # avoid the object processing in Path.__init__ + return p + + +class TraversalMetricsTypeIO(GraphSONTypeIO): + graphson_base_type = 'TraversalMetrics' + + @classmethod + def deserialize(cls, value, reader=None): + return reader.deserialize(value) + + +class MetricsTypeIO(GraphSONTypeIO): + graphson_base_type = 'Metrics' + + @classmethod + def deserialize(cls, value, reader=None): + return reader.deserialize(value) + + +class JsonMapTypeIO(GraphSONTypeIO): + """In GraphSON2, dict are simply serialized as json map""" + + @classmethod + def serialize(cls, value, writer=None): + out = {} + for k, v in value.items(): + out[k] = writer.serialize(v, writer) + + return out + + +class MapTypeIO(GraphSONTypeIO): + """In GraphSON3, dict has its own type""" + + graphson_base_type = 'Map' + cql_type = 'map' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict([('cqlType', cls.cql_type)]) + out['definition'] = [] + for k, v in value.items(): + # we just need the first pair to write the def + out['definition'].append(writer.definition(k)) + out['definition'].append(writer.definition(v)) + break + return out + + @classmethod + def serialize(cls, value, writer=None): + out = [] + for k, v in value.items(): + out.append(writer.serialize(k, writer)) + out.append(writer.serialize(v, writer)) + + return out + + @classmethod + def deserialize(cls, value, reader=None): + out = {} + a, b = itertools.tee(value) + for key, val in zip( + itertools.islice(a, 0, None, 2), + itertools.islice(b, 1, None, 2) + ): + out[reader.deserialize(key)] = reader.deserialize(val) + return out + + +class ListTypeIO(GraphSONTypeIO): + """In GraphSON3, list has its own type""" + + graphson_base_type = 'List' + cql_type = 'list' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict([('cqlType', cls.cql_type)]) + out['definition'] = [] + if value: + out['definition'].append(writer.definition(value[0])) + return out + + @classmethod + def serialize(cls, value, writer=None): + return [writer.serialize(v, writer) for v in value] + + @classmethod + def deserialize(cls, value, reader=None): + return [reader.deserialize(obj) for obj in value] + + +class SetTypeIO(GraphSONTypeIO): + """In GraphSON3, set has its own type""" + + graphson_base_type = 'Set' + cql_type = 'set' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict([('cqlType', cls.cql_type)]) + out['definition'] = [] + for v in value: + # we only take into account the first value for the definition + out['definition'].append(writer.definition(v)) + break + return out + + @classmethod + def serialize(cls, value, writer=None): + return [writer.serialize(v, writer) for v in value] + + @classmethod + def deserialize(cls, value, reader=None): + lst = [reader.deserialize(obj) for obj in value] + + s = set(lst) + if len(s) != len(lst): + log.warning("Coercing g:Set to list due to numerical values returned by Java. " + "See TINKERPOP-1844 for details.") + return lst + + return s + + +class BulkSetTypeIO(GraphSONTypeIO): + graphson_base_type = "BulkSet" + + @classmethod + def deserialize(cls, value, reader=None): + out = [] + + a, b = itertools.tee(value) + for val, bulk in zip( + itertools.islice(a, 0, None, 2), + itertools.islice(b, 1, None, 2) + ): + val = reader.deserialize(val) + bulk = reader.deserialize(bulk) + for n in range(bulk): + out.append(val) + + return out + + +class TupleTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Tuple' + cql_type = 'tuple' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict() + out['cqlType'] = cls.cql_type + serializers = [writer.get_serializer(s) for s in value] + out['definition'] = [s.definition(v, writer) for v, s in zip(value, serializers)] + return out + + @classmethod + def serialize(cls, value, writer=None): + out = cls.definition(value, writer) + out['value'] = [writer.serialize(v, writer) for v in value] + return out + + @classmethod + def deserialize(cls, value, reader=None): + return tuple(reader.deserialize(obj) for obj in value['value']) + + +class UserTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'UDT' + cql_type = 'udt' + + FROZEN_REMOVAL_REGEX = re.compile(r'frozen<"*([^"]+)"*>') + + @classmethod + def cql_types_from_string(cls, typ): + # sanitizing: remove frozen references and double quotes... + return cql_types_from_string( + re.sub(cls.FROZEN_REMOVAL_REGEX, r'\1', typ) + ) + + @classmethod + def get_udt_definition(cls, value, writer): + user_type_name = writer.user_types[type(value)] + keyspace = writer.context['graph_name'] + return writer.context['cluster'].metadata.keyspaces[keyspace].user_types[user_type_name] + + @classmethod + def is_collection(cls, typ): + return typ in ['list', 'tuple', 'map', 'set'] + + @classmethod + def is_udt(cls, typ, writer): + keyspace = writer.context['graph_name'] + if keyspace in writer.context['cluster'].metadata.keyspaces: + return typ in writer.context['cluster'].metadata.keyspaces[keyspace].user_types + return False + + @classmethod + def field_definition(cls, types, writer, name=None): + """ + Build the udt field definition. This is required when we have a complex udt type. + """ + index = -1 + out = [OrderedDict() if name is None else OrderedDict([('fieldName', name)])] + + while types: + index += 1 + typ = types.pop(0) + if index > 0: + out.append(OrderedDict()) + + if cls.is_udt(typ, writer): + keyspace = writer.context['graph_name'] + udt = writer.context['cluster'].metadata.keyspaces[keyspace].user_types[typ] + out[index].update(cls.definition(udt, writer)) + elif cls.is_collection(typ): + out[index]['cqlType'] = typ + definition = cls.field_definition(types, writer) + out[index]['definition'] = definition if isinstance(definition, list) else [definition] + else: + out[index]['cqlType'] = typ + + return out if len(out) > 1 else out[0] + + @classmethod + def definition(cls, value, writer=None): + udt = value if isinstance(value, UserType) else cls.get_udt_definition(value, writer) + return OrderedDict([ + ('cqlType', cls.cql_type), + ('keyspace', udt.keyspace), + ('name', udt.name), + ('definition', [ + cls.field_definition(cls.cql_types_from_string(typ), writer, name=name) + for name, typ in zip(udt.field_names, udt.field_types)]) + ]) + + @classmethod + def serialize(cls, value, writer=None): + udt = cls.get_udt_definition(value, writer) + out = cls.definition(value, writer) + out['value'] = [] + for name, typ in zip(udt.field_names, udt.field_types): + out['value'].append(writer.serialize(getattr(value, name), writer)) + return out + + @classmethod + def deserialize(cls, value, reader=None): + udt_class = reader.context['cluster']._user_types[value['keyspace']][value['name']] + kwargs = zip( + list(map(lambda v: v['fieldName'], value['definition'])), + [reader.deserialize(v) for v in value['value']] + ) + return udt_class(**dict(kwargs)) + + +class TTypeIO(GraphSONTypeIO): + prefix = 'g' + graphson_base_type = 'T' + + @classmethod + def deserialize(cls, value, reader=None): + return T.name_to_value[value] + + +class _BaseGraphSONSerializer(object): + + _serializers = OrderedDict() + + @classmethod + def register(cls, type, serializer): + cls._serializers[type] = serializer + + @classmethod + def get_type_definitions(cls): + return cls._serializers.copy() + + @classmethod + def get_serializer(cls, value): + """ + Get the serializer for a python object. + + :param value: The python object. + """ + + # The serializer matching logic is as follow: + # 1. Try to find the python type by direct access. + # 2. Try to find the first serializer by class inheritance. + # 3. If no serializer found, return the raw value. + + # Note that when trying to find the serializer by class inheritance, + # the order that serializers are registered is important. The use of + # an OrderedDict is to avoid the difference between executions. + serializer = None + try: + serializer = cls._serializers[type(value)] + except KeyError: + for key, serializer_ in cls._serializers.items(): + if isinstance(value, key): + serializer = serializer_ + break + + if serializer: + # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant) + serializer = serializer.get_specialized_serializer(value) + + return serializer + + @classmethod + def serialize(cls, value, writer=None): + """ + Serialize a python object to GraphSON. + + e.g 'P42DT10H5M37S' + e.g. {'key': value} + + :param value: The python object to serialize. + :param writer: A graphson serializer for recursive types (Optional) + """ + serializer = cls.get_serializer(value) + if serializer: + return serializer.serialize(value, writer or cls) + + return value + + +class GraphSON1Serializer(_BaseGraphSONSerializer): + """ + Serialize python objects to graphson types. + """ + + # When we fall back to a superclass's serializer, we iterate over this map. + # We want that iteration order to be consistent, so we use an OrderedDict, + # not a dict. + _serializers = OrderedDict([ + (str, TextTypeIO), + (bool, BooleanTypeIO), + (bytearray, ByteBufferTypeIO), + (Decimal, BigDecimalTypeIO), + (datetime.date, LocalDateTypeIO), + (datetime.time, LocalTimeTypeIO), + (datetime.timedelta, DurationTypeIO), + (datetime.datetime, InstantTypeIO), + (uuid.UUID, UUIDTypeIO), + (Polygon, PolygonTypeIO), + (Point, PointTypeIO), + (LineString, LineStringTypeIO), + (dict, JsonMapTypeIO), + (float, FloatTypeIO) + ]) + + +GraphSON1Serializer.register(ipaddress.IPv4Address, InetTypeIO) +GraphSON1Serializer.register(ipaddress.IPv6Address, InetTypeIO) +GraphSON1Serializer.register(memoryview, ByteBufferTypeIO) +GraphSON1Serializer.register(bytes, ByteBufferTypeIO) + + +class _BaseGraphSONDeserializer(object): + + _deserializers = {} + + @classmethod + def get_type_definitions(cls): + return cls._deserializers.copy() + + @classmethod + def register(cls, graphson_type, serializer): + cls._deserializers[graphson_type] = serializer + + @classmethod + def get_deserializer(cls, graphson_type): + try: + return cls._deserializers[graphson_type] + except KeyError: + raise ValueError('Invalid `graphson_type` specified: {}'.format(graphson_type)) + + @classmethod + def deserialize(cls, graphson_type, value): + """ + Deserialize a `graphson_type` value to a python object. + + :param graphson_base_type: The graphson graphson_type. e.g. 'gx:Instant' + :param value: The graphson value to deserialize. + """ + return cls.get_deserializer(graphson_type).deserialize(value) + + +class GraphSON1Deserializer(_BaseGraphSONDeserializer): + """ + Deserialize graphson1 types to python objects. + """ + _TYPES = [UUIDTypeIO, BigDecimalTypeIO, InstantTypeIO, BlobTypeIO, ByteBufferTypeIO, + PointTypeIO, LineStringTypeIO, PolygonTypeIO, LocalDateTypeIO, + LocalTimeTypeIO, DurationTypeIO, InetTypeIO] + + _deserializers = { + t.graphson_type: t + for t in _TYPES + } + + @classmethod + def deserialize_date(cls, value): + return cls._deserializers[LocalDateTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_time(cls, value): + return cls._deserializers[LocalTimeTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_timestamp(cls, value): + return cls._deserializers[InstantTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_duration(cls, value): + return cls._deserializers[DurationTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_int(cls, value): + return int(value) + + deserialize_smallint = deserialize_int + + deserialize_varint = deserialize_int + + @classmethod + def deserialize_bigint(cls, value): + return cls.deserialize_int(value) + + @classmethod + def deserialize_double(cls, value): + return float(value) + + deserialize_float = deserialize_double + + @classmethod + def deserialize_uuid(cls, value): + return cls._deserializers[UUIDTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_decimal(cls, value): + return cls._deserializers[BigDecimalTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_blob(cls, value): + return cls._deserializers[ByteBufferTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_point(cls, value): + return cls._deserializers[PointTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_linestring(cls, value): + return cls._deserializers[LineStringTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_polygon(cls, value): + return cls._deserializers[PolygonTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_inet(cls, value): + return value + + @classmethod + def deserialize_boolean(cls, value): + return value + + +# TODO Remove in the next major +GraphSON1TypeDeserializer = GraphSON1Deserializer +GraphSON1TypeSerializer = GraphSON1Serializer + + +class GraphSON2Serializer(_BaseGraphSONSerializer): + TYPE_KEY = "@type" + VALUE_KEY = "@value" + + _serializers = GraphSON1Serializer.get_type_definitions() + + def serialize(self, value, writer=None): + """ + Serialize a type to GraphSON2. + + e.g {'@type': 'gx:Duration', '@value': 'P2DT4H'} + + :param value: The python object to serialize. + """ + serializer = self.get_serializer(value) + if not serializer: + raise ValueError("Unable to find a serializer for value of type: ".format(type(value))) + + val = serializer.serialize(value, writer or self) + if serializer is TypeWrapperTypeIO: + graphson_base_type = value.type_io.graphson_base_type + graphson_type = value.type_io.graphson_type + else: + graphson_base_type = serializer.graphson_base_type + graphson_type = serializer.graphson_type + + if graphson_base_type is None: + out = val + else: + out = {self.TYPE_KEY: graphson_type} + if val is not None: + out[self.VALUE_KEY] = val + + return out + + +GraphSON2Serializer.register(int, IntegerTypeIO) + + +class GraphSON2Deserializer(_BaseGraphSONDeserializer): + + _TYPES = GraphSON1Deserializer._TYPES + [ + Int16TypeIO, Int32TypeIO, Int64TypeIO, DoubleTypeIO, FloatTypeIO, + BigIntegerTypeIO, VertexTypeIO, VertexPropertyTypeIO, EdgeTypeIO, + PathTypeIO, PropertyTypeIO, TraversalMetricsTypeIO, MetricsTypeIO] + + _deserializers = { + t.graphson_type: t + for t in _TYPES + } + + +class GraphSON2Reader(object): + """ + GraphSON2 Reader that parse json and deserialize to python objects. + """ + + def __init__(self, context, extra_deserializer_map=None): + """ + :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize` + """ + self.context = context + self.deserializers = GraphSON2Deserializer.get_type_definitions() + if extra_deserializer_map: + self.deserializers.update(extra_deserializer_map) + + def read(self, json_data): + """ + Read and deserialize ``json_data``. + """ + return self.deserialize(json.loads(json_data)) + + def deserialize(self, obj): + """ + Deserialize GraphSON type-tagged dict values into objects mapped in self.deserializers + """ + if isinstance(obj, dict): + try: + des = self.deserializers[obj[GraphSON2Serializer.TYPE_KEY]] + return des.deserialize(obj[GraphSON2Serializer.VALUE_KEY], self) + except KeyError: + pass + # list and map are treated as normal json objs (could be isolated deserializers) + return {self.deserialize(k): self.deserialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self.deserialize(o) for o in obj] + else: + return obj + + +class TypeIOWrapper(object): + """Used to force a graphson type during serialization""" + + type_io = None + value = None + + def __init__(self, type_io, value): + self.type_io = type_io + self.value = value + + +def _wrap_value(type_io, value): + return TypeIOWrapper(type_io, value) + + +to_bigint = partial(_wrap_value, Int64TypeIO) +to_int = partial(_wrap_value, Int32TypeIO) +to_smallint = partial(_wrap_value, Int16TypeIO) +to_double = partial(_wrap_value, DoubleTypeIO) +to_float = partial(_wrap_value, FloatTypeIO) + + +class GraphSON3Serializer(GraphSON2Serializer): + + _serializers = GraphSON2Serializer.get_type_definitions() + + context = None + """A dict of the serialization context""" + + def __init__(self, context): + self.context = context + self.user_types = None + + def definition(self, value): + serializer = self.get_serializer(value) + return serializer.definition(value, self) + + def get_serializer(self, value): + """Custom get_serializer to support UDT/Tuple""" + + serializer = super(GraphSON3Serializer, self).get_serializer(value) + is_namedtuple_udt = serializer is TupleTypeIO and hasattr(value, '_fields') + if not serializer or is_namedtuple_udt: + # Check if UDT + if self.user_types is None: + try: + user_types = self.context['cluster']._user_types[self.context['graph_name']] + self.user_types = dict(map(reversed, user_types.items())) + except KeyError: + self.user_types = {} + + serializer = UserTypeIO if (is_namedtuple_udt or (type(value) in self.user_types)) else serializer + + return serializer + + +GraphSON3Serializer.register(dict, MapTypeIO) +GraphSON3Serializer.register(list, ListTypeIO) +GraphSON3Serializer.register(set, SetTypeIO) +GraphSON3Serializer.register(tuple, TupleTypeIO) +GraphSON3Serializer.register(Duration, DseDurationTypeIO) +GraphSON3Serializer.register(TypeIOWrapper, TypeWrapperTypeIO) + + +class GraphSON3Deserializer(GraphSON2Deserializer): + _TYPES = GraphSON2Deserializer._TYPES + [MapTypeIO, ListTypeIO, + SetTypeIO, TupleTypeIO, + UserTypeIO, DseDurationTypeIO, + TTypeIO, BulkSetTypeIO] + + _deserializers = {t.graphson_type: t for t in _TYPES} + + +class GraphSON3Reader(GraphSON2Reader): + """ + GraphSON3 Reader that parse json and deserialize to python objects. + """ + + def __init__(self, context, extra_deserializer_map=None): + """ + :param context: A dict of the context, mostly used as context for udt deserialization. + :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize` + """ + self.context = context + self.deserializers = GraphSON3Deserializer.get_type_definitions() + if extra_deserializer_map: + self.deserializers.update(extra_deserializer_map) diff --git a/cassandra/datastax/graph/query.py b/cassandra/datastax/graph/query.py new file mode 100644 index 0000000000..d5f2a594b3 --- /dev/null +++ b/cassandra/datastax/graph/query.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from warnings import warn + +from cassandra import ConsistencyLevel +from cassandra.query import Statement, SimpleStatement +from cassandra.datastax.graph.types import Vertex, Edge, Path, VertexProperty +from cassandra.datastax.graph.graphson import GraphSON2Reader, GraphSON3Reader + + +__all__ = [ + 'GraphProtocol', 'GraphOptions', 'GraphStatement', 'SimpleGraphStatement', + 'single_object_row_factory', 'graph_result_row_factory', 'graph_object_row_factory', + 'graph_graphson2_row_factory', 'Result', 'graph_graphson3_row_factory' +] + +# (attr, description, server option) +_graph_options = ( + ('graph_name', 'name of the targeted graph.', 'graph-name'), + ('graph_source', 'choose the graph traversal source, configured on the server side.', 'graph-source'), + ('graph_language', 'the language used in the queries (default "gremlin-groovy")', 'graph-language'), + ('graph_protocol', 'the graph protocol that the server should use for query results (default "graphson-1-0")', 'graph-results'), + ('graph_read_consistency_level', '''read `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). +Setting this overrides the native `Statement.consistency_level `_ for read operations from Cassandra persistence''', 'graph-read-consistency'), + ('graph_write_consistency_level', '''write `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). +Setting this overrides the native `Statement.consistency_level `_ for write operations to Cassandra persistence.''', 'graph-write-consistency') +) +_graph_option_names = tuple(option[0] for option in _graph_options) + +# this is defined by the execution profile attribute, not in graph options +_request_timeout_key = 'request-timeout' + + +class GraphProtocol(object): + + GRAPHSON_1_0 = b'graphson-1.0' + """ + GraphSON1 + """ + + GRAPHSON_2_0 = b'graphson-2.0' + """ + GraphSON2 + """ + + GRAPHSON_3_0 = b'graphson-3.0' + """ + GraphSON3 + """ + + +class GraphOptions(object): + """ + Options for DSE Graph Query handler. + """ + # See _graph_options map above for notes on valid options + + DEFAULT_GRAPH_PROTOCOL = GraphProtocol.GRAPHSON_1_0 + DEFAULT_GRAPH_LANGUAGE = b'gremlin-groovy' + + def __init__(self, **kwargs): + self._graph_options = {} + kwargs.setdefault('graph_source', 'g') + kwargs.setdefault('graph_language', GraphOptions.DEFAULT_GRAPH_LANGUAGE) + for attr, value in kwargs.items(): + if attr not in _graph_option_names: + warn("Unknown keyword argument received for GraphOptions: {0}".format(attr)) + setattr(self, attr, value) + + def copy(self): + new_options = GraphOptions() + new_options._graph_options = self._graph_options.copy() + return new_options + + def update(self, options): + self._graph_options.update(options._graph_options) + + def get_options_map(self, other_options=None): + """ + Returns a map for these options updated with other options, + and mapped to graph payload types. + """ + options = self._graph_options.copy() + if other_options: + options.update(other_options._graph_options) + + # cls are special-cased so they can be enums in the API, and names in the protocol + for cl in ('graph-write-consistency', 'graph-read-consistency'): + cl_enum = options.get(cl) + if cl_enum is not None: + options[cl] = ConsistencyLevel.value_to_name[cl_enum].encode() + return options + + def set_source_default(self): + """ + Sets ``graph_source`` to the server-defined default traversal source ('default') + """ + self.graph_source = 'default' + + def set_source_analytics(self): + """ + Sets ``graph_source`` to the server-defined analytic traversal source ('a') + """ + self.graph_source = 'a' + + def set_source_graph(self): + """ + Sets ``graph_source`` to the server-defined graph traversal source ('g') + """ + self.graph_source = 'g' + + def set_graph_protocol(self, protocol): + """ + Sets ``graph_protocol`` as server graph results format (See :class:`cassandra.datastax.graph.GraphProtocol`) + """ + self.graph_protocol = protocol + + @property + def is_default_source(self): + return self.graph_source in (b'default', None) + + @property + def is_analytics_source(self): + """ + True if ``graph_source`` is set to the server-defined analytics traversal source ('a') + """ + return self.graph_source == b'a' + + @property + def is_graph_source(self): + """ + True if ``graph_source`` is set to the server-defined graph traversal source ('g') + """ + return self.graph_source == b'g' + + +for opt in _graph_options: + + def get(self, key=opt[2]): + return self._graph_options.get(key) + + def set(self, value, key=opt[2]): + if value is not None: + # normalize text here so it doesn't have to be done every time we get options map + if isinstance(value, str): + value = value.encode() + self._graph_options[key] = value + else: + self._graph_options.pop(key, None) + + def delete(self, key=opt[2]): + self._graph_options.pop(key, None) + + setattr(GraphOptions, opt[0], property(get, set, delete, opt[1])) + + +class GraphStatement(Statement): + """ An abstract class representing a graph query.""" + + @property + def query(self): + raise NotImplementedError() + + def __str__(self): + return u''.format(self.query) + __repr__ = __str__ + + +class SimpleGraphStatement(GraphStatement, SimpleStatement): + """ + Simple graph statement for :meth:`.Session.execute_graph`. + Takes the same parameters as :class:`.SimpleStatement`. + """ + @property + def query(self): + return self._query_string + + +def single_object_row_factory(column_names, rows): + """ + returns the JSON string value of graph results + """ + return [row[0] for row in rows] + + +def graph_result_row_factory(column_names, rows): + """ + Returns a :class:`Result ` object that can load graph results and produce specific types. + The Result JSON is deserialized and unpacked from the top-level 'result' dict. + """ + return [Result(json.loads(row[0])['result']) for row in rows] + + +def graph_object_row_factory(column_names, rows): + """ + Like :func:`~.graph_result_row_factory`, except known element types (:class:`~.Vertex`, :class:`~.Edge`) are + converted to their simplified objects. Some low-level metadata is shed in this conversion. Unknown result types are + still returned as :class:`Result `. + """ + return _graph_object_sequence(json.loads(row[0])['result'] for row in rows) + + +def _graph_object_sequence(objects): + for o in objects: + res = Result(o) + if isinstance(o, dict): + typ = res.value.get('type') + if typ == 'vertex': + res = res.as_vertex() + elif typ == 'edge': + res = res.as_edge() + yield res + + +class _GraphSONContextRowFactory(object): + graphson_reader_class = None + graphson_reader_kwargs = None + + def __init__(self, cluster): + context = {'cluster': cluster} + kwargs = self.graphson_reader_kwargs or {} + self.graphson_reader = self.graphson_reader_class(context, **kwargs) + + def __call__(self, column_names, rows): + return [self.graphson_reader.read(row[0])['result'] for row in rows] + + +class _GraphSON2RowFactory(_GraphSONContextRowFactory): + """Row factory to deserialize GraphSON2 results.""" + graphson_reader_class = GraphSON2Reader + + +class _GraphSON3RowFactory(_GraphSONContextRowFactory): + """Row factory to deserialize GraphSON3 results.""" + graphson_reader_class = GraphSON3Reader + + +graph_graphson2_row_factory = _GraphSON2RowFactory +graph_graphson3_row_factory = _GraphSON3RowFactory + + +class Result(object): + """ + Represents deserialized graph results. + Property and item getters are provided for convenience. + """ + + value = None + """ + Deserialized value from the result + """ + + def __init__(self, value): + self.value = value + + def __getattr__(self, attr): + if not isinstance(self.value, dict): + raise ValueError("Value cannot be accessed as a dict") + + if attr in self.value: + return self.value[attr] + + raise AttributeError("Result has no top-level attribute %r" % (attr,)) + + def __getitem__(self, item): + if isinstance(self.value, dict) and isinstance(item, str): + return self.value[item] + elif isinstance(self.value, list) and isinstance(item, int): + return self.value[item] + else: + raise ValueError("Result cannot be indexed by %r" % (item,)) + + def __str__(self): + return str(self.value) + + def __repr__(self): + return "%s(%r)" % (Result.__name__, self.value) + + def __eq__(self, other): + return self.value == other.value + + def as_vertex(self): + """ + Return a :class:`Vertex` parsed from this result + + Raises TypeError if parsing fails (i.e. the result structure is not valid). + """ + try: + return Vertex(self.id, self.label, self.type, self.value.get('properties', {})) + except (AttributeError, ValueError, TypeError): + raise TypeError("Could not create Vertex from %r" % (self,)) + + def as_edge(self): + """ + Return a :class:`Edge` parsed from this result + + Raises TypeError if parsing fails (i.e. the result structure is not valid). + """ + try: + return Edge(self.id, self.label, self.type, self.value.get('properties', {}), + self.inV, self.inVLabel, self.outV, self.outVLabel) + except (AttributeError, ValueError, TypeError): + raise TypeError("Could not create Edge from %r" % (self,)) + + def as_path(self): + """ + Return a :class:`Path` parsed from this result + + Raises TypeError if parsing fails (i.e. the result structure is not valid). + """ + try: + return Path(self.labels, self.objects) + except (AttributeError, ValueError, TypeError): + raise TypeError("Could not create Path from %r" % (self,)) + + def as_vertex_property(self): + return VertexProperty(self.value.get('label'), self.value.get('value'), self.value.get('properties', {})) diff --git a/cassandra/datastax/graph/types.py b/cassandra/datastax/graph/types.py new file mode 100644 index 0000000000..75902c6622 --- /dev/null +++ b/cassandra/datastax/graph/types.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['Element', 'Vertex', 'Edge', 'VertexProperty', 'Path', 'T'] + + +class Element(object): + + element_type = None + + _attrs = ('id', 'label', 'type', 'properties') + + def __init__(self, id, label, type, properties): + if type != self.element_type: + raise TypeError("Attempted to create %s from %s element", (type, self.element_type)) + + self.id = id + self.label = label + self.type = type + self.properties = self._extract_properties(properties) + + @staticmethod + def _extract_properties(properties): + return dict(properties) + + def __eq__(self, other): + return all(getattr(self, attr) == getattr(other, attr) for attr in self._attrs) + + def __str__(self): + return str(dict((k, getattr(self, k)) for k in self._attrs)) + + +class Vertex(Element): + """ + Represents a Vertex element from a graph query. + + Vertex ``properties`` are extracted into a ``dict`` of property names to list of :class:`~VertexProperty` (list + because they are always encoded that way, and sometimes have multiple cardinality; VertexProperty because sometimes + the properties themselves have property maps). + """ + + element_type = 'vertex' + + @staticmethod + def _extract_properties(properties): + # vertex properties are always encoded as a list, regardless of Cardinality + return dict((k, [VertexProperty(k, p['value'], p.get('properties')) for p in v]) for k, v in properties.items()) + + def __repr__(self): + properties = dict((name, [{'label': prop.label, 'value': prop.value, 'properties': prop.properties} for prop in prop_list]) + for name, prop_list in self.properties.items()) + return "%s(%r, %r, %r, %r)" % (self.__class__.__name__, + self.id, self.label, + self.type, properties) + + +class VertexProperty(object): + """ + Vertex properties have a top-level value and an optional ``dict`` of properties. + """ + + label = None + """ + label of the property + """ + + value = None + """ + Value of the property + """ + + properties = None + """ + dict of properties attached to the property + """ + + def __init__(self, label, value, properties=None): + self.label = label + self.value = value + self.properties = properties or {} + + def __eq__(self, other): + return isinstance(other, VertexProperty) and self.label == other.label and self.value == other.value and self.properties == other.properties + + def __repr__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.label, self.value, self.properties) + + +class Edge(Element): + """ + Represents an Edge element from a graph query. + + Attributes match initializer parameters. + """ + + element_type = 'edge' + + _attrs = Element._attrs + ('inV', 'inVLabel', 'outV', 'outVLabel') + + def __init__(self, id, label, type, properties, + inV, inVLabel, outV, outVLabel): + super(Edge, self).__init__(id, label, type, properties) + self.inV = inV + self.inVLabel = inVLabel + self.outV = outV + self.outVLabel = outVLabel + + def __repr__(self): + return "%s(%r, %r, %r, %r, %r, %r, %r, %r)" %\ + (self.__class__.__name__, + self.id, self.label, + self.type, self.properties, + self.inV, self.inVLabel, + self.outV, self.outVLabel) + + +class Path(object): + """ + Represents a graph path. + + Labels list is taken verbatim from the results. + + Objects are either :class:`~.Result` or :class:`~.Vertex`/:class:`~.Edge` for recognized types + """ + + labels = None + """ + List of labels in the path + """ + + objects = None + """ + List of objects in the path + """ + + def __init__(self, labels, objects): + # TODO fix next major + # The Path class should not do any deserialization by itself. To fix in the next major. + from cassandra.datastax.graph.query import _graph_object_sequence + self.labels = labels + self.objects = list(_graph_object_sequence(objects)) + + def __eq__(self, other): + return self.labels == other.labels and self.objects == other.objects + + def __str__(self): + return str({'labels': self.labels, 'objects': self.objects}) + + def __repr__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.labels, [o.value for o in self.objects]) + + +class T(object): + """ + Represents a collection of tokens for more concise Traversal definitions. + """ + + name = None + val = None + + # class attributes + id = None + """ + """ + + key = None + """ + """ + label = None + """ + """ + value = None + """ + """ + + def __init__(self, name, val): + self.name = name + self.val = val + + def __str__(self): + return self.name + + def __repr__(self): + return "T.%s" % (self.name, ) + + +T.id = T("id", 1) +T.id_ = T("id_", 2) +T.key = T("key", 3) +T.label = T("label", 4) +T.value = T("value", 5) + +T.name_to_value = { + 'id': T.id, + 'id_': T.id_, + 'key': T.key, + 'label': T.label, + 'value': T.value +} diff --git a/cassandra/datastax/insights/__init__.py b/cassandra/datastax/insights/__init__.py new file mode 100644 index 0000000000..635f0d9e60 --- /dev/null +++ b/cassandra/datastax/insights/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cassandra/datastax/insights/registry.py b/cassandra/datastax/insights/registry.py new file mode 100644 index 0000000000..523af4dc84 --- /dev/null +++ b/cassandra/datastax/insights/registry.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from warnings import warn + +from cassandra.datastax.insights.util import namespace + +_NOT_SET = object() + + +def _default_serializer_for_object(obj, policy): + # the insights server expects an 'options' dict for policy + # objects, but not for other objects + if policy: + return {'type': obj.__class__.__name__, + 'namespace': namespace(obj.__class__), + 'options': {}} + else: + return {'type': obj.__class__.__name__, + 'namespace': namespace(obj.__class__)} + + +class InsightsSerializerRegistry(object): + + initialized = False + + def __init__(self, mapping_dict=None): + mapping_dict = mapping_dict or {} + class_order = self._class_topological_sort(mapping_dict) + self._mapping_dict = OrderedDict( + ((cls, mapping_dict[cls]) for cls in class_order) + ) + + def serialize(self, obj, policy=False, default=_NOT_SET, cls=None): + try: + return self._get_serializer(cls if cls is not None else obj.__class__)(obj) + except Exception: + if default is _NOT_SET: + result = _default_serializer_for_object(obj, policy) + else: + result = default + + return result + + def _get_serializer(self, cls): + try: + return self._mapping_dict[cls] + except KeyError: + for registered_cls, serializer in self._mapping_dict.items(): + if issubclass(cls, registered_cls): + return self._mapping_dict[registered_cls] + raise ValueError + + def register(self, cls, serializer): + self._mapping_dict[cls] = serializer + self._mapping_dict = OrderedDict( + ((cls, self._mapping_dict[cls]) + for cls in self._class_topological_sort(self._mapping_dict)) + ) + + def register_serializer_for(self, cls): + """ + Parameterized registration helper decorator. Given a class `cls`, + produces a function that registers the decorated function as a + serializer for it. + """ + def decorator(serializer): + self.register(cls, serializer) + return serializer + + return decorator + + @staticmethod + def _class_topological_sort(classes): + """ + A simple topological sort for classes. Takes an iterable of class objects + and returns a list A of those classes, ordered such that A[X] is never a + superclass of A[Y] for X < Y. + + This is an inefficient sort, but that's ok because classes are infrequently + registered. It's more important that this be maintainable than fast. + + We can't use `.sort()` or `sorted()` with a custom `key` -- those assume + a total ordering, which we don't have. + """ + unsorted, sorted_ = list(classes), [] + while unsorted: + head, tail = unsorted[0], unsorted[1:] + + # if head has no subclasses remaining, it can safely go in the list + if not any(issubclass(x, head) for x in tail): + sorted_.append(head) + else: + # move to the back -- head has to wait until all its subclasses + # are sorted into the list + tail.append(head) + + unsorted = tail + + # check that sort is valid + for i, head in enumerate(sorted_): + for after_head_value in sorted_[(i + 1):]: + if issubclass(after_head_value, head): + warn('Sorting classes produced an invalid ordering.\n' + 'In: {classes}\n' + 'Out: {sorted_}'.format(classes=classes, sorted_=sorted_)) + return sorted_ + + +insights_registry = InsightsSerializerRegistry() diff --git a/cassandra/datastax/insights/reporter.py b/cassandra/datastax/insights/reporter.py new file mode 100644 index 0000000000..607c723a1a --- /dev/null +++ b/cassandra/datastax/insights/reporter.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter +import datetime +import json +import logging +import multiprocessing +import random +import platform +import socket +import ssl +import sys +from threading import Event, Thread +import time + +from cassandra.policies import HostDistance +from cassandra.util import ms_timestamp_from_datetime +from cassandra.datastax.insights.registry import insights_registry +from cassandra.datastax.insights.serializers import initialize_registry + +log = logging.getLogger(__name__) + + +class MonitorReporter(Thread): + + def __init__(self, interval_sec, session): + """ + takes an int indicating interval between requests, a function returning + the connection to be used, and the timeout per request + """ + # Thread is an old-style class so we can't super() + Thread.__init__(self, name='monitor_reporter') + + initialize_registry(insights_registry) + + self._interval, self._session = interval_sec, session + + self._shutdown_event = Event() + self.daemon = True + self.start() + + def run(self): + self._send_via_rpc(self._get_startup_data()) + + # introduce some jitter -- send up to 1/10 of _interval early + self._shutdown_event.wait(self._interval * random.uniform(.9, 1)) + + while not self._shutdown_event.is_set(): + start_time = time.time() + + self._send_via_rpc(self._get_status_data()) + + elapsed = time.time() - start_time + self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) + + # TODO: redundant with ConnectionHeartbeat.ShutdownException + class ShutDownException(Exception): + pass + + def _send_via_rpc(self, data): + try: + self._session.execute( + "CALL InsightsRpc.reportInsight(%s)", (json.dumps(data),) + ) + log.debug('Insights RPC data: {}'.format(data)) + except Exception as e: + log.debug('Insights RPC send failed with {}'.format(e)) + log.debug('Insights RPC data: {}'.format(data)) + + def _get_status_data(self): + cc = self._session.cluster.control_connection + + connected_nodes = { + host.address: { + 'connections': state['open_count'], + 'inFlightQueries': state['in_flights'] + } + for (host, state) in self._session.get_pool_state().items() + } + + return { + 'metadata': { + # shared across drivers; never change + 'name': 'driver.status', + # format version + 'insightMappingId': 'v1', + 'insightType': 'EVENT', + # since epoch + 'timestamp': ms_timestamp_from_datetime(datetime.datetime.utcnow()), + 'tags': { + 'language': 'python' + } + }, + # // 'clientId', 'sessionId' and 'controlConnection' are mandatory + # // the rest of the properties are optional + 'data': { + # // 'clientId' must be the same as the one provided in the startup message + 'clientId': str(self._session.cluster.client_id), + # // 'sessionId' must be the same as the one provided in the startup message + 'sessionId': str(self._session.session_id), + 'controlConnection': cc._connection.host if cc._connection else None, + 'connectedNodes': connected_nodes + } + } + + def _get_startup_data(self): + cc = self._session.cluster.control_connection + try: + local_ipaddr = cc._connection._socket.getsockname()[0] + except Exception as e: + local_ipaddr = None + log.debug('Unable to get local socket addr from {}: {}'.format(cc._connection, e)) + hostname = socket.getfqdn() + + host_distances_counter = Counter( + self._session.cluster.profile_manager.distance(host) + for host in self._session.hosts + ) + host_distances_dict = { + 'local': host_distances_counter[HostDistance.LOCAL], + 'remote': host_distances_counter[HostDistance.REMOTE], + 'ignored': host_distances_counter[HostDistance.IGNORED] + } + + try: + compression_type = cc._connection._compression_type + except AttributeError: + compression_type = 'NONE' + + cert_validation = None + try: + if self._session.cluster.ssl_context: + if isinstance(self._session.cluster.ssl_context, ssl.SSLContext): + cert_validation = self._session.cluster.ssl_context.verify_mode == ssl.CERT_REQUIRED + else: # pyopenssl + from OpenSSL import SSL + cert_validation = self._session.cluster.ssl_context.get_verify_mode() != SSL.VERIFY_NONE + elif self._session.cluster.ssl_options: + cert_validation = self._session.cluster.ssl_options.get('cert_reqs') == ssl.CERT_REQUIRED + except Exception as e: + log.debug('Unable to get the cert validation: {}'.format(e)) + + uname_info = platform.uname() + + return { + 'metadata': { + 'name': 'driver.startup', + 'insightMappingId': 'v1', + 'insightType': 'EVENT', + 'timestamp': ms_timestamp_from_datetime(datetime.datetime.utcnow()), + 'tags': { + 'language': 'python' + }, + }, + 'data': { + 'driverName': 'DataStax Python Driver', + 'driverVersion': sys.modules['cassandra'].__version__, + 'clientId': str(self._session.cluster.client_id), + 'sessionId': str(self._session.session_id), + 'applicationName': self._session.cluster.application_name or 'python', + 'applicationNameWasGenerated': not self._session.cluster.application_name, + 'applicationVersion': self._session.cluster.application_version, + 'contactPoints': self._session.cluster._endpoint_map_for_insights, + 'dataCenters': list(set(h.datacenter for h in self._session.cluster.metadata.all_hosts() + if (h.datacenter and + self._session.cluster.profile_manager.distance(h) == HostDistance.LOCAL))), + 'initialControlConnection': cc._connection.host if cc._connection else None, + 'protocolVersion': self._session.cluster.protocol_version, + 'localAddress': local_ipaddr, + 'hostName': hostname, + 'executionProfiles': insights_registry.serialize(self._session.cluster.profile_manager), + 'configuredConnectionLength': host_distances_dict, + 'heartbeatInterval': self._session.cluster.idle_heartbeat_interval, + 'compression': compression_type.upper() if compression_type else 'NONE', + 'reconnectionPolicy': insights_registry.serialize(self._session.cluster.reconnection_policy), + 'sslConfigured': { + 'enabled': bool(self._session.cluster.ssl_options or self._session.cluster.ssl_context), + 'certValidation': cert_validation + }, + 'authProvider': { + 'type': (self._session.cluster.auth_provider.__class__.__name__ + if self._session.cluster.auth_provider else + None) + }, + 'otherOptions': { + }, + 'platformInfo': { + 'os': { + 'name': uname_info.system, + 'version': uname_info.release, + 'arch': uname_info.machine + }, + 'cpus': { + 'length': multiprocessing.cpu_count(), + 'model': platform.processor() + }, + 'runtime': { + 'python': sys.version, + 'event_loop': self._session.cluster.connection_class.__name__ + } + }, + 'periodicStatusInterval': self._interval + } + } + + def stop(self): + log.debug("Shutting down Monitor Reporter") + self._shutdown_event.set() + self.join() diff --git a/cassandra/datastax/insights/serializers.py b/cassandra/datastax/insights/serializers.py new file mode 100644 index 0000000000..b1fe0ac5e9 --- /dev/null +++ b/cassandra/datastax/insights/serializers.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def initialize_registry(insights_registry): + # This will be called from the cluster module, so we put all this behavior + # in a function to avoid circular imports + + if insights_registry.initialized: + return False + + from cassandra import ConsistencyLevel + from cassandra.cluster import ( + ExecutionProfile, GraphExecutionProfile, + ProfileManager, ContinuousPagingOptions, + EXEC_PROFILE_DEFAULT, EXEC_PROFILE_GRAPH_DEFAULT, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + _NOT_SET + ) + from cassandra.datastax.graph import GraphOptions + from cassandra.datastax.insights.registry import insights_registry + from cassandra.datastax.insights.util import namespace + from cassandra.policies import ( + RoundRobinPolicy, + DCAwareRoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, + HostFilterPolicy, + ConstantReconnectionPolicy, + ExponentialReconnectionPolicy, + RetryPolicy, + SpeculativeExecutionPolicy, + ConstantSpeculativeExecutionPolicy, + WrapperPolicy + ) + + import logging + + log = logging.getLogger(__name__) + + @insights_registry.register_serializer_for(RoundRobinPolicy) + def round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {}} + + @insights_registry.register_serializer_for(DCAwareRoundRobinPolicy) + def dc_aware_round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'local_dc': policy.local_dc, + 'used_hosts_per_remote_dc': policy.used_hosts_per_remote_dc} + } + + @insights_registry.register_serializer_for(TokenAwarePolicy) + def token_aware_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'child_policy': insights_registry.serialize(policy._child_policy, + policy=True), + 'shuffle_replicas': policy.shuffle_replicas} + } + + @insights_registry.register_serializer_for(WhiteListRoundRobinPolicy) + def whitelist_round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'allowed_hosts': policy._allowed_hosts} + } + + @insights_registry.register_serializer_for(HostFilterPolicy) + def host_filter_policy_insights_serializer(policy): + return { + 'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'child_policy': insights_registry.serialize(policy._child_policy, + policy=True), + 'predicate': policy.predicate.__name__} + } + + @insights_registry.register_serializer_for(ConstantReconnectionPolicy) + def constant_reconnection_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'delay': policy.delay, + 'max_attempts': policy.max_attempts} + } + + @insights_registry.register_serializer_for(ExponentialReconnectionPolicy) + def exponential_reconnection_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'base_delay': policy.base_delay, + 'max_delay': policy.max_delay, + 'max_attempts': policy.max_attempts} + } + + @insights_registry.register_serializer_for(RetryPolicy) + def retry_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {}} + + @insights_registry.register_serializer_for(SpeculativeExecutionPolicy) + def speculative_execution_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {}} + + @insights_registry.register_serializer_for(ConstantSpeculativeExecutionPolicy) + def constant_speculative_execution_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'delay': policy.delay, + 'max_attempts': policy.max_attempts} + } + + @insights_registry.register_serializer_for(WrapperPolicy) + def wrapper_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': { + 'child_policy': insights_registry.serialize(policy._child_policy, + policy=True) + }} + + @insights_registry.register_serializer_for(ExecutionProfile) + def execution_profile_insights_serializer(profile): + return { + 'loadBalancing': insights_registry.serialize(profile.load_balancing_policy, + policy=True), + 'retry': insights_registry.serialize(profile.retry_policy, + policy=True), + 'readTimeout': profile.request_timeout, + 'consistency': ConsistencyLevel.value_to_name.get(profile.consistency_level, None), + 'serialConsistency': ConsistencyLevel.value_to_name.get(profile.serial_consistency_level, None), + 'continuousPagingOptions': (insights_registry.serialize(profile.continuous_paging_options) + if (profile.continuous_paging_options is not None and + profile.continuous_paging_options is not _NOT_SET) else + None), + 'speculativeExecution': insights_registry.serialize(profile.speculative_execution_policy), + 'graphOptions': None + } + + @insights_registry.register_serializer_for(GraphExecutionProfile) + def graph_execution_profile_insights_serializer(profile): + rv = insights_registry.serialize(profile, cls=ExecutionProfile) + rv['graphOptions'] = insights_registry.serialize(profile.graph_options) + return rv + + _EXEC_PROFILE_DEFAULT_KEYS = (EXEC_PROFILE_DEFAULT, + EXEC_PROFILE_GRAPH_DEFAULT, + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT) + + @insights_registry.register_serializer_for(ProfileManager) + def profile_manager_insights_serializer(manager): + defaults = { + # Insights's expected default + 'default': insights_registry.serialize(manager.profiles[EXEC_PROFILE_DEFAULT]), + # remaining named defaults for driver's defaults, including duplicated default + 'EXEC_PROFILE_DEFAULT': insights_registry.serialize(manager.profiles[EXEC_PROFILE_DEFAULT]), + 'EXEC_PROFILE_GRAPH_DEFAULT': insights_registry.serialize(manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT]), + 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT': insights_registry.serialize( + manager.profiles[EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT] + ), + 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT': insights_registry.serialize( + manager.profiles[EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT] + ) + } + other = { + key: insights_registry.serialize(value) + for key, value in manager.profiles.items() + if key not in _EXEC_PROFILE_DEFAULT_KEYS + } + overlapping_keys = set(defaults) & set(other) + if overlapping_keys: + log.debug('The following key names overlap default key sentinel keys ' + 'and these non-default EPs will not be displayed in Insights ' + ': {}'.format(list(overlapping_keys))) + + other.update(defaults) + return other + + @insights_registry.register_serializer_for(GraphOptions) + def graph_options_insights_serializer(options): + rv = { + 'source': options.graph_source, + 'language': options.graph_language, + 'graphProtocol': options.graph_protocol + } + updates = {k: v.decode('utf-8') for k, v in rv.items() + if isinstance(v, bytes)} + rv.update(updates) + return rv + + @insights_registry.register_serializer_for(ContinuousPagingOptions) + def continuous_paging_options_insights_serializer(paging_options): + return { + 'page_unit': paging_options.page_unit, + 'max_pages': paging_options.max_pages, + 'max_pages_per_second': paging_options.max_pages_per_second, + 'max_queue_size': paging_options.max_queue_size + } + + insights_registry.initialized = True + return True diff --git a/cassandra/datastax/insights/util.py b/cassandra/datastax/insights/util.py new file mode 100644 index 0000000000..0ce96c7edf --- /dev/null +++ b/cassandra/datastax/insights/util.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import traceback +from warnings import warn + +from cassandra.util import Version + + +DSE_60 = Version('6.0.0') +DSE_51_MIN_SUPPORTED = Version('5.1.13') +DSE_60_MIN_SUPPORTED = Version('6.0.5') + + +log = logging.getLogger(__name__) + + +def namespace(cls): + """ + Best-effort method for getting the namespace in which a class is defined. + """ + try: + # __module__ can be None + module = cls.__module__ or '' + except Exception: + warn("Unable to obtain namespace for {cls} for Insights, returning ''. " + "Exception: \n{e}".format(e=traceback.format_exc(), cls=cls)) + module = '' + + module_internal_namespace = _module_internal_namespace_or_emtpy_string(cls) + if module_internal_namespace: + return '.'.join((module, module_internal_namespace)) + return module + + +def _module_internal_namespace_or_emtpy_string(cls): + """ + Best-effort method for getting the module-internal namespace in which a + class is defined -- i.e. the namespace _inside_ the module. + """ + try: + qualname = cls.__qualname__ + except AttributeError: + return '' + + return '.'.join( + # the last segment is the name of the class -- use everything else + qualname.split('.')[:-1] + ) + + +def version_supports_insights(dse_version): + if dse_version: + try: + dse_version = Version(dse_version) + return (DSE_51_MIN_SUPPORTED <= dse_version < DSE_60 + or + DSE_60_MIN_SUPPORTED <= dse_version) + except Exception: + warn("Unable to check version {v} for Insights compatibility, returning False. " + "Exception: \n{e}".format(e=traceback.format_exc(), v=dse_version)) + + return False diff --git a/cassandra/decoder.py b/cassandra/decoder.py deleted file mode 100644 index eb807ab09a..0000000000 --- a/cassandra/decoder.py +++ /dev/null @@ -1,832 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple -try: - from collections import OrderedDict -except ImportError: # Python <2.7 - from cassandra.util import OrderedDict # NOQA - -import datetime -import logging -import socket -import types -from uuid import UUID -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO # ignore flake8 warning: # NOQA - -from cassandra import (ConsistencyLevel, Unavailable, WriteTimeout, ReadTimeout, - AlreadyExists, InvalidRequest, Unauthorized) -from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, - int8_pack, int8_unpack) -from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, - CounterColumnType, DateType, DecimalType, - DoubleType, FloatType, Int32Type, - InetAddressType, IntegerType, ListType, - LongType, MapType, SetType, TimeUUIDType, - UTF8Type, UUIDType) - -log = logging.getLogger(__name__) - -class NotSupportedError(Exception): - pass - - -class InternalError(Exception): - pass - - -PROTOCOL_VERSION = 0x01 -PROTOCOL_VERSION_MASK = 0x7f - -HEADER_DIRECTION_FROM_CLIENT = 0x00 -HEADER_DIRECTION_TO_CLIENT = 0x80 -HEADER_DIRECTION_MASK = 0x80 - - -def tuple_factory(colnames, rows): - return rows - - -def named_tuple_factory(colnames, rows): - Row = namedtuple('Row', colnames) - return [Row(*row) for row in rows] - - -def dict_factory(colnames, rows): - return [dict(zip(colnames, row)) for row in rows] - - -def ordered_dict_factory(colnames, rows): - return [OrderedDict(zip(colnames, row)) for row in rows] - - -_message_types_by_name = {} -_message_types_by_opcode = {} - -class _register_msg_type(type): - def __init__(cls, name, bases, dct): - if not name.startswith('_'): - _message_types_by_name[cls.name] = cls - _message_types_by_opcode[cls.opcode] = cls - - -class _MessageType(object): - __metaclass__ = _register_msg_type - params = () - - tracing = False - - def __init__(self, **kwargs): - for pname in self.params: - try: - pval = kwargs[pname] - except KeyError: - raise ValueError("%s instances need the %s keyword parameter" - % (self.__class__.__name__, pname)) - setattr(self, pname, pval) - - def to_string(self, stream_id, compression=None): - body = StringIO() - self.send_body(body) - body = body.getvalue() - version = PROTOCOL_VERSION | HEADER_DIRECTION_FROM_CLIENT - flags = 0 - if compression is not None and len(body) > 0: - body = compression(body) - flags |= 0x01 - if self.tracing: - flags |= 0x02 - msglen = int32_pack(len(body)) - msg_parts = map(int8_pack, (version, flags, stream_id, self.opcode)) + [msglen, body] - return ''.join(msg_parts) - - def send(self, f, streamid, compression=None): - body = StringIO() - self.send_body(body) - body = body.getvalue() - version = PROTOCOL_VERSION | HEADER_DIRECTION_FROM_CLIENT - flags = 0 - if compression is not None and len(body) > 0: - body = compression(body) - flags |= 0x01 - if self.tracing: - flags |= 0x02 - msglen = int32_pack(len(body)) - header = ''.join(map(int8_pack, (version, flags, streamid, self.opcode))) \ - + msglen - f.write(header) - if len(body) > 0: - f.write(body) - - def __str__(self): - paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in self.params] - return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs)) - __repr__ = __str__ - - -def decode_response(stream_id, flags, opcode, body, decompressor=None): - if flags & 0x01: - if decompressor is None: - raise Exception("No decompressor available for compressed frame!") - body = decompressor(body) - flags ^= 0x01 - - body = StringIO(body) - if flags & 0x02: - trace_id = UUID(bytes=body.read(16)) - flags ^= 0x02 - else: - trace_id = None - - if flags: - log.warn("Unknown protocol flags set: %02x. May cause problems." % flags) - - msg_class = _message_types_by_opcode[opcode] - msg = msg_class.recv_body(body) - msg.stream_id = stream_id - msg.trace_id = trace_id - return msg - - -error_classes = {} - -class ErrorMessage(_MessageType, Exception): - opcode = 0x00 - name = 'ERROR' - params = ('code', 'message', 'info') - summary = 'Unknown' - - @classmethod - def recv_body(cls, f): - code = read_int(f) - msg = read_string(f) - subcls = error_classes.get(code, cls) - extra_info = subcls.recv_error_info(f) - return subcls(code=code, message=msg, info=extra_info) - - def summary_msg(self): - msg = 'code=%04x [%s] message="%s"' \ - % (self.code, self.summary, self.message) - if self.info is not None: - msg += (' info=' + repr(self.info)) - return msg - - def __str__(self): - return '' % self.summary_msg() - __repr__ = __str__ - - @staticmethod - def recv_error_info(f): - pass - - def to_exception(self): - return self - - -class ErrorMessageSubclass(_register_msg_type): - def __init__(cls, name, bases, dct): - if cls.error_code is not None: - error_classes[cls.error_code] = cls - - -class ErrorMessageSub(ErrorMessage): - __metaclass__ = ErrorMessageSubclass - error_code = None - - -class RequestExecutionException(ErrorMessageSub): - pass - - -class RequestValidationException(ErrorMessageSub): - pass - - -class ServerError(ErrorMessageSub): - summary = 'Server error' - error_code = 0x0000 - - -class ProtocolException(ErrorMessageSub): - summary = 'Protocol error' - error_code = 0x000A - - -class UnavailableErrorMessage(RequestExecutionException): - summary = 'Unavailable exception' - error_code = 0x1000 - - @staticmethod - def recv_error_info(f): - return { - 'consistency': read_consistency_level(f), - 'required_replicas': read_int(f), - 'alive_replicas': read_int(f), - } - - def to_exception(self): - return Unavailable(self.summary_msg(), **self.info) - - -class OverloadedErrorMessage(RequestExecutionException): - summary = 'Coordinator node overloaded' - error_code = 0x1001 - - -class IsBootstrappingErrorMessage(RequestExecutionException): - summary = 'Coordinator node is bootstrapping' - error_code = 0x1002 - - -class TruncateError(RequestExecutionException): - summary = 'Error during truncate' - error_code = 0x1003 - - -class WriteTimeoutErrorMessage(RequestExecutionException): - summary = 'Timeout during write request' - error_code = 0x1100 - - @staticmethod - def recv_error_info(f): - return { - 'consistency': read_consistency_level(f), - 'received_responses': read_int(f), - 'required_responses': read_int(f), - 'write_type': read_string(f), - } - - def to_exception(self): - return WriteTimeout(self.summary_msg(), **self.info) - - -class ReadTimeoutErrorMessage(RequestExecutionException): - summary = 'Timeout during read request' - error_code = 0x1200 - - @staticmethod - def recv_error_info(f): - return { - 'consistency': read_consistency_level(f), - 'received_responses': read_int(f), - 'required_responses': read_int(f), - 'data_retrieved': bool(read_byte(f)), - } - - def to_exception(self): - return ReadTimeout(self.summary_msg(), **self.info) - - -class SyntaxException(RequestValidationException): - summary = 'Syntax error in CQL query' - error_code = 0x2000 - - -class UnauthorizedErrorMessage(RequestValidationException): - summary = 'Unauthorized' - error_code = 0x2100 - - def to_exception(self): - return Unauthorized(self.summary_msg()) - - -class InvalidRequestException(RequestValidationException): - summary = 'Invalid query' - error_code = 0x2200 - - def to_exception(self): - return InvalidRequest(self.summary_msg()) - - -class ConfigurationException(RequestValidationException): - summary = 'Query invalid because of configuration issue' - error_code = 0x2300 - - -class PreparedQueryNotFound(RequestValidationException): - summary = 'Matching prepared statement not found on this node' - error_code = 0x2500 - - @staticmethod - def recv_error_info(f): - # return the query ID - return read_binary_string(f) - - -class AlreadyExistsException(ConfigurationException): - summary = 'Item already exists' - error_code = 0x2400 - - @staticmethod - def recv_error_info(f): - return { - 'keyspace': read_string(f), - 'table': read_string(f), - } - - def to_exception(self): - return AlreadyExists(**self.info) - - -class StartupMessage(_MessageType): - opcode = 0x01 - name = 'STARTUP' - params = ('cqlversion', 'options') - - KNOWN_OPTION_KEYS = set(( - 'CQL_VERSION', - 'COMPRESSION', - )) - - def send_body(self, f): - optmap = self.options.copy() - optmap['CQL_VERSION'] = self.cqlversion - write_stringmap(f, optmap) - - -class ReadyMessage(_MessageType): - opcode = 0x02 - name = 'READY' - params = () - - @classmethod - def recv_body(cls, f): - return cls() - - -class AuthenticateMessage(_MessageType): - opcode = 0x03 - name = 'AUTHENTICATE' - params = ('authenticator',) - - @classmethod - def recv_body(cls, f): - authname = read_string(f) - return cls(authenticator=authname) - - -class CredentialsMessage(_MessageType): - opcode = 0x04 - name = 'CREDENTIALS' - params = ('creds',) - - def send_body(self, f): - write_short(f, len(self.creds)) - for credkey, credval in self.creds.items(): - write_string(f, credkey) - write_string(f, credval) - - -class OptionsMessage(_MessageType): - opcode = 0x05 - name = 'OPTIONS' - params = () - - def send_body(self, f): - pass - - -class SupportedMessage(_MessageType): - opcode = 0x06 - name = 'SUPPORTED' - params = ('cql_versions', 'options',) - - @classmethod - def recv_body(cls, f): - options = read_stringmultimap(f) - cql_versions = options.pop('CQL_VERSION') - return cls(cql_versions=cql_versions, options=options) - - -class QueryMessage(_MessageType): - opcode = 0x07 - name = 'QUERY' - params = ('query', 'consistency_level',) - - def send_body(self, f): - write_longstring(f, self.query) - write_consistency_level(f, self.consistency_level) - - -class ResultMessage(_MessageType): - opcode = 0x08 - name = 'RESULT' - params = ('kind', 'results',) - - KIND_VOID = 0x0001 - KIND_ROWS = 0x0002 - KIND_SET_KEYSPACE = 0x0003 - KIND_PREPARED = 0x0004 - KIND_SCHEMA_CHANGE = 0x0005 - - type_codes = { - 0x0001: AsciiType, - 0x0002: LongType, - 0x0003: BytesType, - 0x0004: BooleanType, - 0x0005: CounterColumnType, - 0x0006: DecimalType, - 0x0007: DoubleType, - 0x0008: FloatType, - 0x0009: Int32Type, - 0x000A: UTF8Type, - 0x000B: DateType, - 0x000C: UUIDType, - 0x000D: UTF8Type, - 0x000E: IntegerType, - 0x000F: TimeUUIDType, - 0x0010: InetAddressType, - 0x0020: ListType, - 0x0021: MapType, - 0x0022: SetType, - } - - FLAGS_GLOBAL_TABLES_SPEC = 0x0001 - - @classmethod - def recv_body(cls, f): - kind = read_int(f) - if kind == cls.KIND_VOID: - results = None - elif kind == cls.KIND_ROWS: - results = cls.recv_results_rows(f) - elif kind == cls.KIND_SET_KEYSPACE: - ksname = read_string(f) - results = ksname - elif kind == cls.KIND_PREPARED: - results = cls.recv_results_prepared(f) - elif kind == cls.KIND_SCHEMA_CHANGE: - results = cls.recv_results_schema_change(f) - return cls(kind=kind, results=results) - - @classmethod - def recv_results_rows(cls, f): - column_metadata = cls.recv_results_metadata(f) - rowcount = read_int(f) - rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)] - colnames = [c[2] for c in column_metadata] - coltypes = [c[3] for c in column_metadata] - return (colnames, [tuple(ctype.from_binary(val) for ctype, val in zip(coltypes, row)) - for row in rows]) - - @classmethod - def recv_results_prepared(cls, f): - query_id = read_binary_string(f) - column_metadata = cls.recv_results_metadata(f) - return (query_id, column_metadata) - - @classmethod - def recv_results_metadata(cls, f): - flags = read_int(f) - glob_tblspec = bool(flags & cls.FLAGS_GLOBAL_TABLES_SPEC) - colcount = read_int(f) - if glob_tblspec: - ksname = read_string(f) - cfname = read_string(f) - column_metadata = [] - for x in xrange(colcount): - if glob_tblspec: - colksname = ksname - colcfname = cfname - else: - colksname = read_string(f) - colcfname = read_string(f) - colname = read_string(f) - coltype = cls.read_type(f) - column_metadata.append((colksname, colcfname, colname, coltype)) - return column_metadata - - @classmethod - def recv_results_schema_change(cls, f): - change_type = read_string(f) - keyspace = read_string(f) - table = read_string(f) - return dict(change_type=change_type, keyspace=keyspace, table=table) - - @classmethod - def read_type(cls, f): - optid = read_short(f) - try: - typeclass = cls.type_codes[optid] - except KeyError: - raise NotSupportedError("Unknown data type code 0x%x. Have to skip" - " entire result set." % optid) - if typeclass in (ListType, SetType): - subtype = cls.read_type(f) - typeclass = typeclass.apply_parameters(subtype) - elif typeclass == MapType: - keysubtype = cls.read_type(f) - valsubtype = cls.read_type(f) - typeclass = typeclass.apply_parameters(keysubtype, valsubtype) - return typeclass - - @staticmethod - def recv_row(f, colcount): - return [read_value(f) for x in xrange(colcount)] - - -class PrepareMessage(_MessageType): - opcode = 0x09 - name = 'PREPARE' - params = ('query',) - - def send_body(self, f): - write_longstring(f, self.query) - - -class ExecuteMessage(_MessageType): - opcode = 0x0A - name = 'EXECUTE' - params = ('query_id', 'query_params', 'consistency_level',) - - def send_body(self, f): - write_string(f, self.query_id) - write_short(f, len(self.query_params)) - for param in self.query_params: - write_value(f, param) - write_consistency_level(f, self.consistency_level) - - -known_event_types = frozenset(( - 'TOPOLOGY_CHANGE', - 'STATUS_CHANGE', - 'SCHEMA_CHANGE' -)) - - -class RegisterMessage(_MessageType): - opcode = 0x0B - name = 'REGISTER' - params = ('event_list',) - - def send_body(self, f): - write_stringlist(f, self.event_list) - - -class EventMessage(_MessageType): - opcode = 0x0C - name = 'EVENT' - params = ('event_type', 'event_args') - - @classmethod - def recv_body(cls, f): - event_type = read_string(f).upper() - if event_type in known_event_types: - read_method = getattr(cls, 'recv_' + event_type.lower()) - return cls(event_type=event_type, event_args=read_method(f)) - raise NotSupportedError('Unknown event type %r' % event_type) - - @classmethod - def recv_topology_change(cls, f): - # "NEW_NODE" or "REMOVED_NODE" - change_type = read_string(f) - address = read_inet(f) - return dict(change_type=change_type, address=address) - - @classmethod - def recv_status_change(cls, f): - # "UP" or "DOWN" - change_type = read_string(f) - address = read_inet(f) - return dict(change_type=change_type, address=address) - - @classmethod - def recv_schema_change(cls, f): - # "CREATED", "DROPPED", or "UPDATED" - change_type = read_string(f) - keyspace = read_string(f) - table = read_string(f) - return dict(change_type=change_type, keyspace=keyspace, table=table) - - -def read_byte(f): - return int8_unpack(f.read(1)) - - -def write_byte(f, b): - f.write(int8_pack(b)) - - -def read_int(f): - return int32_unpack(f.read(4)) - - -def write_int(f, i): - f.write(int32_pack(i)) - - -def read_short(f): - return uint16_unpack(f.read(2)) - - -def write_short(f, s): - f.write(uint16_pack(s)) - - -def read_consistency_level(f): - return ConsistencyLevel.value_to_name[read_short(f)] - - -def write_consistency_level(f, cl): - write_short(f, cl) - - -def read_string(f): - size = read_short(f) - contents = f.read(size) - return contents.decode('utf8') - - -def read_binary_string(f): - size = read_short(f) - contents = f.read(size) - return contents - - -def write_string(f, s): - if isinstance(s, unicode): - s = s.encode('utf8') - write_short(f, len(s)) - f.write(s) - - -def read_longstring(f): - size = read_int(f) - contents = f.read(size) - return contents.decode('utf8') - - -def write_longstring(f, s): - if isinstance(s, unicode): - s = s.encode('utf8') - write_int(f, len(s)) - f.write(s) - - -def read_stringlist(f): - numstrs = read_short(f) - return [read_string(f) for x in xrange(numstrs)] - - -def write_stringlist(f, stringlist): - write_short(f, len(stringlist)) - for s in stringlist: - write_string(f, s) - - -def read_stringmap(f): - numpairs = read_short(f) - strmap = {} - for x in xrange(numpairs): - k = read_string(f) - strmap[k] = read_string(f) - return strmap - - -def write_stringmap(f, strmap): - write_short(f, len(strmap)) - for k, v in strmap.items(): - write_string(f, k) - write_string(f, v) - - -def read_stringmultimap(f): - numkeys = read_short(f) - strmmap = {} - for x in xrange(numkeys): - k = read_string(f) - strmmap[k] = read_stringlist(f) - return strmmap - - -def write_stringmultimap(f, strmmap): - write_short(f, len(strmmap)) - for k, v in strmmap.items(): - write_string(f, k) - write_stringlist(f, v) - - -def read_value(f): - size = read_int(f) - if size < 0: - return None - return f.read(size) - - -def write_value(f, v): - if v is None: - write_int(f, -1) - else: - write_int(f, len(v)) - f.write(v) - - -def read_inet(f): - size = read_byte(f) - addrbytes = f.read(size) - port = read_int(f) - if size == 4: - addrfam = socket.AF_INET - elif size == 16: - addrfam = socket.AF_INET6 - else: - raise InternalError("bad inet address: %r" % (addrbytes,)) - return (socket.inet_ntop(addrfam, addrbytes), port) - - -def write_inet(f, addrtuple): - addr, port = addrtuple - if ':' in addr: - addrfam = socket.AF_INET6 - else: - addrfam = socket.AF_INET - addrbytes = socket.inet_pton(addrfam, addr) - write_byte(f, len(addrbytes)) - f.write(addrbytes) - write_int(f, port) - - -def cql_quote(term): - if isinstance(term, unicode): - return "'%s'" % term.encode('utf8').replace("'", "''") - elif isinstance(term, (str, bool)): - return "'%s'" % str(term).replace("'", "''") - else: - return str(term) - - -def cql_encode_none(val): - return 'NULL' - - -def cql_encode_unicode(val): - return cql_quote(val.encode('utf-8')) - - -def cql_encode_str(val): - return cql_quote(val) - - -def cql_encode_object(val): - return str(val) - - -def cql_encode_datetime(val): - return "'%s'" % val.strftime('%Y-%m-%d %H:%M:%S-0000') - - -def cql_encode_date(val): - return "'%s'" % val.strftime('%Y-%m-%d-0000') - - -def cql_encode_sequence(val): - return '( %s )' % ' , '.join(cql_encoders.get(type(v), cql_encode_object)(v) - for v in val) - - -def cql_encode_map_collection(val): - return '{ %s }' % ' , '.join('%s : %s' % (cql_quote(k), cql_quote(v)) - for k, v in val.iteritems()) - - -def cql_encode_list_collection(val): - return '[ %s ]' % ' , '.join(map(cql_quote, val)) - - -def cql_encode_set_collection(val): - return '{ %s }' % ' , '.join(map(cql_quote, val)) - - -cql_encoders = { - float: cql_encode_object, - str: cql_encode_str, - unicode: cql_encode_unicode, - types.NoneType: cql_encode_none, - int: cql_encode_object, - long: cql_encode_object, - UUID: cql_encode_object, - datetime.datetime: cql_encode_datetime, - datetime.date: cql_encode_date, - dict: cql_encode_map_collection, - list: cql_encode_list_collection, - tuple: cql_encode_list_collection, - set: cql_encode_set_collection, - frozenset: cql_encode_set_collection, - types.GeneratorType: cql_encode_sequence -} diff --git a/cassandra/deserializers.pxd b/cassandra/deserializers.pxd new file mode 100644 index 0000000000..c8408a57b6 --- /dev/null +++ b/cassandra/deserializers.pxd @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.buffer cimport Buffer + +cdef class Deserializer: + # The cqltypes._CassandraType corresponding to this deserializer + cdef object cqltype + + # String may be empty, whereas other values may not be. + # Other values may be NULL, in which case the integer length + # of the binary data is negative. However, non-string types + # may also return a zero length for legacy reasons + # (see http://code.metager.de/source/xref/apache/cassandra/doc/native_protocol_v3.spec + # paragraph 6) + cdef bint empty_binary_ok + + cdef deserialize(self, Buffer *buf, int protocol_version) + # cdef deserialize(self, CString byts, protocol_version) + + +cdef inline object from_binary(Deserializer deserializer, + Buffer *buf, + int protocol_version): + if buf.size < 0: + return None + elif buf.size == 0 and not deserializer.empty_binary_ok: + return _ret_empty(deserializer, buf.size) + else: + return deserializer.deserialize(buf, protocol_version) + +cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx new file mode 100644 index 0000000000..c07d67be91 --- /dev/null +++ b/cassandra/deserializers.pyx @@ -0,0 +1,519 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from libc.stdint cimport int32_t, uint16_t + +include 'cython_marshal.pyx' +from cassandra.buffer cimport Buffer, to_bytes, slice_buffer +from cassandra.cython_utils cimport datetime_from_timestamp + +from cython.view cimport array as cython_array +from cassandra.tuple cimport tuple_new, tuple_set + +import socket +from decimal import Decimal +from uuid import UUID + +from cassandra import cqltypes +from cassandra import util + +cdef class Deserializer: + """Cython-based deserializer class for a cqltype""" + + def __init__(self, cqltype): + self.cqltype = cqltype + self.empty_binary_ok = cqltype.empty_binary_ok + + cdef deserialize(self, Buffer *buf, int protocol_version): + raise NotImplementedError + + +cdef class DesBytesType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return b"" + return to_bytes(buf) + +# this is to facilitate cqlsh integration, which requires bytearrays for BytesType +# It is switched in by simply overwriting DesBytesType: +# deserializers.DesBytesType = deserializers.DesBytesTypeByteArray +cdef class DesBytesTypeByteArray(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return bytearray() + return bytearray(buf.ptr[:buf.size]) + +# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html +cdef class DesDecimalType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef Buffer varint_buf + slice_buffer(buf, &varint_buf, 4, buf.size - 4) + + cdef int32_t scale = unpack_num[int32_t](buf) + unscaled = varint_unpack(&varint_buf) + + return Decimal('%de%d' % (unscaled, -scale)) + + +cdef class DesUUIDType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return UUID(bytes=to_bytes(buf)) + + +cdef class DesBooleanType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if unpack_num[int8_t](buf): + return True + return False + + +cdef class DesByteType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return unpack_num[int8_t](buf) + + +cdef class DesAsciiType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return "" + return to_bytes(buf).decode('ascii') + + +cdef class DesFloatType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return unpack_num[float](buf) + + +cdef class DesDoubleType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return unpack_num[double](buf) + + +cdef class DesLongType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return unpack_num[int64_t](buf) + + +cdef class DesInt32Type(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return unpack_num[int32_t](buf) + + +cdef class DesIntegerType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return varint_unpack(buf) + + +cdef class DesInetAddressType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef bytes byts = to_bytes(buf) + + # TODO: optimize inet_ntop, inet_ntoa + if buf.size == 16: + return util.inet_ntop(socket.AF_INET6, byts) + else: + # util.inet_pton could also handle, but this is faster + # since we've already determined the AF + return socket.inet_ntoa(byts) + + +cdef class DesCounterColumnType(DesLongType): + pass + + +cdef class DesDateType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef double timestamp = unpack_num[int64_t](buf) / 1000.0 + return datetime_from_timestamp(timestamp) + + +cdef class TimestampType(DesDateType): + pass + + +cdef class TimeUUIDType(DesDateType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return UUID(bytes=to_bytes(buf)) + + +# Values of the 'date'` type are encoded as 32-bit unsigned integers +# representing a number of days with epoch (January 1st, 1970) at the center of the +# range (2^31). +EPOCH_OFFSET_DAYS = 2 ** 31 + +cdef class DesSimpleDateType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + days = unpack_num[uint32_t](buf) - EPOCH_OFFSET_DAYS + return util.Date(days) + + +cdef class DesShortType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return unpack_num[int16_t](buf) + + +cdef class DesTimeType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return util.Time(unpack_num[int64_t](buf)) + + +cdef class DesUTF8Type(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return "" + cdef val = to_bytes(buf) + return val.decode('utf8') + + +cdef class DesVarcharType(DesUTF8Type): + pass + + +cdef class _DesParameterizedType(Deserializer): + + cdef object subtypes + cdef Deserializer[::1] deserializers + cdef Py_ssize_t subtypes_len + + def __init__(self, cqltype): + super().__init__(cqltype) + self.subtypes = cqltype.subtypes + self.deserializers = make_deserializers(cqltype.subtypes) + self.subtypes_len = len(self.subtypes) + + +cdef class _DesSingleParamType(_DesParameterizedType): + cdef Deserializer deserializer + + def __init__(self, cqltype): + assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes + super().__init__(cqltype) + self.deserializer = self.deserializers[0] + + +#-------------------------------------------------------------------------- +# List and set deserialization + +cdef class DesListType(_DesSingleParamType): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef uint16_t v2_and_below = 2 + cdef int32_t v3_and_above = 3 + + if protocol_version >= 3: + result = _deserialize_list_or_set[int32_t]( + v3_and_above, buf, protocol_version, self.deserializer) + else: + result = _deserialize_list_or_set[uint16_t]( + v2_and_below, buf, protocol_version, self.deserializer) + + return result + +cdef class DesSetType(DesListType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) + + +ctypedef fused itemlen_t: + uint16_t # protocol <= v2 + int32_t # protocol >= v3 + +cdef list _deserialize_list_or_set(itemlen_t dummy_version, + Buffer *buf, int protocol_version, + Deserializer deserializer): + """ + Deserialize a list or set. + + The 'dummy' parameter is needed to make fused types work, so that + we can specialize on the protocol version. + """ + cdef Buffer itemlen_buf + cdef Buffer elem_buf + + cdef itemlen_t numelements + cdef int offset + cdef list result = [] + + _unpack_len[itemlen_t](buf, 0, &numelements) + offset = sizeof(itemlen_t) + protocol_version = max(3, protocol_version) + for _ in range(numelements): + subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version) + result.append(from_binary(deserializer, &elem_buf, protocol_version)) + + return result + + +cdef inline int subelem( + Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1: + """ + Read the next element from the buffer: first read the size (in bytes) of the + element, then fill elem_buf with a newly sliced buffer of this size (and the + right offset). + """ + cdef itemlen_t elemlen + + _unpack_len[itemlen_t](buf, offset[0], &elemlen) + offset[0] += sizeof(itemlen_t) + slice_buffer(buf, elem_buf, offset[0], elemlen) + offset[0] += elemlen + return 0 + + +cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1: + cdef Buffer itemlen_buf + slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t)) + + if itemlen_t is uint16_t: + output[0] = unpack_num[uint16_t](&itemlen_buf) + else: + output[0] = unpack_num[int32_t](&itemlen_buf) + + return 0 + +#-------------------------------------------------------------------------- +# Map deserialization + +cdef class DesMapType(_DesParameterizedType): + + cdef Deserializer key_deserializer, val_deserializer + + def __init__(self, cqltype): + super().__init__(cqltype) + self.key_deserializer = self.deserializers[0] + self.val_deserializer = self.deserializers[1] + + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef uint16_t v2_and_below = 0 + cdef int32_t v3_and_above = 0 + key_type, val_type = self.cqltype.subtypes + + if protocol_version >= 3: + result = _deserialize_map[int32_t]( + v3_and_above, buf, protocol_version, + self.key_deserializer, self.val_deserializer, + key_type, val_type) + else: + result = _deserialize_map[uint16_t]( + v2_and_below, buf, protocol_version, + self.key_deserializer, self.val_deserializer, + key_type, val_type) + + return result + + +cdef _deserialize_map(itemlen_t dummy_version, + Buffer *buf, int protocol_version, + Deserializer key_deserializer, Deserializer val_deserializer, + key_type, val_type): + cdef Buffer key_buf, val_buf + cdef Buffer itemlen_buf + + cdef itemlen_t numelements + cdef int offset + cdef list result = [] + + _unpack_len[itemlen_t](buf, 0, &numelements) + offset = sizeof(itemlen_t) + themap = util.OrderedMapSerializedKey(key_type, protocol_version) + protocol_version = max(3, protocol_version) + for _ in range(numelements): + subelem[itemlen_t](buf, &key_buf, &offset, dummy_version) + subelem[itemlen_t](buf, &val_buf, &offset, numelements) + key = from_binary(key_deserializer, &key_buf, protocol_version) + val = from_binary(val_deserializer, &val_buf, protocol_version) + themap._insert_unchecked(key, to_bytes(&key_buf), val) + + return themap + +#-------------------------------------------------------------------------- + +cdef class DesTupleType(_DesParameterizedType): + + # TODO: Use TupleRowParser to parse these tuples + + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef Py_ssize_t i, p + cdef int32_t itemlen + cdef tuple res = tuple_new(self.subtypes_len) + cdef Buffer item_buf + cdef Buffer itemlen_buf + cdef Deserializer deserializer + + # collections inside UDTs are always encoded with at least the + # version 3 format + protocol_version = max(3, protocol_version) + + p = 0 + values = [] + for i in range(self.subtypes_len): + item = None + if p < buf.size: + slice_buffer(buf, &itemlen_buf, p, 4) + itemlen = unpack_num[int32_t](&itemlen_buf) + p += 4 + if itemlen >= 0: + slice_buffer(buf, &item_buf, p, itemlen) + p += itemlen + + deserializer = self.deserializers[i] + item = from_binary(deserializer, &item_buf, protocol_version) + + tuple_set(res, i, item) + + return res + + +cdef class DesUserType(DesTupleType): + cdef deserialize(self, Buffer *buf, int protocol_version): + typ = self.cqltype + values = DesTupleType.deserialize(self, buf, protocol_version) + if typ.mapped_class: + return typ.mapped_class(**dict(zip(typ.fieldnames, values))) + elif typ.tuple_type: + return typ.tuple_type(*values) + else: + return tuple(values) + + +cdef class DesCompositeType(_DesParameterizedType): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef Py_ssize_t i, idx, start + cdef Buffer elem_buf + cdef int16_t element_length + cdef Deserializer deserializer + cdef tuple res = tuple_new(self.subtypes_len) + + idx = 0 + for i in range(self.subtypes_len): + if not buf.size: + # CompositeType can have missing elements at the end + + # Fill the tuple with None values and slice it + # + # (I'm not sure a tuple needs to be fully initialized before + # it can be destroyed, so play it safe) + for j in range(i, self.subtypes_len): + tuple_set(res, j, None) + res = res[:i] + break + + element_length = unpack_num[uint16_t](buf) + slice_buffer(buf, &elem_buf, 2, element_length) + + deserializer = self.deserializers[i] + item = from_binary(deserializer, &elem_buf, protocol_version) + tuple_set(res, i, item) + + # skip element length, element, and the EOC (one byte) + start = 2 + element_length + 1 + slice_buffer(buf, buf, start, buf.size - start) + + return res + + +DesDynamicCompositeType = DesCompositeType + + +cdef class DesReversedType(_DesSingleParamType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return from_binary(self.deserializer, buf, protocol_version) + + +cdef class DesFrozenType(_DesSingleParamType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return from_binary(self.deserializer, buf, protocol_version) + +#-------------------------------------------------------------------------- + +cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size): + """ + Decide whether to return None or EMPTY when a buffer size is + zero or negative. This is used by from_binary in deserializers.pxd. + """ + if buf_size < 0: + return None + elif deserializer.cqltype.support_empty_values: + return cqltypes.EMPTY + else: + return None + +#-------------------------------------------------------------------------- +# Generic deserialization + +cdef class GenericDeserializer(Deserializer): + """ + Wrap a generic datatype for deserialization + """ + + cdef deserialize(self, Buffer *buf, int protocol_version): + return self.cqltype.deserialize(to_bytes(buf), protocol_version) + + def __repr__(self): + return "GenericDeserializer(%s)" % (self.cqltype,) + +#-------------------------------------------------------------------------- +# Helper utilities + +def make_deserializers(cqltypes): + """Create an array of Deserializers for each given cqltype in cqltypes""" + cdef Deserializer[::1] deserializers + return obj_array([find_deserializer(ct) for ct in cqltypes]) + + +cdef dict classes = globals() + +cpdef Deserializer find_deserializer(cqltype): + """Find a deserializer for a cqltype""" + name = 'Des' + cqltype.__name__ + + if name in globals(): + cls = classes[name] + elif issubclass(cqltype, cqltypes.ListType): + cls = DesListType + elif issubclass(cqltype, cqltypes.SetType): + cls = DesSetType + elif issubclass(cqltype, cqltypes.MapType): + cls = DesMapType + elif issubclass(cqltype, cqltypes.UserType): + # UserType is a subclass of TupleType, so should precede it + cls = DesUserType + elif issubclass(cqltype, cqltypes.TupleType): + cls = DesTupleType + elif issubclass(cqltype, cqltypes.DynamicCompositeType): + # DynamicCompositeType is a subclass of CompositeType, so should precede it + cls = DesDynamicCompositeType + elif issubclass(cqltype, cqltypes.CompositeType): + cls = DesCompositeType + elif issubclass(cqltype, cqltypes.ReversedType): + cls = DesReversedType + elif issubclass(cqltype, cqltypes.FrozenType): + cls = DesFrozenType + else: + cls = GenericDeserializer + + return cls(cqltype) + + +def obj_array(list objs): + """Create a (Cython) array of objects given a list of objects""" + cdef object[:] arr + cdef Py_ssize_t i + arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O") + # arr[:] = objs # This does not work (segmentation faults) + for i, obj in enumerate(objs): + arr[i] = obj + return arr diff --git a/cassandra/encoder.py b/cassandra/encoder.py new file mode 100644 index 0000000000..94093e85b6 --- /dev/null +++ b/cassandra/encoder.py @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +These functions are used to convert Python objects into CQL strings. +When non-prepared statements are executed, these encoder functions are +called on each query parameter. +""" + +import logging +log = logging.getLogger(__name__) + +from binascii import hexlify +from decimal import Decimal +import calendar +import datetime +import math +import sys +import types +from uuid import UUID +import ipaddress + +from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey, + sortedset, Time, Date, Point, LineString, Polygon) + + +def cql_quote(term): + if isinstance(term, str): + return "'%s'" % str(term).replace("'", "''") + else: + return str(term) + + +class ValueSequence(list): + pass + + +class Encoder(object): + """ + A container for mapping python types to CQL string literals when working + with non-prepared statements. The type :attr:`~.Encoder.mapping` can be + directly customized by users. + """ + + mapping = None + """ + A map of python types to encoder functions. + """ + + def __init__(self): + self.mapping = { + float: self.cql_encode_float, + Decimal: self.cql_encode_decimal, + bytearray: self.cql_encode_bytes, + str: self.cql_encode_str, + int: self.cql_encode_object, + UUID: self.cql_encode_object, + datetime.datetime: self.cql_encode_datetime, + datetime.date: self.cql_encode_date, + datetime.time: self.cql_encode_time, + Date: self.cql_encode_date_ext, + Time: self.cql_encode_time, + dict: self.cql_encode_map_collection, + OrderedDict: self.cql_encode_map_collection, + OrderedMap: self.cql_encode_map_collection, + OrderedMapSerializedKey: self.cql_encode_map_collection, + list: self.cql_encode_list_collection, + tuple: self.cql_encode_list_collection, # TODO: change to tuple in next major + set: self.cql_encode_set_collection, + sortedset: self.cql_encode_set_collection, + frozenset: self.cql_encode_set_collection, + types.GeneratorType: self.cql_encode_list_collection, + ValueSequence: self.cql_encode_sequence, + Point: self.cql_encode_str_quoted, + LineString: self.cql_encode_str_quoted, + Polygon: self.cql_encode_str_quoted + } + + self.mapping.update({ + memoryview: self.cql_encode_bytes, + bytes: self.cql_encode_bytes, + type(None): self.cql_encode_none, + ipaddress.IPv4Address: self.cql_encode_ipaddress, + ipaddress.IPv6Address: self.cql_encode_ipaddress + }) + + def cql_encode_none(self, val): + """ + Converts :const:`None` to the string 'NULL'. + """ + return 'NULL' + + def cql_encode_unicode(self, val): + """ + Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. + """ + return cql_quote(val.encode('utf-8')) + + def cql_encode_str(self, val): + """ + Escapes quotes in :class:`str` objects. + """ + return cql_quote(val) + + def cql_encode_str_quoted(self, val): + return "'%s'" % val + + def cql_encode_bytes(self, val): + return (b'0x' + hexlify(val)).decode('utf-8') + + def cql_encode_object(self, val): + """ + Default encoder for all objects that do not have a specific encoder function + registered. This function simply calls :meth:`str()` on the object. + """ + return str(val) + + def cql_encode_float(self, val): + """ + Encode floats using repr to preserve precision + """ + if math.isinf(val): + return 'Infinity' if val > 0 else '-Infinity' + elif math.isnan(val): + return 'NaN' + else: + return repr(val) + + def cql_encode_datetime(self, val): + """ + Converts a :class:`datetime.datetime` object to a (string) integer timestamp + with millisecond precision. + """ + timestamp = calendar.timegm(val.utctimetuple()) + return str(int(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3)) + + def cql_encode_date(self, val): + """ + Converts a :class:`datetime.date` object to a string with format + ``YYYY-MM-DD``. + """ + return "'%s'" % val.strftime('%Y-%m-%d') + + def cql_encode_time(self, val): + """ + Converts a :class:`cassandra.util.Time` object to a string with format + ``HH:MM:SS.mmmuuunnn``. + """ + return "'%s'" % val + + def cql_encode_date_ext(self, val): + """ + Encodes a :class:`cassandra.util.Date` object as an integer + """ + # using the int form in case the Date exceeds datetime.[MIN|MAX]YEAR + return str(val.days_from_epoch + 2 ** 31) + + def cql_encode_sequence(self, val): + """ + Converts a sequence to a string of the form ``(item1, item2, ...)``. This + is suitable for ``IN`` value lists. + """ + return '(%s)' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) + for v in val) + + cql_encode_tuple = cql_encode_sequence + """ + Converts a sequence to a string of the form ``(item1, item2, ...)``. This + is suitable for ``tuple`` type columns. + """ + + def cql_encode_map_collection(self, val): + """ + Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. + This is suitable for ``map`` type columns. + """ + return '{%s}' % ', '.join('%s: %s' % ( + self.mapping.get(type(k), self.cql_encode_object)(k), + self.mapping.get(type(v), self.cql_encode_object)(v) + ) for k, v in val.items()) + + def cql_encode_list_collection(self, val): + """ + Converts a sequence to a string of the form ``[item1, item2, ...]``. This + is suitable for ``list`` type columns. + """ + return '[%s]' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + + def cql_encode_set_collection(self, val): + """ + Converts a sequence to a string of the form ``{item1, item2, ...}``. This + is suitable for ``set`` type columns. + """ + return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + + def cql_encode_all_types(self, val, as_text_type=False): + """ + Converts any type into a CQL string, defaulting to ``cql_encode_object`` + if :attr:`~Encoder.mapping` does not contain an entry for the type. + """ + encoded = self.mapping.get(type(val), self.cql_encode_object)(val) + if as_text_type and not isinstance(encoded, str): + return encoded.decode('utf-8') + return encoded + + def cql_encode_ipaddress(self, val): + """ + Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This + is suitable for ``inet`` type columns. + """ + return "'%s'" % val.compressed + + def cql_encode_decimal(self, val): + return self.cql_encode_float(float(val)) \ No newline at end of file diff --git a/cassandra/graph/__init__.py b/cassandra/graph/__init__.py new file mode 100644 index 0000000000..1d33345aad --- /dev/null +++ b/cassandra/graph/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph import * \ No newline at end of file diff --git a/cassandra/graph/graphson.py b/cassandra/graph/graphson.py new file mode 100644 index 0000000000..576d5063fe --- /dev/null +++ b/cassandra/graph/graphson.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph.graphson import * diff --git a/cassandra/graph/query.py b/cassandra/graph/query.py new file mode 100644 index 0000000000..9003fe280f --- /dev/null +++ b/cassandra/graph/query.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph.query import * diff --git a/cassandra/graph/types.py b/cassandra/graph/types.py new file mode 100644 index 0000000000..53febe7e9c --- /dev/null +++ b/cassandra/graph/types.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph.types import * diff --git a/cassandra/io/__init__.py b/cassandra/io/__init__.py index e69de29bb2..588a655d98 100644 --- a/cassandra/io/__init__.py +++ b/cassandra/io/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py new file mode 100644 index 0000000000..007e10d5c4 --- /dev/null +++ b/cassandra/io/asyncioreactor.py @@ -0,0 +1,221 @@ +from cassandra.connection import Connection, ConnectionShutdown + +import asyncio +import logging +import os +import socket +import ssl +from threading import Lock, Thread, get_ident + + +log = logging.getLogger(__name__) + + +# This module uses ``yield from`` and ``@asyncio.coroutine`` over ``await`` and +# ``async def`` for pre-Python-3.5 compatibility, so keep in mind that the +# managed coroutines are generator-based, not native coroutines. See PEP 492: +# https://www.python.org/dev/peps/pep-0492/#coroutine-objects + + +try: + asyncio.run_coroutine_threadsafe +except AttributeError: + raise ImportError( + 'Cannot use asyncioreactor without access to ' + 'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)' + ) + + +class AsyncioTimer(object): + """ + An ``asyncioreactor``-specific Timer. Similar to :class:`.connection.Timer, + but with a slightly different API due to limitations in the underlying + ``call_later`` interface. Not meant to be used with a + :class:`.connection.TimerManager`. + """ + + @property + def end(self): + raise NotImplementedError('{} is not compatible with TimerManager and ' + 'does not implement .end()') + + def __init__(self, timeout, callback, loop): + delayed = self._call_delayed_coro(timeout=timeout, + callback=callback) + self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop) + + @staticmethod + async def _call_delayed_coro(timeout, callback): + await asyncio.sleep(timeout) + return callback() + + def __lt__(self, other): + try: + return self._handle < other._handle + except AttributeError: + raise NotImplemented + + def cancel(self): + self._handle.cancel() + + def finish(self): + # connection.Timer method not implemented here because we can't inspect + # the Handle returned from call_later + raise NotImplementedError('{} is not compatible with TimerManager and ' + 'does not implement .finish()') + + +class AsyncioConnection(Connection): + """ + An experimental implementation of :class:`.Connection` that uses the + ``asyncio`` module in the Python standard library for its event loop. + + Note that it requires ``asyncio`` features that were only introduced in the + 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. + """ + + _loop = None + _pid = os.getpid() + + _lock = Lock() + _loop_thread = None + + _write_queue = None + _write_queue_lock = None + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self._connect_socket() + self._socket.setblocking(0) + + self._write_queue = asyncio.Queue() + self._write_queue_lock = asyncio.Lock() + + # see initialize_reactor -- loop is running in a separate thread, so we + # have to use a threadsafe call + self._read_watcher = asyncio.run_coroutine_threadsafe( + self.handle_read(), loop=self._loop + ) + self._write_watcher = asyncio.run_coroutine_threadsafe( + self.handle_write(), loop=self._loop + ) + self._send_options_message() + + @classmethod + def initialize_reactor(cls): + with cls._lock: + if cls._pid != os.getpid(): + cls._loop = None + if cls._loop is None: + cls._loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._loop) + + if not cls._loop_thread: + # daemonize so the loop will be shut down on interpreter + # shutdown + cls._loop_thread = Thread(target=cls._loop.run_forever, + daemon=True, name="asyncio_thread") + cls._loop_thread.start() + + @classmethod + def create_timer(cls, timeout, callback): + return AsyncioTimer(timeout, callback, loop=cls._loop) + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + # close from the loop thread to avoid races when removing file + # descriptors + asyncio.run_coroutine_threadsafe( + self._close(), loop=self._loop + ) + + async def _close(self): + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) + if self._write_watcher: + self._write_watcher.cancel() + if self._read_watcher: + self._read_watcher.cancel() + if self._socket: + self._loop.remove_writer(self._socket.fileno()) + self._loop.remove_reader(self._socket.fileno()) + self._socket.close() + + log.debug("Closed socket to %s" % (self.endpoint,)) + + if not self.is_defunct: + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def push(self, data): + buff_size = self.out_buffer_size + if len(data) > buff_size: + chunks = [] + for i in range(0, len(data), buff_size): + chunks.append(data[i:i + buff_size]) + else: + chunks = [data] + + if self._loop_thread.ident != get_ident(): + asyncio.run_coroutine_threadsafe( + self._push_msg(chunks), + loop=self._loop + ) + else: + # avoid races/hangs by just scheduling this, not using threadsafe + self._loop.create_task(self._push_msg(chunks)) + + async def _push_msg(self, chunks): + # This lock ensures all chunks of a message are sequential in the Queue + async with self._write_queue_lock: + for chunk in chunks: + self._write_queue.put_nowait(chunk) + + + async def handle_write(self): + while True: + try: + next_msg = await self._write_queue.get() + if next_msg: + await self._loop.sock_sendall(self._socket, next_msg) + except socket.error as err: + log.debug("Exception in send for %s: %s", self, err) + self.defunct(err) + return + except asyncio.CancelledError: + return + + async def handle_read(self): + while True: + try: + buf = await self._loop.sock_recv(self._socket, self.in_buffer_size) + self._iobuf.write(buf) + # sock_recv expects EWOULDBLOCK if socket provides no data, but + # nonblocking ssl sockets raise these instead, so we handle them + # ourselves by yielding to the event loop, where the socket will + # get the reading/writing it "wants" before retrying + except (ssl.SSLWantWriteError, ssl.SSLWantReadError): + # Apparently the preferred way to yield to the event loop from within + # a native coroutine based on https://github.com/python/asyncio/issues/284 + await asyncio.sleep(0) + continue + except socket.error as err: + log.debug("Exception during socket recv for %s: %s", + self, err) + self.defunct(err) + return # leave the read loop + except asyncio.CancelledError: + return + + if buf and self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 60fa2484b0..e1bcafb39e 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -1,139 +1,374 @@ -from collections import defaultdict, deque +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import atexit +from collections import deque from functools import partial import logging import os import socket import sys -from threading import Event, Lock, Thread -import traceback -from Queue import Queue -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode +from threading import Lock, Thread, Event +import time +import weakref +import sys +import ssl + +try: + from weakref import WeakSet +except ImportError: + from cassandra.util import WeakSet # noqa + +from cassandra import DependencyException +try: + import asyncore +except ModuleNotFoundError: + raise DependencyException( + "Unable to import asyncore module. Note that this module has been removed in Python 3.12 " + "so when using the driver with this version (or anything newer) you will need to use one of the " + "other event loop implementations." + ) -import asyncore +from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager -from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown, - ConnectionBusy, ConnectionException, NONBLOCKING) -from cassandra.decoder import RegisterMessage -from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) -_loop_started = False -_loop_lock = Lock() +_dispatcher_map = {} -_starting_conns = set() -_starting_conns_lock = Lock() +def _cleanup(loop): + if loop: + loop._cleanup() -def _run_loop(): - global _loop_started - log.debug("Starting asyncore event loop") - with _loop_lock: - while True: - try: - asyncore.loop(timeout=0.001, use_poll=True, count=None) - except Exception: - log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) - break - with _starting_conns_lock: - if not _starting_conns: +class WaitableTimer(Timer): + def __init__(self, timeout, callback): + Timer.__init__(self, timeout, callback) + self.callback = callback + self.event = Event() + + self.final_exception = None + + def finish(self, time_now): + try: + finished = Timer.finish(self, time_now) + if finished: + self.event.set() + return True + return False + + except Exception as e: + self.final_exception = e + self.event.set() + return True + + def wait(self, timeout=None): + self.event.wait(timeout) + if self.final_exception: + raise self.final_exception + + +class _PipeWrapper(object): + + def __init__(self, fd): + self.fd = fd + + def fileno(self): + return self.fd + + def close(self): + os.close(self.fd) + + def getsockopt(self, level, optname, buflen=None): + # act like an unerrored socket for the asyncore error handling + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError() + + +class _AsyncoreDispatcher(asyncore.dispatcher): + + def __init__(self, socket): + asyncore.dispatcher.__init__(self, map=_dispatcher_map) + # inject after to avoid base class validation + self.set_socket(socket) + self._notified = False + + def writable(self): + return False + + def validate(self): + assert not self._notified + self.notify_loop() + assert self._notified + self.loop(0.1) + assert not self._notified + + def loop(self, timeout): + asyncore.loop(timeout=timeout, use_poll=True, map=_dispatcher_map, count=1) + + +class _AsyncorePipeDispatcher(_AsyncoreDispatcher): + + def __init__(self): + self.read_fd, self.write_fd = os.pipe() + _AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd)) + + def writable(self): + return False + + def handle_read(self): + while len(os.read(self.read_fd, 4096)) == 4096: + pass + self._notified = False + + def notify_loop(self): + if not self._notified: + self._notified = True + os.write(self.write_fd, b'x') + + +class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): + """ + Experimental alternate dispatcher for avoiding busy wait in the asyncore loop. It is not used by default because + it relies on local port binding. + Port scanning is not implemented, so multiple clients on one host will collide. This address would need to be set per + instance, or this could be specialized to scan until an address is found. + + To use:: + + from cassandra.io.asyncorereactor import _AsyncoreUDPDispatcher, AsyncoreLoop + AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher + + """ + bind_address = ('localhost', 10000) + + def __init__(self): + self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._socket.bind(self.bind_address) + self._socket.setblocking(0) + _AsyncoreDispatcher.__init__(self, self._socket) + + def handle_read(self): + try: + d = self._socket.recvfrom(1) + while d and d[1]: + d = self._socket.recvfrom(1) + except socket.error as e: + pass + self._notified = False + + def notify_loop(self): + if not self._notified: + self._notified = True + self._socket.sendto(b'', self.bind_address) + + def loop(self, timeout): + asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1) + + +class _BusyWaitDispatcher(object): + + max_write_latency = 0.001 + """ + Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check + if anything is writable. + """ + + def notify_loop(self): + pass + + def loop(self, timeout): + if not _dispatcher_map: + time.sleep(0.005) + count = timeout // self.max_write_latency + asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count) + + def validate(self): + pass + + def close(self): + pass + + +class AsyncoreLoop(object): + + timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts + + _loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher + + def __init__(self): + self._pid = os.getpid() + self._loop_lock = Lock() + self._started = False + self._shutdown = False + + self._thread = None + + self._timers = TimerManager() + + try: + dispatcher = self._loop_dispatch_class() + dispatcher.validate() + log.debug("Validated loop dispatch with %s", self._loop_dispatch_class) + except Exception: + log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class) + dispatcher.close() + dispatcher = _BusyWaitDispatcher() + self._loop_dispatcher = dispatcher + + def maybe_start(self): + should_start = False + did_acquire = False + try: + did_acquire = self._loop_lock.acquire(False) + if did_acquire and not self._started: + self._started = True + should_start = True + finally: + if did_acquire: + self._loop_lock.release() + + if should_start: + self._thread = Thread(target=self._run_loop, name="asyncore_cassandra_driver_event_loop") + self._thread.daemon = True + self._thread.start() + + def wake_loop(self): + self._loop_dispatcher.notify_loop() + + def _run_loop(self): + log.debug("Starting asyncore event loop") + with self._loop_lock: + while not self._shutdown: + try: + self._loop_dispatcher.loop(self.timer_resolution) + self._timers.service_timeouts() + except Exception as exc: + self._maybe_log_debug("Asyncore event loop stopped unexpectedly", exc_info=exc) break + self._started = False + + self._maybe_log_debug("Asyncore event loop ended") + + def _maybe_log_debug(self, *args, **kwargs): + try: + log.debug(*args, **kwargs) + except Exception: + # TODO: Remove when Python 2 support is removed + # PYTHON-1266. If our logger has disappeared, there's nothing we + # can do, so just log nothing. + pass + + def add_timer(self, timer): + self._timers.add_timer(timer) + + # This function is called from a different thread than the event loop + # thread, so for this call to be thread safe, we must wake up the loop + # in case it's stuck at a select + self.wake_loop() - _loop_started = False - if log: - # this can happen during interpreter shutdown - log.debug("Asyncore event loop ended") + def _cleanup(self): + global _dispatcher_map + self._shutdown = True + if not self._thread: + return + + log.debug("Waiting for event loop thread to join...") + self._thread.join(timeout=1.0) + if self._thread.is_alive(): + log.warning( + "Event loop thread could not be joined, so shutdown may not be clean. " + "Please call Cluster.shutdown() to avoid this.") + + log.debug("Event loop thread was joined") + + # Ensure all connections are closed and in-flight requests cancelled + for conn in tuple(_dispatcher_map.values()): + if conn is not self._loop_dispatcher: + conn.close() + self._timers.service_timeouts() + # Once all the connections are closed, close the dispatcher + self._loop_dispatcher.close() -def _start_loop(): - global _loop_started - should_start = False - did_acquire = False - try: - did_acquire = _loop_lock.acquire(False) - if did_acquire and not _loop_started: - _loop_started = True - should_start = True - finally: - if did_acquire: - _loop_lock.release() + log.debug("Dispatchers were closed") - if should_start: - t = Thread(target=_run_loop, name="event_loop") - t.daemon = True - t.start() + +_global_loop = None +atexit.register(partial(_cleanup, _global_loop)) class AsyncoreConnection(Connection, asyncore.dispatcher): """ - An implementation of :class:`.Connection` that utilizes the ``asyncore`` + An implementation of :class:`.Connection` that uses the ``asyncore`` module in the Python standard library for its event loop. """ - _buf = "" - _total_reqd_bytes = 0 _writable = False _readable = False - _have_listeners = False @classmethod - def factory(cls, *args, **kwargs): - conn = cls(*args, **kwargs) - conn.connected_event.wait() - if conn.last_error: - raise conn.last_error + def initialize_reactor(cls): + global _global_loop + if not _global_loop: + _global_loop = AsyncoreLoop() else: - return conn + current_pid = os.getpid() + if _global_loop._pid != current_pid: + log.debug("Detected fork, clearing and reinitializing reactor state") + cls.handle_fork() + _global_loop = AsyncoreLoop() + + @classmethod + def handle_fork(cls): + global _dispatcher_map, _global_loop + _dispatcher_map = {} + if _global_loop: + _global_loop._cleanup() + _global_loop = None + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + _global_loop.add_timer(timer) + return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - asyncore.dispatcher.__init__(self) - - self.connected_event = Event() - self._callbacks = {} - self._push_watchers = defaultdict(set) self.deque = deque() + self.deque_lock = Lock() - with _starting_conns_lock: - _starting_conns.add(self) + self._connect_socket() - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - self.connect((self.host, self.port)) + # start the event loop if needed + _global_loop.maybe_start() - if self.sockopts: - for args in self.sockopts: - self.socket.setsockopt(*args) + init_handler = WaitableTimer( + timeout=0, + callback=partial(asyncore.dispatcher.__init__, + self, self._socket, _dispatcher_map) + ) + _global_loop.add_timer(init_handler) + init_handler.wait(kwargs["connect_timeout"]) self._writable = True self._readable = True - # start the global event loop if needed - _start_loop() - - def create_socket(self, family, type): - # copied from asyncore, but with the line to set the socket in - # non-blocking mode removed (we will do that after connecting) - self.family_and_type = family, type - sock = socket.socket(family, type) - self.set_socket(sock) - - def connect(self, address): - # this is copied directly from asyncore.py, except that - # a timeout is set before connecting - self.connected = False - self.connecting = True - self.socket.settimeout(1.0) - err = self.socket.connect_ex(address) - if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \ - or err == EINVAL and os.name in ('nt', 'ce'): - raise ConnectionException("Timed out connecting to %s" % (address[0])) - if err in (0, EISCONN): - self.addr = address - self.setblocking(0) - self.handle_connect_event() - else: - raise socket.error(err, errorcode[err]) + self._send_options_message() def close(self): with self.lock: @@ -141,179 +376,103 @@ def close(self): return self.is_closed = True - log.debug("Closing connection to %s" % (self.host,)) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) self._writable = False self._readable = False - asyncore.dispatcher.close(self) - log.debug("Closed socket to %s" % (self.host,)) - - with _starting_conns_lock: - _starting_conns.discard(self) - - # don't leave in-progress operations hanging - self.connected_event.set() - if not self.is_defunct: - self._error_all_callbacks( - ConnectionShutdown("Connection to %s was closed" % self.host)) - - def __del__(self): - try: - self.close() - except TypeError: - pass - def defunct(self, exc): - with self.lock: - if self.is_defunct: - return - self.is_defunct = True + # We don't have to wait for this to be closed, we can just schedule it + self.create_timer(0, partial(asyncore.dispatcher.close, self)) - trace = traceback.format_exc(exc) - if trace != "None": - log.debug("Defuncting connection to %s: %s\n%s", - self.host, exc, traceback.format_exc(exc)) - else: - log.debug("Defuncting connection to %s: %s", self.host, exc) + log.debug("Closed socket to %s", self.endpoint) - self.last_error = exc - self._error_all_callbacks(exc) - self.connected_event.set() - return exc + if not self.is_defunct: + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) - def _error_all_callbacks(self, exc): - new_exc = ConnectionShutdown(str(exc)) - for cb in self._callbacks.values(): - cb(new_exc) + #This happens when the connection is shutdown while waiting for the ReadyMessage + if not self.connected_event.is_set(): + self.last_error = ConnectionShutdown("Connection to %s was closed" % self.endpoint) - def handle_connect(self): - with _starting_conns_lock: - _starting_conns.discard(self) - self._send_options_message() + # don't leave in-progress operations hanging + self.connected_event.set() def handle_error(self): self.defunct(sys.exc_info()[1]) def handle_close(self): - log.debug("connection closed by server") + log.debug("Connection %s closed by server", self) self.close() def handle_write(self): - try: - next_msg = self.deque.popleft() - except IndexError: - self._writable = False - return + while True: + with self.deque_lock: + try: + next_msg = self.deque.popleft() + except IndexError: + self._writable = False + return - try: - sent = self.send(next_msg) - except socket.error as err: - if (err.args[0] in NONBLOCKING): - self.deque.appendleft(next_msg) + try: + sent = self.send(next_msg) + self._readable = True + except socket.error as err: + if (err.args[0] in NONBLOCKING or + err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): + with self.deque_lock: + self.deque.appendleft(next_msg) + else: + self.defunct(err) + return else: - self.defunct(err) - return - else: - if sent < len(next_msg): - self.deque.appendleft(next_msg[sent:]) - - if not self.deque: - self._writable = False - - self._readable = True + if sent < len(next_msg): + with self.deque_lock: + self.deque.appendleft(next_msg[sent:]) + if sent == 0: + return def handle_read(self): try: - buf = self.recv(self.in_buffer_size) - except socket.error as err: - if err.args[0] not in NONBLOCKING: - self.defunct(err) - return - - if buf: - self._buf += buf while True: - if len(self._buf) < 8: - # we don't have a complete header yet - break - elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes: - # we already saw a header, but we don't have a complete message yet + buf = self.recv(self.in_buffer_size) + self._iobuf.write(buf) + if len(buf) < self.in_buffer_size: break + except socket.error as err: + if isinstance(err, ssl.SSLError): + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if not self._iobuf.tell(): + return else: - body_len = int32_unpack(self._buf[4:8]) - if len(self._buf) - 8 >= body_len: - msg = self._buf[:8 + body_len] - self._buf = self._buf[8 + body_len:] - self._total_reqd_bytes = 0 - self.process_msg(msg, body_len) - else: - self._total_reqd_bytes = body_len + 8 - - if not self._callbacks: - self._readable = False - else: - self.close() + self.defunct(err) + return + elif err.args[0] in NONBLOCKING: + if not self._iobuf.tell(): + return + else: + self.defunct(err) + return - def handle_pushed(self, response): - log.debug("Message pushed from server: %r", response) - for cb in self._push_watchers.get(response.event_type, []): - try: - cb(response.event_args) - except Exception: - log.exception("Pushed event handler errored, ignoring:") + if self._iobuf.tell(): + self.process_io_buffer() + if not self._requests and not self.is_control_connection: + self._readable = False def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] - for i in xrange(0, len(data), sabs): + for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] - with self.lock: + with self.deque_lock: self.deque.extend(chunks) - - self._writable = True + self._writable = True + _global_loop.wake_loop() def writable(self): return self._writable def readable(self): - return self._readable or (self._have_listeners and not (self.is_defunct or self.is_closed)) - - def send_msg(self, msg, cb): - if self.is_defunct: - raise ConnectionShutdown("Connection to %s is defunct" % self.host) - elif self.is_closed: - raise ConnectionShutdown("Connection to %s is closed" % self.host) - - try: - request_id = self._id_queue.get_nowait() - except Queue.EMPTY: - raise ConnectionBusy( - "Connection to %s is at the max number of requests" % self.host) - - self._callbacks[request_id] = cb - self.push(msg.to_string(request_id, compression=self.compressor)) - return request_id - - def wait_for_response(self, msg): - return self.wait_for_responses(msg)[0] - - def wait_for_responses(self, *msgs): - waiter = ResponseWaiter(len(msgs)) - for i, msg in enumerate(msgs): - self.send_msg(msg, partial(waiter.got_response, index=i)) - - return waiter.deliver() - - def register_watcher(self, event_type, callback): - self._push_watchers[event_type].add(callback) - self._have_listeners = True - self.wait_for_response(RegisterMessage(event_list=[event_type])) - - def register_watchers(self, type_callback_dict): - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self._have_listeners = True - self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys())) + return self._readable or ((self.is_control_connection or self._continuous_paging_sessions) and not (self.is_defunct or self.is_closed)) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py new file mode 100644 index 0000000000..6be7738236 --- /dev/null +++ b/cassandra/io/eventletreactor.py @@ -0,0 +1,195 @@ +# Copyright 2014 Symantec Corporation +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Originally derived from MagnetoDB source: +# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py +import eventlet +from eventlet.green import socket +from eventlet.queue import Queue +from greenlet import GreenletExit +import logging +from threading import Event +import time +from deprecated import deprecated + +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager +try: + from eventlet.green.OpenSSL import SSL + _PYOPENSSL = True +except ImportError as e: + _PYOPENSSL = False + no_pyopenssl_error = e + + +log = logging.getLogger(__name__) + + +def _check_pyopenssl(): + if not _PYOPENSSL: + raise ImportError( + "{}, pyOpenSSL must be installed to enable " + "SSL support with the Eventlet event loop".format(str(no_pyopenssl_error)) + ) + + +@deprecated(version="3.30.0", reason="The eventlet event loop is deprecated and will be removed in 3.31.0. See CASSPYTHON-12.") +class EventletConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes ``eventlet``. + + This implementation assumes all eventlet monkey patching is active. It is not tested with partial patching. + """ + + _read_watcher = None + _write_watcher = None + + _socket_impl = eventlet.green.socket + _ssl_impl = eventlet.green.ssl + + _timers = None + _timeout_watcher = None + _new_timer = None + + @classmethod + def initialize_reactor(cls): + eventlet.monkey_patch() + if not cls._timers: + cls._timers = TimerManager() + cls._timeout_watcher = eventlet.spawn(cls.service_timeouts) + cls._new_timer = Event() + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._timers.add_timer(timer) + cls._new_timer.set() + return timer + + @classmethod + def service_timeouts(cls): + """ + cls._timeout_watcher runs in this loop forever. + It is usually waiting for the next timeout on the cls._new_timer Event. + When new timers are added, that event is set so that the watcher can + wake up and possibly set an earlier timeout. + """ + timer_manager = cls._timers + while True: + next_end = timer_manager.service_timeouts() + sleep_time = max(next_end - time.time(), 0) if next_end else 10000 + cls._new_timer.wait(sleep_time) + cls._new_timer.clear() + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + self.uses_legacy_ssl_options = self.ssl_options and not self.ssl_context + self._write_queue = Queue() + + self._connect_socket() + + self._read_watcher = eventlet.spawn(lambda: self.handle_read()) + self._write_watcher = eventlet.spawn(lambda: self.handle_write()) + self._send_options_message() + + def _wrap_socket_from_context(self): + _check_pyopenssl() + rv = SSL.Connection(self.ssl_context, self._socket) + rv.set_connect_state() + if self.ssl_options and 'server_hostname' in self.ssl_options: + # This is necessary for SNI + rv.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + return rv + + def _initiate_connection(self, sockaddr): + if self.uses_legacy_ssl_options: + super(EventletConnection, self)._initiate_connection(sockaddr) + else: + self._socket.connect(sockaddr) + if self.ssl_context or self.ssl_options: + self._socket.do_handshake() + + def _validate_hostname(self): + if not self.uses_legacy_ssl_options: + cert_name = self._socket.get_peer_certificate().get_subject().commonName + if cert_name != self.endpoint.address: + raise Exception("Hostname verification failed! Certificate name '{}' " + "doesn't match endpoint '{}'".format(cert_name, self.endpoint.address)) + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) + + cur_gthread = eventlet.getcurrent() + + if self._read_watcher and self._read_watcher != cur_gthread: + self._read_watcher.kill() + if self._write_watcher and self._write_watcher != cur_gthread: + self._write_watcher.kill() + if self._socket: + self._socket.close() + log.debug("Closed socket to %s" % (self.endpoint,)) + + if not self.is_defunct: + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_close(self): + log.debug("connection closed by server") + self.close() + + def handle_write(self): + while True: + try: + next_msg = self._write_queue.get() + self._socket.sendall(next_msg) + except socket.error as err: + log.debug("Exception during socket send for %s: %s", self, err) + self.defunct(err) + return # Leave the write loop + except GreenletExit: # graceful greenthread exit + return + + def handle_read(self): + while True: + try: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + except socket.error as err: + log.debug("Exception during socket recv for %s: %s", + self, err) + self.defunct(err) + return # leave the read loop + except GreenletExit: # graceful greenthread exit + return + + if buf and self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return + + def push(self, data): + chunk_size = self.out_buffer_size + for i in range(0, len(data), chunk_size): + self._write_queue.put(data[i:i + chunk_size]) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py new file mode 100644 index 0000000000..eb1296d6f9 --- /dev/null +++ b/cassandra/io/geventreactor.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gevent +import gevent.event +from gevent.queue import Queue +from gevent import socket +import gevent.ssl + +from deprecated import deprecated +import logging +import time + +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager + +log = logging.getLogger(__name__) + +@deprecated(version="3.30.0", reason="The gevent event loop is deprecated and will be removed in 3.31.0. See CASSPYTHON-12.") +class GeventConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes ``gevent``. + + This implementation assumes all gevent monkey patching is active. It is not tested with partial patching. + """ + + _read_watcher = None + _write_watcher = None + + _socket_impl = gevent.socket + _ssl_impl = gevent.ssl + + _timers = None + _timeout_watcher = None + _new_timer = None + + @classmethod + def initialize_reactor(cls): + if not cls._timers: + cls._timers = TimerManager() + cls._timeout_watcher = gevent.spawn(cls.service_timeouts) + cls._new_timer = gevent.event.Event() + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._timers.add_timer(timer) + cls._new_timer.set() + return timer + + @classmethod + def service_timeouts(cls): + timer_manager = cls._timers + timer_event = cls._new_timer + while True: + next_end = timer_manager.service_timeouts() + sleep_time = max(next_end - time.time(), 0) if next_end else 10000 + timer_event.wait(sleep_time) + timer_event.clear() + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self._write_queue = Queue() + + self._connect_socket() + + self._read_watcher = gevent.spawn(self.handle_read) + self._write_watcher = gevent.spawn(self.handle_write) + self._send_options_message() + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) + if self._read_watcher: + self._read_watcher.kill(block=False) + if self._write_watcher: + self._write_watcher.kill(block=False) + if self._socket: + self._socket.close() + log.debug("Closed socket to %s" % (self.endpoint,)) + + if not self.is_defunct: + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_close(self): + log.debug("connection closed by server") + self.close() + + def handle_write(self): + while True: + try: + next_msg = self._write_queue.get() + self._socket.sendall(next_msg) + except socket.error as err: + log.debug("Exception in send for %s: %s", self, err) + self.defunct(err) + return + + def handle_read(self): + while True: + try: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + except socket.error as err: + log.debug("Exception in read for %s: %s", self, err) + self.defunct(err) + return # leave the read loop + + if buf and self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return + + def push(self, data): + chunk_size = self.out_buffer_size + for i in range(0, len(data), chunk_size): + self._write_queue.put(data[i:i + chunk_size]) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index a5325f7459..76a53b9bdd 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -1,130 +1,284 @@ -from collections import defaultdict, deque -from functools import partial, wraps +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import atexit +from collections import deque +from functools import partial import logging import os import socket -from threading import Event, Lock, Thread -import traceback -from Queue import Queue - -from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown, - ConnectionBusy, NONBLOCKING) -from cassandra.decoder import RegisterMessage -from cassandra.marshal import int32_unpack -import cassandra.io.libevwrapper as libev +import ssl +from threading import Lock, Thread +import time +from cassandra import DependencyException try: - from cStringIO import StringIO + import cassandra.io.libevwrapper as libev except ImportError: - from StringIO import StringIO # ignore flake8 warning: # NOQA + raise DependencyException( + "The C extension needed to use libev was not found. This " + "probably means that you didn't have the required build dependencies " + "when installing the driver. See " + "https://docs.datastax.com/en/developer/python-driver/latest/installation/index.html#c-extensions " + "for instructions on installing build dependencies and building " + "the C extension.") + +from cassandra.connection import (Connection, ConnectionShutdown, + NONBLOCKING, Timer, TimerManager) + log = logging.getLogger(__name__) -_loop = libev.Loop() -_loop_notifier = libev.Async(_loop) -_loop_notifier.start() -# prevent _loop_notifier from keeping the loop from returning -_loop.unref() +def _cleanup(loop): + if loop: + loop._cleanup() + + +class LibevLoop(object): + + def __init__(self): + self._pid = os.getpid() + self._loop = libev.Loop() + self._notifier = libev.Async(self._loop) + self._notifier.start() + + # prevent _notifier from keeping the loop from returning + self._loop.unref() + + self._started = False + self._shutdown = False + self._lock = Lock() + self._lock_thread = Lock() + + self._thread = None + + # set of all connections; only replaced with a new copy + # while holding _conn_set_lock, never modified in place + self._live_conns = set() + # newly created connections that need their write/read watcher started + self._new_conns = set() + # recently closed connections that need their write/read watcher stopped + self._closed_conns = set() + self._conn_set_lock = Lock() + + self._preparer = libev.Prepare(self._loop, self._loop_will_run) + # prevent _preparer from keeping the loop from returning + self._loop.unref() + self._preparer.start() + + self._timers = TimerManager() + self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) + + def maybe_start(self): + should_start = False + with self._lock: + if not self._started: + log.debug("Starting libev event loop") + self._started = True + should_start = True + + if should_start: + with self._lock_thread: + if not self._shutdown: + self._thread = Thread(target=self._run_loop, name="event_loop") + self._thread.daemon = True + self._thread.start() + + self._notifier.send() + + def _run_loop(self): + while True: + self._loop.start() + # there are still active watchers, no deadlock + with self._lock: + if not self._shutdown and self._live_conns: + log.debug("Restarting event loop") + continue + else: + # all Connections have been closed, no active watchers + log.debug("All Connections currently closed, event loop ended") + self._started = False + break -_loop_started = None -_loop_lock = Lock() + def _cleanup(self): + self._shutdown = True + if not self._thread: + return -def _run_loop(): - while True: - end_condition = _loop.start() - # there are still active watchers, no deadlock - with _loop_lock: - if end_condition: - log.debug("Restarting event loop") - continue - else: - # all Connections have been closed, no active watchers - log.debug("All Connections currently closed, event loop ended") - global _loop_started - _loop_started = False - break + for conn in self._live_conns | self._new_conns | self._closed_conns: + conn.close() + for watcher in (conn._write_watcher, conn._read_watcher): + if watcher: + watcher.stop() -def _start_loop(): - global _loop_started - should_start = False - with _loop_lock: - if not _loop_started: - log.debug("Starting libev event loop") - _loop_started = True - should_start = True + self.notify() # wake the timer watcher - if should_start: - t = Thread(target=_run_loop, name="event_loop") - t.daemon = True - t.start() + # PYTHON-752 Thread might have just been created and not started + with self._lock_thread: + self._thread.join(timeout=1.0) - return should_start + if self._thread.is_alive(): + log.warning( + "Event loop thread could not be joined, so shutdown may not be clean. " + "Please call Cluster.shutdown() to avoid this.") + log.debug("Event loop thread was joined") -def defunct_on_error(f): + def add_timer(self, timer): + self._timers.add_timer(timer) + self._notifier.send() # wake up in case this timer is earlier - @wraps(f) - def wrapper(self, *args, **kwargs): - try: - return f(self, *args, **kwargs) - except Exception as exc: - self.defunct(exc) + def _update_timer(self): + if not self._shutdown: + next_end = self._timers.service_timeouts() + if next_end: + self._loop_timer.start(next_end - time.time()) # timer handles negative values + else: + self._loop_timer.stop() + + def _on_loop_timer(self): + self._timers.service_timeouts() + + def notify(self): + self._notifier.send() + + def connection_created(self, conn): + with self._conn_set_lock: + new_live_conns = self._live_conns.copy() + new_live_conns.add(conn) + self._live_conns = new_live_conns + + new_new_conns = self._new_conns.copy() + new_new_conns.add(conn) + self._new_conns = new_new_conns + + def connection_destroyed(self, conn): + with self._conn_set_lock: + new_live_conns = self._live_conns.copy() + new_live_conns.discard(conn) + self._live_conns = new_live_conns + + new_closed_conns = self._closed_conns.copy() + new_closed_conns.add(conn) + self._closed_conns = new_closed_conns + + self._notifier.send() + + def _loop_will_run(self, prepare): + changed = False + for conn in self._live_conns: + if not conn.deque and conn._write_watcher_is_active: + if conn._write_watcher: + conn._write_watcher.stop() + conn._write_watcher_is_active = False + changed = True + elif conn.deque and not conn._write_watcher_is_active: + conn._write_watcher.start() + conn._write_watcher_is_active = True + changed = True + + if self._new_conns: + with self._conn_set_lock: + to_start = self._new_conns + self._new_conns = set() + + for conn in to_start: + conn._read_watcher.start() + + changed = True - return wrapper + if self._closed_conns: + with self._conn_set_lock: + to_stop = self._closed_conns + self._closed_conns = set() + + for conn in to_stop: + if conn._write_watcher: + conn._write_watcher.stop() + # clear reference cycles from IO callback + del conn._write_watcher + if conn._read_watcher: + conn._read_watcher.stop() + # clear reference cycles from IO callback + del conn._read_watcher + + changed = True + + # TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks + self._update_timer() + + if changed: + self._notifier.send() + + +_global_loop = None +atexit.register(lambda: _cleanup(_global_loop)) class LibevConnection(Connection): """ - An implementation of :class:`.Connection` that utilizes libev. + An implementation of :class:`.Connection` that uses libev for its event loop. """ - - _total_reqd_bytes = 0 + _write_watcher_is_active = False _read_watcher = None _write_watcher = None _socket = None @classmethod - def factory(cls, *args, **kwargs): - conn = cls(*args, **kwargs) - conn.connected_event.wait() - if conn.last_error: - raise conn.last_error + def initialize_reactor(cls): + global _global_loop + if not _global_loop: + _global_loop = LibevLoop() else: - return conn + if _global_loop._pid != os.getpid(): + log.debug("Detected fork, clearing and reinitializing reactor state") + cls.handle_fork() + _global_loop = LibevLoop() + + @classmethod + def handle_fork(cls): + global _global_loop + if _global_loop: + _global_loop._cleanup() + _global_loop = None + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + _global_loop.add_timer(timer) + return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - self.connected_event = Event() - self._iobuf = StringIO() - - self._callbacks = {} - self._push_watchers = defaultdict(set) self.deque = deque() - - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.settimeout(1.0) # TODO potentially make this value configurable - self._socket.connect((self.host, self.port)) + self._deque_lock = Lock() + self._connect_socket() self._socket.setblocking(0) - if self.sockopts: - for args in self.sockopts: - self._socket.setsockopt(*args) - - self._read_watcher = libev.IO(self._socket._sock, libev.EV_READ, _loop, self.handle_read) - self._write_watcher = libev.IO(self._socket._sock, libev.EV_WRITE, _loop, self.handle_write) - with _loop_lock: - self._read_watcher.start() - self._write_watcher.start() + with _global_loop._lock: + self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, _global_loop._loop, self.handle_read) + self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, _global_loop._loop, self.handle_write) self._send_options_message() + _global_loop.connection_created(self) + # start the global event loop if needed - if not _start_loop(): - # if the loop was already started, notify it - with _loop_lock: - _loop_notifier.send() + _global_loop.maybe_start() def close(self): with self.lock: @@ -132,170 +286,103 @@ def close(self): return self.is_closed = True - log.debug("Closing connection to %s" % (self.host,)) - if self._read_watcher: - self._read_watcher.stop() - if self._write_watcher: - self._write_watcher.stop() + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) + + _global_loop.connection_destroyed(self) self._socket.close() - with _loop_lock: - _loop_notifier.send() + log.debug("Closed socket to %s", self.endpoint) # don't leave in-progress operations hanging if not self.is_defunct: - self._error_all_callbacks( - ConnectionShutdown("Connection to %s was closed" % self.host)) - - def __del__(self): - self.close() + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) - def defunct(self, exc): - with self.lock: - if self.is_defunct: - return - self.is_defunct = True + def handle_write(self, watcher, revents, errno=None): + if revents & libev.EV_ERROR: + if errno: + exc = IOError(errno, os.strerror(errno)) + else: + exc = Exception("libev reported an error") - trace = traceback.format_exc(exc) - if trace != "None": - log.debug("Defuncting connection to %s: %s\n%s", - self.host, exc, traceback.format_exc(exc)) - else: - log.debug("Defuncting connection to %s: %s", self.host, exc) + self.defunct(exc) + return - self.last_error = exc - self._error_all_callbacks(exc) - self.connected_event.set() - return exc + while True: + try: + with self._deque_lock: + next_msg = self.deque.popleft() + except IndexError: + if not self._socket_writable: + self._socket_writable = True + return - def _error_all_callbacks(self, exc): - new_exc = ConnectionShutdown(str(exc)) - for cb in self._callbacks.values(): - cb(new_exc) + try: + sent = self._socket.send(next_msg) + except socket.error as err: + if (err.args[0] in NONBLOCKING or + err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): + if err.args[0] in NONBLOCKING: + self._socket_writable = False + with self._deque_lock: + self.deque.appendleft(next_msg) + else: + self.defunct(err) + return + else: + if sent < len(next_msg): + with self._deque_lock: + self.deque.appendleft(next_msg[sent:]) + # we've seen some cases that 0 is returned instead of NONBLOCKING. But usually, + # we don't expect this to happen. https://bugs.python.org/issue20951 + if sent == 0: + self._socket_writable = False + return + + def handle_read(self, watcher, revents, errno=None): + if revents & libev.EV_ERROR: + if errno: + exc = IOError(errno, os.strerror(errno)) + else: + exc = Exception("libev reported an error") - def handle_write(self, watcher, revents): - try: - next_msg = self.deque.popleft() - except IndexError: - self._write_watcher.stop() + self.defunct(exc) return - try: - sent = self._socket.send(next_msg) + while True: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + if len(buf) < self.in_buffer_size: + break except socket.error as err: - if (err.args[0] in NONBLOCKING): - self.deque.appendleft(next_msg) + if isinstance(err, ssl.SSLError): + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if not self._iobuf.tell(): + return + else: + self.defunct(err) + return + elif err.args[0] in NONBLOCKING: + if not self._iobuf.tell(): + return else: self.defunct(err) - return - else: - if sent < len(next_msg): - self.deque.appendleft(next_msg[sent:]) - - if not self.deque: - self._write_watcher.stop() - - def handle_read(self, watcher, revents): - try: - buf = self._socket.recv(self.in_buffer_size) - except socket.error as err: - if err.args[0] not in NONBLOCKING: - self.defunct(err) - return + return - if buf: - self._iobuf.write(buf) - while True: - pos = self._iobuf.tell() - if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): - # we don't have a complete header yet or we - # already saw a header, but we don't have a - # complete message yet - break - else: - # have enough for header, read body len from header - self._iobuf.seek(4) - body_len_bytes = self._iobuf.read(4) - body_len = int32_unpack(body_len_bytes) - - # seek to end to get length of current buffer - self._iobuf.seek(0, os.SEEK_END) - pos = self._iobuf.tell() - - if pos - 8 >= body_len: - # read message header and body - self._iobuf.seek(0) - msg = self._iobuf.read(8 + body_len) - - # leave leftover in current buffer - leftover = self._iobuf.read() - self._iobuf = StringIO() - self._iobuf.write(leftover) - - self._total_reqd_bytes = 0 - self.process_msg(msg, body_len) - else: - self._total_reqd_bytes = body_len + 8 - break + if self._iobuf.tell(): + self.process_io_buffer() else: - log.debug("connection closed by server") + log.debug("Connection %s closed by server", self) self.close() - def handle_pushed(self, response): - log.debug("Message pushed from server: %r", response) - for cb in self._push_watchers.get(response.event_type, []): - try: - cb(response.event_args) - except Exception: - log.exception("Pushed event handler errored, ignoring:") - def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] - for i in xrange(0, len(data), sabs): + for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] - with self.lock: + with self._deque_lock: self.deque.extend(chunks) - - if not self._write_watcher.is_active(): - with _loop_lock: - self._write_watcher.start() - _loop_notifier.send() - - def send_msg(self, msg, cb): - if self.is_defunct: - raise ConnectionShutdown("Connection to %s is defunct" % self.host) - elif self.is_closed: - raise ConnectionShutdown("Connection to %s is closed" % self.host) - - try: - request_id = self._id_queue.get_nowait() - except Queue.EMPTY: - raise ConnectionBusy( - "Connection to %s is at the max number of requests" % self.host) - - self._callbacks[request_id] = cb - self.push(msg.to_string(request_id, compression=self.compressor)) - return request_id - - def wait_for_response(self, msg): - return self.wait_for_responses(msg)[0] - - def wait_for_responses(self, *msgs): - waiter = ResponseWaiter(len(msgs)) - for i, msg in enumerate(msgs): - self.send_msg(msg, partial(waiter.got_response, index=i)) - - return waiter.deliver() - - def register_watcher(self, event_type, callback): - self._push_watchers[event_type].add(callback) - self.wait_for_response(RegisterMessage(event_list=[event_type])) - - def register_watchers(self, type_callback_dict): - for event_type, callback in type_callback_dict.items(): - self._push_watchers[event_type].add(callback) - self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys())) + _global_loop.notify() diff --git a/cassandra/io/libevwrapper.c b/cassandra/io/libevwrapper.c index 334cfbbe8c..84d3d16bb2 100644 --- a/cassandra/io/libevwrapper.c +++ b/cassandra/io/libevwrapper.c @@ -8,7 +8,8 @@ typedef struct libevwrapper_Loop { static void Loop_dealloc(libevwrapper_Loop *self) { - self->ob_type->tp_free((PyObject *)self); + ev_loop_destroy(self->loop); + Py_TYPE(self)->tp_free((PyObject *)self); }; static PyObject* @@ -17,9 +18,10 @@ Loop_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { self = (libevwrapper_Loop *)type->tp_alloc(type, 0); if (self != NULL) { - self->loop = ev_default_loop(0); + self->loop = ev_loop_new(EVBACKEND_SELECT); if (!self->loop) { - PyErr_SetString(PyExc_Exception, "Error getting default ev loop"); + PyErr_SetString(PyExc_Exception, "Error getting new ev loop"); + Py_DECREF(self); return NULL; } } @@ -36,7 +38,7 @@ Loop_init(libevwrapper_Loop *self, PyObject *args, PyObject *kwds) { }; static PyObject * -Loop_start(libevwrapper_Loop *self) { +Loop_start(libevwrapper_Loop *self, PyObject *args) { Py_BEGIN_ALLOW_THREADS ev_run(self->loop, 0); Py_END_ALLOW_THREADS @@ -44,7 +46,7 @@ Loop_start(libevwrapper_Loop *self) { }; static PyObject * -Loop_unref(libevwrapper_Loop *self) { +Loop_unref(libevwrapper_Loop *self, PyObject *args) { ev_unref(self->loop); Py_RETURN_NONE; } @@ -55,9 +57,9 @@ static PyMethodDef Loop_methods[] = { {NULL} /* Sentinel */ }; -static PyTypeObject libevwrapper_LoopType = { - PyObject_HEAD_INIT(NULL) - 0, /*ob_size*/ +static +PyTypeObject libevwrapper_LoopType = { + PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.Loop",/*tp_name*/ sizeof(libevwrapper_Loop), /*tp_basicsize*/ 0, /*tp_itemsize*/ @@ -108,29 +110,26 @@ static void IO_dealloc(libevwrapper_IO *self) { Py_XDECREF(self->loop); Py_XDECREF(self->callback); - self->ob_type->tp_free((PyObject *)self); + Py_TYPE(self)->tp_free((PyObject *)self); }; static void io_callback(struct ev_loop *loop, ev_io *watcher, int revents) { - if (revents & EV_ERROR) { - if (errno) { - PyErr_SetFromErrno(PyExc_IOError); - } else { - PyErr_SetString(PyExc_IOError, "libev errored"); - } + if (!Py_IsInitialized()) { + return; } libevwrapper_IO *self = watcher->data; - - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - - PyObject *result = PyObject_CallFunction(self->callback, "Ob", self, revents); + PyObject *result; + PyGILState_STATE gstate = PyGILState_Ensure(); + if (revents & EV_ERROR && errno) { + result = PyObject_CallFunction(self->callback, "Obi", self, revents, errno); + } else { + result = PyObject_CallFunction(self->callback, "Ob", self, revents); + } if (!result) { PyErr_WriteUnraisable(self->callback); } Py_XDECREF(result); - PyGILState_Release(gstate); }; @@ -138,10 +137,11 @@ static int IO_init(libevwrapper_IO *self, PyObject *args, PyObject *kwds) { PyObject *socket; PyObject *callback; - libevwrapper_Loop *loop; - int io_flags = 0; + PyObject *loop; + int io_flags = 0, fd = -1; + struct ev_io *io = NULL; - if (!PyArg_ParseTuple(args, "ObOO", &socket, &io_flags, &loop, &callback)) { + if (!PyArg_ParseTuple(args, "OiOO", &socket, &io_flags, &loop, &callback)) { return -1; } @@ -160,45 +160,53 @@ IO_init(libevwrapper_IO *self, PyObject *args, PyObject *kwds) { self->callback = callback; } - int fd = PyObject_AsFileDescriptor(socket); + fd = PyObject_AsFileDescriptor(socket); if (fd == -1) { PyErr_SetString(PyExc_TypeError, "unable to get file descriptor from socket"); Py_XDECREF(callback); Py_XDECREF(loop); return -1; } - ev_io_init(&self->io, io_callback, fd, io_flags); + io = &(self->io); + ev_io_init(io, io_callback, fd, io_flags); self->io.data = self; return 0; } static PyObject* -IO_start(libevwrapper_IO *self) { +IO_start(libevwrapper_IO *self, PyObject *args) { ev_io_start(self->loop->loop, &self->io); Py_RETURN_NONE; } static PyObject* -IO_stop(libevwrapper_IO *self) { +IO_stop(libevwrapper_IO *self, PyObject *args) { ev_io_stop(self->loop->loop, &self->io); Py_RETURN_NONE; } static PyObject* -IO_is_active(libevwrapper_IO *self) { - return PyBool_FromLong(ev_is_active(&self->io)); +IO_is_active(libevwrapper_IO *self, PyObject *args) { + struct ev_io *io = &(self->io); + return PyBool_FromLong(ev_is_active(io)); +} + +static PyObject* +IO_is_pending(libevwrapper_IO *self, PyObject *args) { + struct ev_io *io = &(self->io); + return PyBool_FromLong(ev_is_pending(io)); } static PyMethodDef IO_methods[] = { {"start", (PyCFunction)IO_start, METH_NOARGS, "Start the watcher"}, {"stop", (PyCFunction)IO_stop, METH_NOARGS, "Stop the watcher"}, {"is_active", (PyCFunction)IO_is_active, METH_NOARGS, "Is the watcher active?"}, + {"is_pending", (PyCFunction)IO_is_pending, METH_NOARGS, "Is the watcher pending?"}, {NULL} /* Sentinal */ }; static PyTypeObject libevwrapper_IOType = { - PyObject_HEAD_INIT(NULL) - 0, /*ob_size*/ + PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.IO", /*tp_name*/ sizeof(libevwrapper_IO), /*tp_basicsize*/ 0, /*tp_itemsize*/ @@ -244,16 +252,18 @@ typedef struct libevwrapper_Async { static void Async_dealloc(libevwrapper_Async *self) { - self->ob_type->tp_free((PyObject *)self); + Py_XDECREF(self->loop); + Py_TYPE(self)->tp_free((PyObject *)self); }; static void async_callback(EV_P_ ev_async *watcher, int revents) {}; static int Async_init(libevwrapper_Async *self, PyObject *args, PyObject *kwds) { - libevwrapper_Loop *loop; - + PyObject *loop; static char *kwlist[] = {"loop", NULL}; + struct ev_async *async = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &loop)) { PyErr_SetString(PyExc_TypeError, "unable to get file descriptor from socket"); return -1; @@ -261,22 +271,23 @@ Async_init(libevwrapper_Async *self, PyObject *args, PyObject *kwds) { if (loop) { Py_INCREF(loop); - self->loop = loop; + self->loop = (libevwrapper_Loop *)loop; } else { return -1; } - ev_async_init(&self->async, async_callback); + async = &(self->async); + ev_async_init(async, async_callback); return 0; }; static PyObject * -Async_start(libevwrapper_Async *self) { +Async_start(libevwrapper_Async *self, PyObject *args) { ev_async_start(self->loop->loop, &self->async); Py_RETURN_NONE; } static PyObject * -Async_send(libevwrapper_Async *self) { +Async_send(libevwrapper_Async *self, PyObject *args) { ev_async_send(self->loop->loop, &self->async); Py_RETURN_NONE; }; @@ -288,8 +299,7 @@ static PyMethodDef Async_methods[] = { }; static PyTypeObject libevwrapper_AsyncType = { - PyObject_HEAD_INIT(NULL) - 0, /*ob_size*/ + PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.Async", /*tp_name*/ sizeof(libevwrapper_Async), /*tp_basicsize*/ 0, /*tp_itemsize*/ @@ -327,44 +337,342 @@ static PyTypeObject libevwrapper_AsyncType = { (initproc)Async_init, /* tp_init */ }; +typedef struct libevwrapper_Prepare { + PyObject_HEAD + struct ev_prepare prepare; + struct libevwrapper_Loop *loop; + PyObject *callback; +} libevwrapper_Prepare; + +static void +Prepare_dealloc(libevwrapper_Prepare *self) { + Py_XDECREF(self->loop); + Py_XDECREF(self->callback); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static void prepare_callback(struct ev_loop *loop, ev_prepare *watcher, int revents) { + if (!Py_IsInitialized()) { + return; + } + + libevwrapper_Prepare *self = watcher->data; + PyObject *result = NULL; + PyGILState_STATE gstate; + + gstate = PyGILState_Ensure(); + result = PyObject_CallFunction(self->callback, "O", self); + if (!result) { + PyErr_WriteUnraisable(self->callback); + } + Py_XDECREF(result); + + PyGILState_Release(gstate); +} + +static int +Prepare_init(libevwrapper_Prepare *self, PyObject *args, PyObject *kwds) { + PyObject *callback; + PyObject *loop; + struct ev_prepare *prepare = NULL; + + if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { + return -1; + } + + if (loop) { + Py_INCREF(loop); + self->loop = (libevwrapper_Loop *)loop; + } else { + return -1; + } + + if (callback) { + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); + Py_XDECREF(loop); + return -1; + } + Py_INCREF(callback); + self->callback = callback; + } + prepare = &(self->prepare); + ev_prepare_init(prepare, prepare_callback); + self->prepare.data = self; + return 0; +} + +static PyObject * +Prepare_start(libevwrapper_Prepare *self, PyObject *args) { + ev_prepare_start(self->loop->loop, &self->prepare); + Py_RETURN_NONE; +} + +static PyObject * +Prepare_stop(libevwrapper_Prepare *self, PyObject *args) { + ev_prepare_stop(self->loop->loop, &self->prepare); + Py_RETURN_NONE; +} + +static PyMethodDef Prepare_methods[] = { + {"start", (PyCFunction)Prepare_start, METH_NOARGS, "Start the Prepare watcher"}, + {"stop", (PyCFunction)Prepare_stop, METH_NOARGS, "Stop the Prepare watcher"}, + {NULL} /* Sentinal */ +}; + +static PyTypeObject libevwrapper_PrepareType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.Prepare", /*tp_name*/ + sizeof(libevwrapper_Prepare), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Prepare_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Prepare objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Prepare_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Prepare_init, /* tp_init */ +}; + +typedef struct libevwrapper_Timer { + PyObject_HEAD + struct ev_timer timer; + struct libevwrapper_Loop *loop; + PyObject *callback; +} libevwrapper_Timer; + +static void +Timer_dealloc(libevwrapper_Timer *self) { + Py_XDECREF(self->loop); + Py_XDECREF(self->callback); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static void timer_callback(struct ev_loop *loop, ev_timer *watcher, int revents) { + if (!Py_IsInitialized()) { + return; + } + + libevwrapper_Timer *self = watcher->data; + + PyObject *result = NULL; + PyGILState_STATE gstate; + + gstate = PyGILState_Ensure(); + result = PyObject_CallFunction(self->callback, NULL); + if (!result) { + PyErr_WriteUnraisable(self->callback); + } + Py_XDECREF(result); + + PyGILState_Release(gstate); +} + +static int +Timer_init(libevwrapper_Timer *self, PyObject *args, PyObject *kwds) { + PyObject *callback; + PyObject *loop; + + if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { + return -1; + } + + if (loop) { + Py_INCREF(loop); + self->loop = (libevwrapper_Loop *)loop; + } else { + return -1; + } + + if (callback) { + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); + Py_XDECREF(loop); + return -1; + } + Py_INCREF(callback); + self->callback = callback; + } + ev_init(&self->timer, timer_callback); + self->timer.data = self; + return 0; +} + +static PyObject * +Timer_start(libevwrapper_Timer *self, PyObject *args) { + double timeout; + if (!PyArg_ParseTuple(args, "d", &timeout)) { + return NULL; + } + /* some tiny non-zero number to avoid zero, and + make it run immediately for negative timeouts */ + self->timer.repeat = fmax(timeout, 0.000000001); + ev_timer_again(self->loop->loop, &self->timer); + Py_RETURN_NONE; +} + +static PyObject * +Timer_stop(libevwrapper_Timer *self, PyObject *args) { + ev_timer_stop(self->loop->loop, &self->timer); + Py_RETURN_NONE; +} + +static PyMethodDef Timer_methods[] = { + {"start", (PyCFunction)Timer_start, METH_VARARGS, "Start the Timer watcher"}, + {"stop", (PyCFunction)Timer_stop, METH_NOARGS, "Stop the Timer watcher"}, + {NULL} /* Sentinal */ +}; + +static PyTypeObject libevwrapper_TimerType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.Timer", /*tp_name*/ + sizeof(libevwrapper_Timer), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Timer_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Timer objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Timer_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Timer_init, /* tp_init */ +}; + + static PyMethodDef module_methods[] = { {NULL} /* Sentinal */ }; +PyDoc_STRVAR(module_doc, +"libev wrapper methods"); + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "libevwrapper", + module_doc, + -1, + module_methods, + NULL, + NULL, + NULL, + NULL +}; + +#define INITERROR return NULL -#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ -#define PyMODINIT_FUNC void -#endif -PyMODINIT_FUNC -initlibevwrapper(void) +PyObject * +PyInit_libevwrapper(void) { - PyObject *m; + PyObject *module = NULL; if (PyType_Ready(&libevwrapper_LoopType) < 0) - return; + INITERROR; libevwrapper_IOType.tp_new = PyType_GenericNew; if (PyType_Ready(&libevwrapper_IOType) < 0) - return; + INITERROR; + + libevwrapper_PrepareType.tp_new = PyType_GenericNew; + if (PyType_Ready(&libevwrapper_PrepareType) < 0) + INITERROR; libevwrapper_AsyncType.tp_new = PyType_GenericNew; if (PyType_Ready(&libevwrapper_AsyncType) < 0) - return; + INITERROR; + + libevwrapper_TimerType.tp_new = PyType_GenericNew; + if (PyType_Ready(&libevwrapper_TimerType) < 0) + INITERROR; - m = Py_InitModule3("libevwrapper", module_methods, "libev wrapper methods"); - PyModule_AddIntConstant(m, "EV_READ", EV_READ); - PyModule_AddIntConstant(m, "EV_WRITE", EV_WRITE); + module = PyModule_Create(&moduledef); + + if (module == NULL) + INITERROR; + + if (PyModule_AddIntConstant(module, "EV_READ", EV_READ) == -1) + INITERROR; + if (PyModule_AddIntConstant(module, "EV_WRITE", EV_WRITE) == -1) + INITERROR; + if (PyModule_AddIntConstant(module, "EV_ERROR", EV_ERROR) == -1) + INITERROR; Py_INCREF(&libevwrapper_LoopType); - PyModule_AddObject(m, "Loop", (PyObject *)&libevwrapper_LoopType); + if (PyModule_AddObject(module, "Loop", (PyObject *)&libevwrapper_LoopType) == -1) + INITERROR; Py_INCREF(&libevwrapper_IOType); - PyModule_AddObject(m, "IO", (PyObject *)&libevwrapper_IOType); + if (PyModule_AddObject(module, "IO", (PyObject *)&libevwrapper_IOType) == -1) + INITERROR; + + Py_INCREF(&libevwrapper_PrepareType); + if (PyModule_AddObject(module, "Prepare", (PyObject *)&libevwrapper_PrepareType) == -1) + INITERROR; Py_INCREF(&libevwrapper_AsyncType); - PyModule_AddObject(m, "Async", (PyObject *)&libevwrapper_AsyncType); + if (PyModule_AddObject(module, "Async", (PyObject *)&libevwrapper_AsyncType) == -1) + INITERROR; + Py_INCREF(&libevwrapper_TimerType); + if (PyModule_AddObject(module, "Timer", (PyObject *)&libevwrapper_TimerType) == -1) + INITERROR; + +#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 7) + // Since CPython 3.7, `Py_Initialize()` routing always initializes GIL. + // Routine `PyEval_ThreadsInitialized()` has been deprecated in CPython 3.7 + // and completely removed in CPython 3.13. if (!PyEval_ThreadsInitialized()) { PyEval_InitThreads(); } +#endif + + return module; } diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py new file mode 100644 index 0000000000..58e79e9ce9 --- /dev/null +++ b/cassandra/io/twistedreactor.py @@ -0,0 +1,309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Module that implements an event loop based on twisted +( https://twistedmatrix.com ). +""" +import atexit +from deprecated import deprecated +import logging +import time +from functools import partial +from threading import Thread, Lock +import weakref + +from twisted.internet import reactor, protocol +from twisted.internet.endpoints import connectProtocol, TCP4ClientEndpoint, SSL4ClientEndpoint +from twisted.internet.interfaces import IOpenSSLClientConnectionCreator +from twisted.python.failure import Failure +from zope.interface import implementer + +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager, ConnectionException + +try: + from OpenSSL import SSL + _HAS_SSL = True +except ImportError as e: + _HAS_SSL = False + import_exception = e +log = logging.getLogger(__name__) + + +def _cleanup(cleanup_weakref): + try: + cleanup_weakref()._cleanup() + except ReferenceError: + return + + +class TwistedConnectionProtocol(protocol.Protocol): + """ + Twisted Protocol class for handling data received and connection + made events. + """ + + def __init__(self, connection): + self.connection = connection + + def dataReceived(self, data): + """ + Callback function that is called when data has been received + on the connection. + + Reaches back to the Connection object and queues the data for + processing. + """ + self.connection._iobuf.write(data) + self.connection.handle_read() + + def connectionMade(self): + """ + Callback function that is called when a connection has succeeded. + + Reaches back to the Connection object and confirms that the connection + is ready. + """ + self.connection.client_connection_made(self.transport) + + def connectionLost(self, reason): + # reason is a Failure instance + log.debug("Connect lost: %s", reason) + self.connection.defunct(reason.value) + + +class TwistedLoop(object): + + _lock = None + _thread = None + _timeout_task = None + _timeout = None + + def __init__(self): + self._lock = Lock() + self._timers = TimerManager() + + def maybe_start(self): + with self._lock: + if not reactor.running: + self._thread = Thread(target=reactor.run, + name="cassandra_driver_twisted_event_loop", + kwargs={'installSignalHandlers': False}) + self._thread.daemon = True + self._thread.start() + atexit.register(partial(_cleanup, weakref.ref(self))) + + def _reactor_stopped(self): + return reactor._stopped + + def _cleanup(self): + if self._thread: + reactor.callFromThread(reactor.stop) + self._thread.join(timeout=1.0) + if self._thread.is_alive(): + log.warning("Event loop thread could not be joined, so " + "shutdown may not be clean. Please call " + "Cluster.shutdown() to avoid this.") + log.debug("Event loop thread was joined") + + def add_timer(self, timer): + self._timers.add_timer(timer) + # callFromThread to schedule from the loop thread, where + # the timeout task can safely be modified + reactor.callFromThread(self._schedule_timeout, timer.end) + + def _schedule_timeout(self, next_timeout): + if next_timeout: + delay = max(next_timeout - time.time(), 0) + if self._timeout_task and self._timeout_task.active(): + if next_timeout < self._timeout: + self._timeout_task.reset(delay) + self._timeout = next_timeout + else: + self._timeout_task = reactor.callLater(delay, self._on_loop_timer) + self._timeout = next_timeout + + def _on_loop_timer(self): + self._timers.service_timeouts() + self._schedule_timeout(self._timers.next_timeout) + + +@implementer(IOpenSSLClientConnectionCreator) +class _SSLCreator(object): + def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout): + self.endpoint = endpoint + self.ssl_options = ssl_options + self.check_hostname = check_hostname + self.timeout = timeout + + if ssl_context: + self.context = ssl_context + else: + self.context = SSL.Context(SSL.TLSv1_METHOD) + if "certfile" in self.ssl_options: + self.context.use_certificate_file(self.ssl_options["certfile"]) + if "keyfile" in self.ssl_options: + self.context.use_privatekey_file(self.ssl_options["keyfile"]) + if "ca_certs" in self.ssl_options: + self.context.load_verify_locations(self.ssl_options["ca_certs"]) + if "cert_reqs" in self.ssl_options: + self.context.set_verify( + self.ssl_options["cert_reqs"], + callback=self.verify_callback + ) + self.context.set_info_callback(self.info_callback) + + def verify_callback(self, connection, x509, errnum, errdepth, ok): + return ok + + def info_callback(self, connection, where, ret): + if where & SSL.SSL_CB_HANDSHAKE_DONE: + if self.check_hostname and self.endpoint.address != connection.get_peer_certificate().get_subject().commonName: + transport = connection.get_app_data() + transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint))) + + def clientConnectionForTLS(self, tlsProtocol): + connection = SSL.Connection(self.context, None) + connection.set_app_data(tlsProtocol) + if self.ssl_options and "server_hostname" in self.ssl_options: + connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + return connection + +@deprecated(version="3.30.0", reason="The Twisted event loop is deprecated and will be removed in 3.31.0. See CASSPYTHON-12.") +class TwistedConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes the + Twisted event loop. + """ + + _loop = None + + @classmethod + def initialize_reactor(cls): + if not cls._loop: + cls._loop = TwistedLoop() + + @classmethod + def create_timer(cls, timeout, callback): + timer = Timer(timeout, callback) + cls._loop.add_timer(timer) + return timer + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + Note that we can't call reactor methods directly here because + it's not thread-safe, so we schedule the reactor/connection + stuff to be run from the event loop thread when it gets the + chance. + """ + Connection.__init__(self, *args, **kwargs) + + self.is_closed = True + self.connector = None + self.transport = None + + reactor.callFromThread(self.add_connection) + self._loop.maybe_start() + + def _check_pyopenssl(self): + if self.ssl_context or self.ssl_options: + if not _HAS_SSL: + raise ImportError( + str(import_exception) + + ', pyOpenSSL must be installed to enable SSL support with the Twisted event loop' + ) + + def add_connection(self): + """ + Convenience function to connect and store the resulting + connector. + """ + host, port = self.endpoint.resolve() + if self.ssl_context or self.ssl_options: + # Can't use optionsForClientTLS here because it *forces* hostname verification. + # Cool they enforce strong security, but we have to be able to turn it off + self._check_pyopenssl() + + ssl_connection_creator = _SSLCreator( + self.endpoint, + self.ssl_context if self.ssl_context else None, + self.ssl_options, + self._check_hostname, + self.connect_timeout, + ) + + endpoint = SSL4ClientEndpoint( + reactor, + host, + port, + sslContextFactory=ssl_connection_creator, + timeout=self.connect_timeout, + ) + else: + endpoint = TCP4ClientEndpoint( + reactor, + host, + port, + timeout=self.connect_timeout + ) + connectProtocol(endpoint, TwistedConnectionProtocol(self)) + + def client_connection_made(self, transport): + """ + Called by twisted protocol when a connection attempt has + succeeded. + """ + with self.lock: + self.is_closed = False + self.transport = transport + self._send_options_message() + + def close(self): + """ + Disconnect and error-out all requests. + """ + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) + reactor.callFromThread(self.transport.connector.disconnect) + log.debug("Closed socket to %s", self.endpoint) + + if not self.is_defunct: + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_read(self): + """ + Process the incoming data buffer. + """ + self.process_io_buffer() + + def push(self, data): + """ + This function is called when outgoing data should be queued + for sending. + + Note that we can't call transport.write() directly because + it is not thread-safe, so we schedule it to run from within + the event loop when it gets the chance. + """ + reactor.callFromThread(self.transport.write, data) diff --git a/cassandra/ioutils.pyx b/cassandra/ioutils.pyx new file mode 100644 index 0000000000..91c2bf9542 --- /dev/null +++ b/cassandra/ioutils.pyx @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include 'cython_marshal.pyx' +from cassandra.buffer cimport Buffer, from_ptr_and_size + +from libc.stdint cimport int32_t +from cassandra.bytesio cimport BytesIOReader + + +cdef inline int get_buf(BytesIOReader reader, Buffer *buf_out) except -1: + """ + Get a pointer into the buffer provided by BytesIOReader for the + next data item in the stream of values. + + BEWARE: + If the next item has a zero negative size, the pointer will be set to NULL. + A negative size happens when the value is NULL in the database, whereas a + zero size may happen either for legacy reasons, or for data types such as + strings (which may be empty). + """ + cdef Py_ssize_t raw_val_size = read_int(reader) + cdef char *ptr + if raw_val_size <= 0: + ptr = NULL + else: + ptr = reader.read(raw_val_size) + + from_ptr_and_size(ptr, raw_val_size, buf_out) + return 0 + +cdef inline int32_t read_int(BytesIOReader reader) except ?0xDEAD: + cdef Buffer buf + buf.ptr = reader.read(4) + buf.size = 4 + return unpack_num[int32_t](&buf) diff --git a/cassandra/marshal.py b/cassandra/marshal.py index cf9c765662..e8733f0544 100644 --- a/cassandra/marshal.py +++ b/cassandra/marshal.py @@ -1,14 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import struct + def _make_packer(format_string): - try: - packer = struct.Struct(format_string) # new in Python 2.5 - except AttributeError: - pack = lambda x: struct.pack(format_string, x) - unpack = lambda s: struct.unpack(format_string, s) - else: - pack = packer.pack - unpack = lambda s: packer.unpack(s)[0] + packer = struct.Struct(format_string) + pack = packer.pack + unpack = lambda s: packer.unpack(s)[0] return pack, unpack int64_pack, int64_unpack = _make_packer('>q') @@ -17,37 +29,163 @@ def _make_packer(format_string): int8_pack, int8_unpack = _make_packer('>b') uint64_pack, uint64_unpack = _make_packer('>Q') uint32_pack, uint32_unpack = _make_packer('>I') +uint32_le_pack, uint32_le_unpack = _make_packer('H') uint8_pack, uint8_unpack = _make_packer('>B') float_pack, float_unpack = _make_packer('>f') double_pack, double_unpack = _make_packer('>d') +# Special case for cassandra header +header_struct = struct.Struct('>BBbB') +header_pack = header_struct.pack +header_unpack = header_struct.unpack + +# in protocol version 3 and higher, the stream ID is two bytes +v3_header_struct = struct.Struct('>BBhB') +v3_header_pack = v3_header_struct.pack +v3_header_unpack = v3_header_struct.unpack + + def varint_unpack(term): - val = int(term.encode('hex'), 16) - if (ord(term[0]) & 128) != 0: - val = val - (1 << (len(term) * 8)) + val = int(''.join("%02x" % i for i in term), 16) + if (term[0] & 128) != 0: + len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code + val -= 1 << (len_term * 8) return val -def bitlength(n): - bitlen = 0 - while n > 0: - n >>= 1 - bitlen += 1 - return bitlen + +def bit_length(n): + return int.bit_length(n) + def varint_pack(big): pos = True if big == 0: - return '\x00' + return b'\x00' if big < 0: - bytelength = bitlength(abs(big) - 1) / 8 + 1 + bytelength = bit_length(abs(big) - 1) // 8 + 1 big = (1 << bytelength * 8) + big pos = False - revbytes = [] + revbytes = bytearray() while big > 0: - revbytes.append(chr(big & 0xff)) + revbytes.append(big & 0xff) big >>= 8 - if pos and ord(revbytes[-1]) & 0x80: - revbytes.append('\x00') + if pos and revbytes[-1] & 0x80: + revbytes.append(0) + revbytes.reverse() + return bytes(revbytes) + + +point_be = struct.Struct('>dd') +point_le = struct.Struct('ddd') +circle_le = struct.Struct('> 63) + + +def decode_zig_zag(n): + return (n >> 1) ^ -(n & 1) + + +def vints_unpack(term): # noqa + values = [] + n = 0 + while n < len(term): + first_byte = term[n] + + if (first_byte & 128) == 0: + val = first_byte + else: + num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() + val = first_byte & (0xff >> num_extra_bytes) + end = n + num_extra_bytes + while n < end: + n += 1 + val <<= 8 + val |= term[n] & 0xff + + n += 1 + values.append(decode_zig_zag(val)) + + return tuple(values) + +def vints_pack(values): + revbytes = bytearray() + values = [int(v) for v in values[::-1]] + for value in values: + v = encode_zig_zag(value) + if v < 128: + revbytes.append(v) + else: + num_extra_bytes = 0 + num_bits = v.bit_length() + # We need to reserve (num_extra_bytes+1) bits in the first byte + # i.e. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved + # i.e. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved + reserved_bits = num_extra_bytes + 1 + while num_bits > (8-(reserved_bits)): + num_extra_bytes += 1 + num_bits -= 8 + reserved_bits = min(num_extra_bytes + 1, 8) + revbytes.append(v & 0xff) + v >>= 8 + + if num_extra_bytes > 8: + raise ValueError('Value %d is too big and cannot be encoded as vint' % value) + + # We can now store the last bits in the first byte + n = 8 - num_extra_bytes + v |= (0xff >> n << n) + revbytes.append(abs(v)) + revbytes.reverse() - return ''.join(revbytes) + return bytes(revbytes) + +def uvint_unpack(bytes): + first_byte = bytes[0] + + if (first_byte & 128) == 0: + return (first_byte,1) + + num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() + rv = first_byte & (0xff >> num_extra_bytes) + for idx in range(1,num_extra_bytes + 1): + new_byte = bytes[idx] + rv <<= 8 + rv |= new_byte & 0xff + + return (rv, num_extra_bytes + 1) + +def uvint_pack(val): + rv = bytearray() + if val < 128: + rv.append(val) + else: + v = val + num_extra_bytes = 0 + num_bits = v.bit_length() + # We need to reserve (num_extra_bytes+1) bits in the first byte + # i.e. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved + # i.e. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved + reserved_bits = num_extra_bytes + 1 + while num_bits > (8-(reserved_bits)): + num_extra_bytes += 1 + num_bits -= 8 + reserved_bits = min(num_extra_bytes + 1, 8) + rv.append(v & 0xff) + v >>= 8 + + if num_extra_bytes > 8: + raise ValueError('Value %d is too big and cannot be encoded as vint' % val) + + # We can now store the last bits in the first byte + n = 8 - num_extra_bytes + v |= (0xff >> n << n) + rv.append(abs(v)) + + rv.reverse() + return bytes(rv) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index a365e86ac4..4c1be285b8 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1,41 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from binascii import unhexlify from bisect import bisect_left from collections import defaultdict -try: - from collections import OrderedDict -except ImportError: # Python <2.7 - from cassandra.util import OrderedDict # NOQA +from collections.abc import Mapping +from functools import total_ordering from hashlib import md5 import json import logging import re +import sys from threading import RLock -import weakref +import struct +import random murmur3 = None try: - from murmur3 import murmur3 -except ImportError: + from cassandra.murmur3 import murmur3 +except ImportError as e: pass +from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized import cassandra.cqltypes as types +from cassandra.encoder import Encoder from cassandra.marshal import varint_unpack -from cassandra.pool import Host +from cassandra.protocol import QueryMessage +from cassandra.query import dict_factory, bind_params +from cassandra.util import OrderedDict, Version +from cassandra.pool import HostDistance +from cassandra.connection import EndPoint log = logging.getLogger(__name__) -_keywords = set(( - 'select', 'from', 'where', 'and', 'key', 'insert', 'update', 'with', - 'limit', 'using', 'use', 'count', 'set', - 'begin', 'apply', 'batch', 'truncate', 'delete', 'in', 'create', - 'keyspace', 'schema', 'columnfamily', 'table', 'index', 'on', 'drop', - 'primary', 'into', 'values', 'timestamp', 'ttl', 'alter', 'add', 'type', - 'compact', 'storage', 'order', 'by', 'asc', 'desc', 'clustering', - 'token', 'writetime', 'map', 'list', 'to' +cql_keywords = set(( + 'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin', + 'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count', + 'counter', 'create', 'custom', 'date', 'decimal', 'default', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'double', 'drop', + 'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function', + 'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json', + 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'mbean', 'mbeans', 'modify', 'monotonic', + 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', + 'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set', + 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid', + 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'unset', 'update', 'use', 'user', + 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime', + + # DSE specifics + "node", "nodes", "plan", "active", "application", "applications", "java", "executor", "executors", "std_out", "std_err", + "renew", "delegation", "no", "redact", "token", "lowercasestring", "cluster", "authentication", "schemes", "scheme", + "internal", "ldap", "kerberos", "remote", "object", "method", "call", "calls", "search", "schema", "config", "rows", + "columns", "profiles", "commit", "reload", "rebuild", "field", "workpool", "any", "submission", "indices", + "restrict", "unrestrict" )) - -_unreserved_keywords = set(( - 'key', 'clustering', 'ttl', 'compact', 'storage', 'type', 'values' +""" +Set of keywords in CQL. + +Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g +""" + +cql_keywords_unreserved = set(( + 'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains', + 'count', 'counter', 'custom', 'date', 'decimal', 'deterministic', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', + 'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces', + 'language', 'list', 'login', 'map', 'monotonic', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', + 'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time', + 'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar', + 'varint', 'writetime' )) +""" +Set of unreserved keywords in CQL. + +Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g +""" + +cql_keywords_reserved = cql_keywords - cql_keywords_unreserved +""" +Set of reserved keywords in CQL. +""" + +_encoder = Encoder() class Metadata(object): @@ -51,15 +108,20 @@ class Metadata(object): A map from keyspace names to matching :class:`~.KeyspaceMetadata` instances. """ + partitioner = None + """ + The string name of the partitioner for the cluster. + """ + token_map = None """ A :class:`~.TokenMap` instance describing the ring topology. """ - def __init__(self, cluster): - # use a weak reference so that the Cluster object can be GC'ed. - # Normally the cycle detector would handle this, but implementing - # __del__ disables that. - self.cluster_ref = weakref.ref(cluster) + dbaas = False + """ A boolean indicating if connected to a DBaaS cluster """ + + def __init__(self): self.keyspaces = {} + self.dbaas = False self._hosts = {} self._hosts_lock = RLock() @@ -68,180 +130,148 @@ def export_schema_as_string(self): Returns a string that can be executed as a query in order to recreate the entire schema. The string is formatted to be human readable. """ - return "\n".join(ks.export_as_string() for ks in self.keyspaces.values()) - - def rebuild_schema(self, keyspace, table, ks_results, cf_results, col_results): - """ - Rebuild the view of the current schema from a fresh set of rows from - the system schema tables. - - For internal use only. - """ - cf_def_rows = defaultdict(list) - col_def_rows = defaultdict(lambda: defaultdict(list)) + return "\n\n".join(ks.export_as_string() for ks in self.keyspaces.values()) - for row in cf_results: - cf_def_rows[row["keyspace_name"]].append(row) - - for row in col_results: - ksname = row["keyspace_name"] - cfname = row["columnfamily_name"] - col_def_rows[ksname][cfname].append(row) - - # either table or ks_results must be None - if not table: - # ks_results is not None - added_keyspaces = set() - for row in ks_results: - keyspace_meta = self._build_keyspace_metadata(row) - for table_row in cf_def_rows.get(keyspace_meta.name, []): - table_meta = self._build_table_metadata( - keyspace_meta, table_row, col_def_rows[keyspace_meta.name]) - keyspace_meta.tables[table_meta.name] = table_meta - - added_keyspaces.add(keyspace_meta.name) - self.keyspaces[keyspace_meta.name] = keyspace_meta - - if not keyspace: - # remove not-just-added keyspaces - self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() - if name in added_keyspaces) - else: - # keyspace is not None, table is not None - try: - keyspace_meta = self.keyspaces[keyspace] - except KeyError: - # we're trying to update a table in a keyspace we don't know - # about, something went wrong. - # TODO log error, submit schema refresh - pass - if keyspace in cf_def_rows: - for table_row in cf_def_rows[keyspace]: - table_meta = self._build_table_metadata( - keyspace_meta, table_row, col_def_rows[keyspace]) - keyspace_meta.tables[table_meta.name] = table_meta + def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs): - def _build_keyspace_metadata(self, row): - name = row["keyspace_name"] - durable_writes = row["durable_writes"] - strategy_class = row["strategy_class"] - strategy_options = json.loads(row["strategy_options"]) - return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + server_version = self.get_host(connection.endpoint).release_version + dse_version = self.get_host(connection.endpoint).dse_version + parser = get_schema_parser(connection, server_version, dse_version, timeout) - def _build_table_metadata(self, keyspace_metadata, row, col_rows): - cfname = row["columnfamily_name"] + if not target_type: + self._rebuild_all(parser) + return - comparator = types.lookup_casstype(row["comparator"]) - if issubclass(comparator, types.CompositeType): - column_name_types = comparator.subtypes - is_composite = True - else: - column_name_types = (comparator,) - is_composite = False - - num_column_name_components = len(column_name_types) - last_col = column_name_types[-1] - - column_aliases = json.loads(row["column_aliases"]) - if is_composite: - if issubclass(last_col, types.ColumnToCollectionType): - # collections - is_compact = False - has_value = False - clustering_size = num_column_name_components - 2 - elif (len(column_aliases) == num_column_name_components - 1 - and issubclass(last_col, types.UTF8Type)): - # aliases? - is_compact = False - has_value = False - clustering_size = num_column_name_components - 1 + tt_lower = target_type.lower() + try: + parse_method = getattr(parser, 'get_' + tt_lower) + meta = parse_method(self.keyspaces, **kwargs) + if meta: + update_method = getattr(self, '_update_' + tt_lower) + if tt_lower == 'keyspace' and connection.protocol_version < 3: + # we didn't have 'type' target in legacy protocol versions, so we need to query those too + user_types = parser.get_types_map(self.keyspaces, **kwargs) + self._update_keyspace(meta, user_types) + else: + update_method(meta) else: - # compact table - is_compact = True - has_value = True - clustering_size = num_column_name_components - else: - is_compact = True - if column_aliases or not col_rows.get(cfname): - has_value = True - clustering_size = num_column_name_components + drop_method = getattr(self, '_drop_' + tt_lower) + drop_method(**kwargs) + except AttributeError: + raise ValueError("Unknown schema target_type: '%s'" % target_type) + + def _rebuild_all(self, parser): + current_keyspaces = set() + for keyspace_meta in parser.get_all_keyspaces(): + current_keyspaces.add(keyspace_meta.name) + old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) + self.keyspaces[keyspace_meta.name] = keyspace_meta + if old_keyspace_meta: + self._keyspace_updated(keyspace_meta.name) else: - has_value = False - clustering_size = 0 + self._keyspace_added(keyspace_meta.name) + + # remove not-just-added keyspaces + removed_keyspaces = [name for name in self.keyspaces.keys() + if name not in current_keyspaces] + self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() + if name in current_keyspaces) + for ksname in removed_keyspaces: + self._keyspace_removed(ksname) + + def _update_keyspace(self, keyspace_meta, new_user_types=None): + ks_name = keyspace_meta.name + old_keyspace_meta = self.keyspaces.get(ks_name, None) + self.keyspaces[ks_name] = keyspace_meta + if old_keyspace_meta: + keyspace_meta.tables = old_keyspace_meta.tables + keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types + keyspace_meta.indexes = old_keyspace_meta.indexes + keyspace_meta.functions = old_keyspace_meta.functions + keyspace_meta.aggregates = old_keyspace_meta.aggregates + keyspace_meta.views = old_keyspace_meta.views + if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): + self._keyspace_updated(ks_name) + else: + self._keyspace_added(ks_name) - table_meta = TableMetadata(keyspace_metadata, cfname) - table_meta.comparator = comparator + def _drop_keyspace(self, keyspace): + if self.keyspaces.pop(keyspace, None): + self._keyspace_removed(keyspace) - # partition key - key_aliases = row.get("key_aliases") - key_aliases = json.loads(key_aliases) if key_aliases else [] - - key_type = types.lookup_casstype(row["key_validator"]) - key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] - for i, col_type in enumerate(key_types): - if len(key_aliases) > i: - column_name = key_aliases[i] - elif i == 0: - column_name = "key" + def _update_table(self, meta): + try: + keyspace_meta = self.keyspaces[meta.keyspace_name] + # this is unfortunate, but protocol v4 does not differentiate + # between events for tables and views. .get_table will + # return one or the other based on the query results. + # Here we deal with that. + if isinstance(meta, TableMetadata): + keyspace_meta._add_table_metadata(meta) else: - column_name = "key%d" % i + keyspace_meta._add_view_metadata(meta) + except KeyError: + # can happen if keyspace disappears while processing async event + pass - col = ColumnMetadata(table_meta, column_name, col_type) - table_meta.columns[column_name] = col - table_meta.partition_key.append(col) + def _drop_table(self, keyspace, table): + try: + keyspace_meta = self.keyspaces[keyspace] + keyspace_meta._drop_table_metadata(table) # handles either table or view + except KeyError: + # can happen if keyspace disappears while processing async event + pass - # clustering key - for i in range(clustering_size): - if len(column_aliases) > i: - column_name = column_aliases[i] - else: - column_name = "column%d" % i + def _update_type(self, type_meta): + try: + self.keyspaces[type_meta.keyspace].user_types[type_meta.name] = type_meta + except KeyError: + # can happen if keyspace disappears while processing async event + pass - col = ColumnMetadata(table_meta, column_name, column_name_types[i]) - table_meta.columns[column_name] = col - table_meta.clustering_key.append(col) + def _drop_type(self, keyspace, type): + try: + self.keyspaces[keyspace].user_types.pop(type, None) + except KeyError: + # can happen if keyspace disappears while processing async event + pass - # value alias (if present) - if has_value: - validator = types.lookup_casstype(row["default_validator"]) - if not key_aliases: # TODO are we checking the right thing here? - value_alias = "value" - else: - value_alias = row["value_alias"] + def _update_function(self, function_meta): + try: + self.keyspaces[function_meta.keyspace].functions[function_meta.signature] = function_meta + except KeyError: + # can happen if keyspace disappears while processing async event + pass - col = ColumnMetadata(table_meta, value_alias, validator) - table_meta.columns[value_alias] = col + def _drop_function(self, keyspace, function): + try: + self.keyspaces[keyspace].functions.pop(function.signature, None) + except KeyError: + pass - # other normal columns - if col_rows: - for col_row in col_rows[cfname]: - column_meta = self._build_column_metadata(table_meta, col_row) - table_meta.columns[column_meta.name] = column_meta + def _update_aggregate(self, aggregate_meta): + try: + self.keyspaces[aggregate_meta.keyspace].aggregates[aggregate_meta.signature] = aggregate_meta + except KeyError: + pass - table_meta.options = self._build_table_options(row, is_compact) - return table_meta + def _drop_aggregate(self, keyspace, aggregate): + try: + self.keyspaces[keyspace].aggregates.pop(aggregate.signature, None) + except KeyError: + pass - def _build_table_options(self, row, is_compact_storage): - """ Setup the mostly-non-schema table options, like caching settings """ - options = dict((o, row.get(o)) for o in TableMetadata.recognized_options) - options["is_compact_storage"] = is_compact_storage - return options + def _keyspace_added(self, ksname): + if self.token_map: + self.token_map.rebuild_keyspace(ksname, build_if_absent=False) - def _build_column_metadata(self, table_metadata, row): - name = row["column_name"] - data_type = types.lookup_casstype(row["validator"]) - column_meta = ColumnMetadata(table_metadata, name, data_type) - index_meta = self._build_index_metadata(column_meta, row) - column_meta.index = index_meta - return column_meta + def _keyspace_updated(self, ksname): + if self.token_map: + self.token_map.rebuild_keyspace(ksname, build_if_absent=False) - def _build_index_metadata(self, column_metadata, row): - index_name = row.get("index_name") - index_type = row.get("index_type") - if index_name or index_type: - return IndexMetadata(column_metadata, index_name, index_type) - else: - return None + def _keyspace_removed(self, ksname): + if self.token_map: + self.token_map.remove_keyspace(ksname) def rebuild_token_map(self, partitioner, token_map): """ @@ -249,32 +279,28 @@ def rebuild_token_map(self, partitioner, token_map): system topology tables. For internal use only. """ + self.partitioner = partitioner if partitioner.endswith('RandomPartitioner'): token_class = MD5Token elif partitioner.endswith('Murmur3Partitioner'): token_class = Murmur3Token - if murmur3 is None: - log.warning( - "The murmur3 C extension is not available, token awareness " - "cannot be supported for the Murmur3Partitioner") elif partitioner.endswith('ByteOrderedPartitioner'): token_class = BytesToken else: self.token_map = None return - token_to_primary_replica = {} + token_to_host_owner = {} ring = [] - for host, token_strings in token_map.iteritems(): + for host, token_strings in token_map.items(): for token_string in token_strings: - token = token_class(token_string) + token = token_class.from_string(token_string) ring.append(token) - token_to_primary_replica[token] = host + token_to_host_owner[token] = host all_tokens = sorted(ring) self.token_map = TokenMap( - token_class, token_to_primary_replica, all_tokens, - self.keyspaces.values()) + token_class, token_to_host_owner, all_tokens, self) def get_replicas(self, keyspace, key): """ @@ -289,120 +315,382 @@ def get_replicas(self, keyspace, key): except NoMurmur3: return [] - def add_host(self, address): - cluster = self.cluster_ref() - with self._hosts_lock: - if address not in self._hosts: - new_host = Host(address, cluster.conviction_policy_factory) - self._hosts[address] = new_host - else: - return None + def can_support_partitioner(self): + if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None: + return False + else: + return True - new_host.monitor.register(cluster) - return new_host + def add_or_return_host(self, host): + """ + Returns a tuple (host, new), where ``host`` is a Host + instance, and ``new`` is a bool indicating whether + the host was newly added. + """ + with self._hosts_lock: + try: + return self._hosts[host.endpoint], False + except KeyError: + self._hosts[host.endpoint] = host + return host, True def remove_host(self, host): with self._hosts_lock: - return bool(self._hosts.pop(host.address, False)) + return bool(self._hosts.pop(host.endpoint, False)) + + def get_host(self, endpoint_or_address, port=None): + """ + Find a host in the metadata for a specific endpoint. If a string inet address and port are passed, + iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` and + :attr:`~.pool.Host.broadcast_rpc_port`attributes. + """ + if not isinstance(endpoint_or_address, EndPoint): + return self._get_host_by_address(endpoint_or_address, port) - def get_host(self, address): - return self._hosts.get(address) + return self._hosts.get(endpoint_or_address) + + def _get_host_by_address(self, address, port=None): + for host in self._hosts.values(): + if (host.broadcast_rpc_address == address and + (port is None or host.broadcast_rpc_port is None or host.broadcast_rpc_port == port)): + return host + + return None def all_hosts(self): """ Returns a list of all known :class:`.Host` instances in the cluster. """ with self._hosts_lock: - return self._hosts.values() + return list(self._hosts.values()) + + +REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator." + + +def trim_if_startswith(s, prefix): + if s.startswith(prefix): + return s[len(prefix):] + return s + + +_replication_strategies = {} -class ReplicationStrategy(object): +class ReplicationStrategyTypeType(type): + def __new__(metacls, name, bases, dct): + dct.setdefault('name', name) + cls = type.__new__(metacls, name, bases, dct) + if not name.startswith('_'): + _replication_strategies[name] = cls + return cls + + + +class _ReplicationStrategy(object, metaclass=ReplicationStrategyTypeType): + options_map = None @classmethod def create(cls, strategy_class, options_map): if not strategy_class: return None - if strategy_class.endswith("OldNetworkTopologyStrategy"): + strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX) + + rs_class = _replication_strategies.get(strategy_name, None) + if rs_class is None: + rs_class = _UnknownStrategyBuilder(strategy_name) + _replication_strategies[strategy_name] = rs_class + + try: + rs_instance = rs_class(options_map) + except Exception as exc: + log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc) return None - elif strategy_class.endswith("NetworkTopologyStrategy"): - return NetworkTopologyStrategy(options_map) - elif strategy_class.endswith("SimpleStrategy"): - repl_factor = options_map.get('replication_factor', None) - if not repl_factor: - return None - return SimpleStrategy(repl_factor) - - def make_token_replica_map(token_to_primary_replica, ring): + + return rs_instance + + def make_token_replica_map(self, token_to_host_owner, ring): raise NotImplementedError() def export_for_schema(self): raise NotImplementedError() +ReplicationStrategy = _ReplicationStrategy + + +class _UnknownStrategyBuilder(object): + def __init__(self, name): + self.name = name + + def __call__(self, options_map): + strategy_instance = _UnknownStrategy(self.name, options_map) + return strategy_instance + + +class _UnknownStrategy(ReplicationStrategy): + def __init__(self, name, options_map): + self.name = name + self.options_map = options_map.copy() if options_map is not None else dict() + self.options_map['class'] = self.name + + def __eq__(self, other): + return (isinstance(other, _UnknownStrategy) and + self.name == other.name and + self.options_map == other.options_map) + + def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + if self.options_map: + return dict((str(key), str(value)) for key, value in self.options_map.items()) + return "{'class': '%s'}" % (self.name, ) + + def make_token_replica_map(self, token_to_host_owner, ring): + return {} + + +class ReplicationFactor(object): + """ + Represent the replication factor of a keyspace. + """ + + all_replicas = None + """ + The number of total replicas. + """ + + full_replicas = None + """ + The number of replicas that own a full copy of the data. This is the same + than `all_replicas` when transient replication is not enabled. + """ + + transient_replicas = None + """ + The number of transient replicas. + + Only set if the keyspace has transient replication enabled. + """ + + def __init__(self, all_replicas, transient_replicas=None): + self.all_replicas = all_replicas + self.transient_replicas = transient_replicas + self.full_replicas = (all_replicas - transient_replicas) if transient_replicas else all_replicas + + @staticmethod + def create(rf): + """ + Given the inputted replication factor string, parse and return the ReplicationFactor instance. + """ + transient_replicas = None + try: + all_replicas = int(rf) + except ValueError: + try: + rf = rf.split('/') + all_replicas, transient_replicas = int(rf[0]), int(rf[1]) + except Exception: + raise ValueError("Unable to determine replication factor from: {}".format(rf)) + + return ReplicationFactor(all_replicas, transient_replicas) + + def __str__(self): + return ("%d/%d" % (self.all_replicas, self.transient_replicas) if self.transient_replicas + else "%d" % self.all_replicas) + + def __eq__(self, other): + if not isinstance(other, ReplicationFactor): + return False + + return self.all_replicas == other.all_replicas and self.full_replicas == other.full_replicas + + class SimpleStrategy(ReplicationStrategy): - name = "SimpleStrategy" - replication_factor = None + replication_factor_info = None + """ + A :class:`cassandra.metadata.ReplicationFactor` instance. + """ + + @property + def replication_factor(self): + """ + The replication factor for this keyspace. + + For backward compatibility, this returns the + :attr:`cassandra.metadata.ReplicationFactor.full_replicas` value of + :attr:`cassandra.metadata.SimpleStrategy.replication_factor_info`. + """ + return self.replication_factor_info.full_replicas - def __init__(self, replication_factor): - self.replication_factor = int(replication_factor) + def __init__(self, options_map): + self.replication_factor_info = ReplicationFactor.create(options_map['replication_factor']) - def make_token_replica_map(self, token_to_primary_replica, ring): + def make_token_replica_map(self, token_to_host_owner, ring): replica_map = {} for i in range(len(ring)): - j, hosts = 0, set() + j, hosts = 0, list() while len(hosts) < self.replication_factor and j < len(ring): token = ring[(i + j) % len(ring)] - hosts.add(token_to_primary_replica[token]) + host = token_to_host_owner[token] + if host not in hosts: + hosts.append(host) j += 1 replica_map[ring[i]] = hosts - return replica_map def export_for_schema(self): - return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \ - % (self.replication_factor,) + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + return "{'class': 'SimpleStrategy', 'replication_factor': '%s'}" \ + % (str(self.replication_factor_info),) + + def __eq__(self, other): + if not isinstance(other, SimpleStrategy): + return False + + return str(self.replication_factor_info) == str(other.replication_factor_info) + class NetworkTopologyStrategy(ReplicationStrategy): - name = "NetworkTopologyStrategy" + dc_replication_factors_info = None + """ + A map of datacenter names to the :class:`cassandra.metadata.ReplicationFactor` instance for that DC. + """ + dc_replication_factors = None + """ + A map of datacenter names to the replication factor for that DC. - def __init__(self, dc_replication_factors): - self.dc_replication_factors = dc_replication_factors + For backward compatibility, this maps to the :attr:`cassandra.metadata.ReplicationFactor.full_replicas` + value of the :attr:`cassandra.metadata.NetworkTopologyStrategy.dc_replication_factors_info` dict. + """ - def make_token_replica_map(self, token_to_primary_replica, ring): - # note: this does not account for hosts having different racks - replica_map = {} + def __init__(self, dc_replication_factors): + self.dc_replication_factors_info = dict( + (str(k), ReplicationFactor.create(v)) for k, v in dc_replication_factors.items()) + self.dc_replication_factors = dict( + (dc, rf.full_replicas) for dc, rf in self.dc_replication_factors_info.items()) + + def make_token_replica_map(self, token_to_host_owner, ring): + dc_rf_map = dict( + (dc, full_replicas) for dc, full_replicas in self.dc_replication_factors.items() + if full_replicas > 0) + + # build a map of DCs to lists of indexes into `ring` for tokens that + # belong to that DC + dc_to_token_offset = defaultdict(list) + dc_racks = defaultdict(set) + hosts_per_dc = defaultdict(set) + for i, token in enumerate(ring): + host = token_to_host_owner[token] + dc_to_token_offset[host.datacenter].append(i) + if host.datacenter and host.rack: + dc_racks[host.datacenter].add(host.rack) + hosts_per_dc[host.datacenter].add(host) + + # A map of DCs to an index into the dc_to_token_offset value for that dc. + # This is how we keep track of advancing around the ring for each DC. + dc_to_current_index = defaultdict(int) + + replica_map = defaultdict(list) for i in range(len(ring)): - remaining = self.dc_replication_factors.copy() - for j in range(len(ring)): - host = token_to_primary_replica[ring[(i + j) % len(ring)]] - if not host.datacenter: - continue + replicas = replica_map[ring[i]] - if not remaining[host.datacenter]: - # we already have all replicas for this DC + # go through each DC and find the replicas in that DC + for dc in dc_to_token_offset.keys(): + if dc not in dc_rf_map: continue - replica_map[ring[i]].add(host) - remaining[host.datacenter] -= 1 - if remaining[host.datacenter] == 0: - del remaining[host.datacenter] - - if not remaining: - break + # advance our per-DC index until we're up to at least the + # current token in the ring + token_offsets = dc_to_token_offset[dc] + index = dc_to_current_index[dc] + num_tokens = len(token_offsets) + while index < num_tokens and token_offsets[index] < i: + index += 1 + dc_to_current_index[dc] = index + + replicas_remaining = dc_rf_map[dc] + replicas_this_dc = 0 + skipped_hosts = [] + racks_placed = set() + racks_this_dc = dc_racks[dc] + hosts_this_dc = len(hosts_per_dc[dc]) + + for token_offset_index in range(index, index+num_tokens): + if token_offset_index >= len(token_offsets): + token_offset_index = token_offset_index - len(token_offsets) + + token_offset = token_offsets[token_offset_index] + host = token_to_host_owner[ring[token_offset]] + if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc: + break + + if host in replicas: + continue + + if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc): + skipped_hosts.append(host) + continue + + replicas.append(host) + replicas_this_dc += 1 + replicas_remaining -= 1 + racks_placed.add(host.rack) + + if len(racks_placed) == len(racks_this_dc): + for host in skipped_hosts: + if replicas_remaining == 0: + break + replicas.append(host) + replicas_remaining -= 1 + del skipped_hosts[:] return replica_map def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ ret = "{'class': 'NetworkTopologyStrategy'" - for dc, repl_factor in self.dc_replication_factors: - ret += ", '%s': '%d'" % (dc, repl_factor) + for dc, rf in sorted(self.dc_replication_factors_info.items()): + ret += ", '%s': '%s'" % (dc, str(rf)) return ret + "}" + def __eq__(self, other): + if not isinstance(other, NetworkTopologyStrategy): + return False + + return self.dc_replication_factors_info == other.dc_replication_factors_info + + +class LocalStrategy(ReplicationStrategy): + def __init__(self, options_map): + pass + + def make_token_replica_map(self, token_to_host_owner, ring): + return {} + + def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + return "{'class': 'LocalStrategy'}" + + def __eq__(self, other): + return isinstance(other, LocalStrategy) + class KeyspaceMetadata(object): """ @@ -410,12 +698,12 @@ class KeyspaceMetadata(object): """ name = None - """ The string name of the keyspace """ + """ The string name of the keyspace. """ durable_writes = True """ A boolean indicating whether durable writes are enabled for this keyspace - or not + or not. """ replication_strategy = None @@ -428,418 +716,2699 @@ class KeyspaceMetadata(object): A map from table names to instances of :class:`~.TableMetadata`. """ - def __init__(self, name, durable_writes, strategy_class, strategy_options): - self.name = name - self.durable_writes = durable_writes - self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) - self.tables = {} - - def export_as_string(self): - return "\n".join([self.as_cql_query()] + [t.as_cql_query() for t in self.tables.values()]) + indexes = None + """ + A dict mapping index names to :class:`.IndexMetadata` instances. + """ - def as_cql_query(self): - ret = "CREATE KEYSPACE %s WITH REPLICATION = %s " % \ - (self.name, self.replication_strategy.export_for_schema()) - return ret + (' AND DURABLE_WRITES = %s;' % ("true" if self.durable_writes else "false")) + user_types = None + """ + A map from user-defined type names to instances of :class:`~cassandra.metadata.UserType`. + .. versionadded:: 2.1.0 + """ -class TableMetadata(object): + functions = None """ - A representation of the schema for a single table. + A map from user-defined function signatures to instances of :class:`~cassandra.metadata.Function`. + + .. versionadded:: 2.6.0 """ - keyspace = None - """ An instance of :class:`~.KeyspaceMetadata` """ + aggregates = None + """ + A map from user-defined aggregate signatures to instances of :class:`~cassandra.metadata.Aggregate`. - name = None - """ The string name of the table """ + .. versionadded:: 2.6.0 + """ - partition_key = None + views = None """ - A list of :class:`.ColumnMetadata` instances representing the columns in - the partition key for this table. This will always hold at least one - column. + A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ - clustering_key = None + virtual = False """ - A list of :class:`.ColumnMetadata` instances representing the columns - in the clustering key for this table. These are all of the - :attr:`.primary_key` columns that are not in the :attr:`.partition_key`. + A boolean indicating if this is a virtual keyspace or not. Always ``False`` + for clusters running Cassandra pre-4.0 and DSE pre-6.7 versions. - Note that a table may have no clustering keys, in which case this will - be an empty list. + .. versionadded:: 3.15 + """ + + graph_engine = None """ + A string indicating whether a graph engine is enabled for this keyspace (Core/Classic). + """ + + _exc_info = None + """ set if metadata parsing failed """ + + def __init__(self, name, durable_writes, strategy_class, strategy_options, graph_engine=None): + self.name = name + self.durable_writes = durable_writes + self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) + self.tables = {} + self.indexes = {} + self.user_types = {} + self.functions = {} + self.aggregates = {} + self.views = {} + self.graph_engine = graph_engine @property - def primary_key(self): + def is_graph_enabled(self): + return self.graph_engine is not None + + def export_as_string(self): """ - A list of :class:`.ColumnMetadata` representing the components of - the primary key for this table. + Returns a CQL query string that can be used to recreate the entire keyspace, + including user-defined types and tables. """ - return self.partition_key + self.clustering_key + # Make sure tables with vertex are exported before tables with edges + tables_with_vertex = [t for t in self.tables.values() if hasattr(t, 'vertex') and t.vertex] + other_tables = [t for t in self.tables.values() if t not in tables_with_vertex] + + cql = "\n\n".join( + [self.as_cql_query() + ';'] + + self.user_type_strings() + + [f.export_as_string() for f in self.functions.values()] + + [a.export_as_string() for a in self.aggregates.values()] + + [t.export_as_string() for t in tables_with_vertex + other_tables]) + + if self._exc_info: + import traceback + ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \ + (self.name) + for line in traceback.format_exception(*self._exc_info): + ret += line + ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql + return ret + if self.virtual: + return ("/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n" + "Structure, for reference:*/\n" + "{cql}\n" + "").format(ks=self.name, cql=cql) + return cql - columns = None + def as_cql_query(self): + """ + Returns a CQL query string that can be used to recreate just this keyspace, + not including user-defined types and tables. + """ + if self.virtual: + return "// VIRTUAL KEYSPACE {}".format(protect_name(self.name)) + ret = "CREATE KEYSPACE %s WITH replication = %s " % ( + protect_name(self.name), + self.replication_strategy.export_for_schema()) + ret = ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) + if self.graph_engine is not None: + ret = ret + (" AND graph_engine = '%s'" % self.graph_engine) + return ret + + def user_type_strings(self): + user_type_strings = [] + user_types = self.user_types.copy() + keys = sorted(user_types.keys()) + for k in keys: + if k in user_types: + self.resolve_user_types(k, user_types, user_type_strings) + return user_type_strings + + def resolve_user_types(self, key, user_types, user_type_strings): + user_type = user_types.pop(key) + for type_name in user_type.field_types: + for sub_type in types.cql_types_from_string(type_name): + if sub_type in user_types: + self.resolve_user_types(sub_type, user_types, user_type_strings) + user_type_strings.append(user_type.export_as_string()) + + def _add_table_metadata(self, table_metadata): + old_indexes = {} + old_meta = self.tables.get(table_metadata.name, None) + if old_meta: + # views are not queried with table, so they must be transferred to new + table_metadata.views = old_meta.views + # indexes will be updated with what is on the new metadata + old_indexes = old_meta.indexes + + # note the intentional order of add before remove + # this makes sure the maps are never absent something that existed before this update + for index_name, index_metadata in table_metadata.indexes.items(): + self.indexes[index_name] = index_metadata + + for index_name in (n for n in old_indexes if n not in table_metadata.indexes): + self.indexes.pop(index_name, None) + + self.tables[table_metadata.name] = table_metadata + + def _drop_table_metadata(self, table_name): + table_meta = self.tables.pop(table_name, None) + if table_meta: + for index_name in table_meta.indexes: + self.indexes.pop(index_name, None) + for view_name in table_meta.views: + self.views.pop(view_name, None) + return + # we can't tell table drops from views, so drop both + # (name is unique among them, within a keyspace) + view_meta = self.views.pop(table_name, None) + if view_meta: + try: + self.tables[view_meta.base_table_name].views.pop(table_name, None) + except KeyError: + pass + + def _add_view_metadata(self, view_metadata): + try: + self.tables[view_metadata.base_table_name].views[view_metadata.name] = view_metadata + self.views[view_metadata.name] = view_metadata + except KeyError: + pass + + +class UserType(object): """ - A dict mapping column names to :class:`.ColumnMetadata` instances. + A user defined type, as created by ``CREATE TYPE`` statements. + + User-defined types were introduced in Cassandra 2.1. + + .. versionadded:: 2.1.0 """ - options = None + keyspace = None """ - A dict mapping table option names to their specific settings for this - table. + The string name of the keyspace in which this type is defined. """ - recognized_options = ( - "comment", "read_repair_chance", # "local_read_repair_chance", - "replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance", - "caching", "compaction_strategy_class", "compaction_strategy_options", - "min_compaction_threshold", "max_compression_threshold", - "compression_parameters") - - def __init__(self, keyspace_metadata, name, partition_key=None, clustering_key=None, columns=None, options=None): - self.keyspace = keyspace_metadata - self.name = name - self.partition_key = [] if partition_key is None else partition_key - self.clustering_key = [] if clustering_key is None else clustering_key - self.columns = OrderedDict() if columns is None else columns - self.options = options - self.comparator = None + name = None + """ + The name of this type. + """ - def export_as_string(self): - """ - Returns a string of CQL queries that can be used to recreate this table - along with all indexes on it. The returned string is formatted to - be human readable. - """ - ret = self.as_cql_query(formatted=True) - ret += ";" + field_names = None + """ + An ordered list of the names for each field in this user-defined type. + """ - for col_meta in self.columns.values(): - if col_meta.index: - ret += "\n%s;" % (col_meta.index.as_cql_query(),) + field_types = None + """ + An ordered list of the types for each field in this user-defined type. + """ - return ret + def __init__(self, keyspace, name, field_names, field_types): + self.keyspace = keyspace + self.name = name + # non-frozen collections can return None + self.field_names = field_names or [] + self.field_types = field_types or [] def as_cql_query(self, formatted=False): """ - Returns a CQL query that can be used to recreate this table (index - creations are not included). If `formatted` is set to :const:`True`, - extra whitespace will be added to make the query human readable. + Returns a CQL query that can be used to recreate this type. + If `formatted` is set to :const:`True`, extra whitespace will + be added to make the query more readable. """ - ret = "CREATE TABLE %s.%s (%s" % (self.keyspace.name, self.name, "\n" if formatted else "") + ret = "CREATE TYPE %s.%s (%s" % ( + protect_name(self.keyspace), + protect_name(self.name), + "\n" if formatted else "") if formatted: - column_join = ",\n" + field_join = ",\n" padding = " " else: - column_join = ", " + field_join = ", " padding = "" - columns = [] - for col in self.columns.values(): - columns.append("%s %s" % (col.name, col.typestring)) + fields = [] + for field_name, field_type in zip(self.field_names, self.field_types): + fields.append("%s %s" % (protect_name(field_name), field_type)) - if len(self.partition_key) == 1 and not self.clustering_key: - columns[0] += " PRIMARY KEY" + ret += field_join.join("%s%s" % (padding, field) for field in fields) + ret += "\n)" if formatted else ")" + return ret - ret += column_join.join("%s%s" % (padding, col) for col in columns) + def export_as_string(self): + return self.as_cql_query(formatted=True) + ';' - # primary key - if len(self.partition_key) > 1 or self.clustering_key: - ret += "%s%sPRIMARY KEY (" % (column_join, padding) - if len(self.partition_key) > 1: - ret += "(%s)" % ", ".join(col.name for col in self.partition_key) - else: - ret += self.partition_key[0].name +class Aggregate(object): + """ + A user defined aggregate function, as created by ``CREATE AGGREGATE`` statements. - if self.clustering_key: - ret += ", %s" % ", ".join(col.name for col in self.clustering_key) + Aggregate functions were introduced in Cassandra 2.2 - ret += ")" + .. versionadded:: 2.6.0 + """ - # options - ret += "%s) WITH " % ("\n" if formatted else "") + keyspace = None + """ + The string name of the keyspace in which this aggregate is defined + """ - option_strings = [] - if self.options.get("is_compact_storage"): - option_strings.append("COMPACT STORAGE") + name = None + """ + The name of this aggregate + """ - if self.clustering_key: - cluster_str = "CLUSTERING ORDER BY " + argument_types = None + """ + An ordered list of the types for each argument to the aggregate + """ - clustering_names = self.protect_names([c.name for c in self.clustering_key]) + final_func = None + """ + Name of a final function + """ - if self.options.get("is_compact_storage") and \ - not issubclass(self.comparator, types.CompositeType): - subtypes = [self.comparator] - else: - subtypes = self.comparator.subtypes + initial_condition = None + """ + Initial condition of the aggregate + """ - inner = [] - for colname, coltype in zip(clustering_names, subtypes): - ordering = "DESC" if issubclass(coltype, types.ReversedType) else "ASC" - inner.append("%s %s" % (colname, ordering)) + return_type = None + """ + Return type of the aggregate + """ - cluster_str += "(%s)" % ", ".join(inner) - option_strings.append(cluster_str) + state_func = None + """ + Name of a state function + """ - option_strings.extend(map(self._make_option_str, self.recognized_options)) - option_strings = filter(lambda x: x is not None, option_strings) + state_type = None + """ + Type of the aggregate state + """ - join_str = "\n AND " if formatted else " AND " - ret += join_str.join(option_strings) + deterministic = None + """ + Flag indicating if this function is guaranteed to produce the same result + for a particular input and state. This is available only with DSE >=6.0. + """ + + def __init__(self, keyspace, name, argument_types, state_func, + state_type, final_func, initial_condition, return_type, + deterministic): + self.keyspace = keyspace + self.name = name + self.argument_types = argument_types + self.state_func = state_func + self.state_type = state_type + self.final_func = final_func + self.initial_condition = initial_condition + self.return_type = return_type + self.deterministic = deterministic + + def as_cql_query(self, formatted=False): + """ + Returns a CQL query that can be used to recreate this aggregate. + If `formatted` is set to :const:`True`, extra whitespace will + be added to make the query more readable. + """ + sep = '\n ' if formatted else ' ' + keyspace = protect_name(self.keyspace) + name = protect_name(self.name) + type_list = ', '.join([types.strip_frozen(arg_type) for arg_type in self.argument_types]) + state_func = protect_name(self.state_func) + state_type = types.strip_frozen(self.state_type) + + ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \ + "SFUNC %(state_func)s%(sep)s" \ + "STYPE %(state_type)s" % locals() + + ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else '' + ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else '' + ret += '{}DETERMINISTIC'.format(sep) if self.deterministic else '' return ret - def _make_option_str(self, name): - value = self.options.get(name) - if value is not None: - if name == "comment": - value = value or "" - return "%s = %s" % (name, self.protect_value(value)) + def export_as_string(self): + return self.as_cql_query(formatted=True) + ';' - def protect_name(self, name): - if isinstance(name, unicode): - name = name.encode('utf8') - return self.maybe_escape_name(name) + @property + def signature(self): + return SignatureDescriptor.format_signature(self.name, self.argument_types) - def protect_names(self, names): - return map(self.protect_name, names) - def protect_value(self, value): - if value is None: - return 'NULL' - if isinstance(value, (int, float, bool)): - return str(value) - return "'%s'" % value.replace("'", "''") +class Function(object): + """ + A user defined function, as created by ``CREATE FUNCTION`` statements. - valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') + User-defined functions were introduced in Cassandra 2.2 - def is_valid_name(self, name): - if name is None: - return False - if name.lower() in _keywords - _unreserved_keywords: - return False - return self.valid_cql3_word_re.match(name) is not None + .. versionadded:: 2.6.0 + """ - def maybe_escape_name(self, name): - if self.is_valid_name(name): - return name - return self.escape_name(name) + keyspace = None + """ + The string name of the keyspace in which this function is defined + """ - def escape_name(self, name): - return '"%s"' % (name.replace('"', '""'),) + name = None + """ + The name of this function + """ + argument_types = None + """ + An ordered list of the types for each argument to the function + """ -class ColumnMetadata(object): + argument_names = None """ - A representation of a single column in a table. + An ordered list of the names of each argument to the function """ - table = None - """ The :class:`.TableMetadata` this column belongs to. """ + return_type = None + """ + Return type of the function + """ - name = None - """ The string name of this column. """ + language = None + """ + Language of the function body + """ + + body = None + """ + Function body string + """ - data_type = None + called_on_null_input = None + """ + Flag indicating whether this function should be called for rows with null values + (convenience function to avoid handling nulls explicitly if the result will just be null) + """ - index = None + deterministic = None """ - If an index exists on this column, this is an instance of - :class:`.IndexMetadata`, otherwise :const:`None`. + Flag indicating if this function is guaranteed to produce the same result + for a particular input. This is available only for DSE >=6.0. """ - def __init__(self, table_metadata, column_name, data_type, index_metadata=None): - self.table = table_metadata - self.name = column_name - self.data_type = data_type - self.index = index_metadata + monotonic = None + """ + Flag indicating if this function is guaranteed to increase or decrease + monotonically on any of its arguments. This is available only for DSE >=6.0. + """ - @property - def typestring(self): + monotonic_on = None + """ + A list containing the argument or arguments over which this function is + monotonic. This is available only for DSE >=6.0. + """ + + def __init__(self, keyspace, name, argument_types, argument_names, + return_type, language, body, called_on_null_input, + deterministic, monotonic, monotonic_on): + self.keyspace = keyspace + self.name = name + self.argument_types = argument_types + # argument_types (frozen>) will always be a list + # argument_name is not frozen in C* < 3.0 and may return None + self.argument_names = argument_names or [] + self.return_type = return_type + self.language = language + self.body = body + self.called_on_null_input = called_on_null_input + self.deterministic = deterministic + self.monotonic = monotonic + self.monotonic_on = monotonic_on + + def as_cql_query(self, formatted=False): """ - A string representation of the type for this column, such as "varchar" - or "map". + Returns a CQL query that can be used to recreate this function. + If `formatted` is set to :const:`True`, extra whitespace will + be added to make the query more readable. """ - if issubclass(self.data_type, types.ReversedType): - return self.data_type.subtypes[0].cql_parameterized_type() - else: - return self.data_type.cql_parameterized_type() + sep = '\n ' if formatted else ' ' + keyspace = protect_name(self.keyspace) + name = protect_name(self.name) + arg_list = ', '.join(["%s %s" % (protect_name(n), types.strip_frozen(t)) + for n, t in zip(self.argument_names, self.argument_types)]) + typ = self.return_type + lang = self.language + body = self.body + on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL" + deterministic_token = ('DETERMINISTIC{}'.format(sep) + if self.deterministic else + '') + monotonic_tokens = '' # default for nonmonotonic function + if self.monotonic: + # monotonic on all arguments; ignore self.monotonic_on + monotonic_tokens = 'MONOTONIC{}'.format(sep) + elif self.monotonic_on: + # if monotonic == False and monotonic_on is nonempty, we know that + # monotonicity was specified with MONOTONIC ON , so there's + # exactly 1 value there + monotonic_tokens = 'MONOTONIC ON {}{}'.format(self.monotonic_on[0], + sep) + + return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \ + "%(on_null)s ON NULL INPUT%(sep)s" \ + "RETURNS %(typ)s%(sep)s" \ + "%(deterministic_token)s" \ + "%(monotonic_tokens)s" \ + "LANGUAGE %(lang)s%(sep)s" \ + "AS $$%(body)s$$" % locals() - def __str__(self): - return "%s %s" % (self.name, self.data_type) + def export_as_string(self): + return self.as_cql_query(formatted=True) + ';' + @property + def signature(self): + return SignatureDescriptor.format_signature(self.name, self.argument_types) -class IndexMetadata(object): - """ - A representation of a secondary index on a column. - """ - column = None +class TableMetadata(object): """ - The column (:class:`.ColumnMetadata`) this index is on. + A representation of the schema for a single table. """ + keyspace_name = None + """ String name of this Table's keyspace """ + name = None - """ A string name for the index. """ + """ The string name of the table. """ - index_type = None - """ A string representing the type of index. """ + partition_key = None + """ + A list of :class:`.ColumnMetadata` instances representing the columns in + the partition key for this table. This will always hold at least one + column. + """ - def __init__(self, column_metadata, index_name=None, index_type=None): - self.column = column_metadata - self.name = index_name - self.index_type = index_type + clustering_key = None + """ + A list of :class:`.ColumnMetadata` instances representing the columns + in the clustering key for this table. These are all of the + :attr:`.primary_key` columns that are not in the :attr:`.partition_key`. - def as_cql_query(self): + Note that a table may have no clustering keys, in which case this will + be an empty list. + """ + + @property + def primary_key(self): """ - Returns a CQL query that can be used to recreate this index. + A list of :class:`.ColumnMetadata` representing the components of + the primary key for this table. """ - table = self.column.table - return "CREATE INDEX %s ON %s.%s (%s)" % (self.name, table.keyspace.name, table.name, self.column.name) + return self.partition_key + self.clustering_key + columns = None + """ + A dict mapping column names to :class:`.ColumnMetadata` instances. + """ -class TokenMap(object): + indexes = None """ - Information about the layout of the ring. + A dict mapping index names to :class:`.IndexMetadata` instances. """ - token_class = None + is_compact_storage = False + + options = None """ - A subclass of :class:`.Token`, depending on what partitioner the cluster uses. + A dict mapping table option names to their specific settings for this + table. """ - tokens_to_hosts_by_ks = None + compaction_options = { + "min_compaction_threshold": "min_threshold", + "max_compaction_threshold": "max_threshold", + "compaction_strategy_class": "class"} + + triggers = None """ - A map of keyspace names to a nested map of :class:`.Token` objects to - sets of :class:`.Host` objects. + A dict mapping trigger names to :class:`.TriggerMetadata` instances. """ - ring = None + views = None """ - An ordered list of :class:`.Token` instances in the ring. + A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ - def __init__(self, token_class, token_to_primary_replica, all_tokens, keyspaces): - self.token_class = token_class - self.ring = all_tokens + _exc_info = None + """ set if metadata parsing failed """ - self.tokens_to_hosts_by_ks = {} - for ks_metadata in keyspaces: - strategy = ks_metadata.replication_strategy - if strategy is None: - token_to_hosts = defaultdict(set) - for token, host in token_to_primary_replica.items(): - token_to_hosts[token].add(host) - self.tokens_to_hosts_by_ks[ks_metadata.name] = token_to_hosts - else: - self.tokens_to_hosts_by_ks[ks_metadata.name] = \ - strategy.make_token_replica_map( - token_to_primary_replica, all_tokens) + virtual = False + """ + A boolean indicating if this is a virtual table or not. Always ``False`` + for clusters running Cassandra pre-4.0 and DSE pre-6.7 versions. - def get_replicas(self, keyspace, token): + .. versionadded:: 3.15 + """ + + @property + def is_cql_compatible(self): """ - Get a set of :class:`.Host` instances representing all of the - replica nodes for a given :class:`.Token`. + A boolean indicating if this table can be represented as CQL in export """ - tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) - if tokens_to_hosts is None: - return set() + if self.virtual: + return False + comparator = getattr(self, 'comparator', None) + if comparator: + # no compact storage with more than one column beyond PK if there + # are clustering columns + incompatible = (self.is_compact_storage and + len(self.columns) > len(self.primary_key) + 1 and + len(self.clustering_key) >= 1) + + return not incompatible + return True + + extensions = None + """ + Metadata describing configuration for table extensions + """ + + def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None, virtual=False): + self.keyspace_name = keyspace_name + self.name = name + self.partition_key = [] if partition_key is None else partition_key + self.clustering_key = [] if clustering_key is None else clustering_key + self.columns = OrderedDict() if columns is None else columns + self.indexes = {} + self.options = {} if options is None else options + self.comparator = None + self.triggers = OrderedDict() if triggers is None else triggers + self.views = {} + self.virtual = virtual + + def export_as_string(self): + """ + Returns a string of CQL queries that can be used to recreate this table + along with all indexes on it. The returned string is formatted to + be human readable. + """ + if self._exc_info: + import traceback + ret = "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" % \ + (self.keyspace_name, self.name) + for line in traceback.format_exception(*self._exc_info): + ret += line + ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() + elif not self.is_cql_compatible: + # If we can't produce this table with CQL, comment inline + ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ + (self.keyspace_name, self.name) + ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() + elif self.virtual: + ret = ('/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n' + 'Structure, for reference:\n' + '{cql}\n*/').format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql()) - point = bisect_left(self.ring, token) - if point == 0 and token != self.ring[0]: - return tokens_to_hosts[self.ring[-1]] - elif point == len(self.ring): - return tokens_to_hosts[self.ring[0]] else: - return tokens_to_hosts[self.ring[point]] + ret = self._all_as_cql() + return ret -class Token(object): - """ - Abstract class representing a token. - """ + def _all_as_cql(self): + ret = self.as_cql_query(formatted=True) + ret += ";" - @classmethod - def hash_fn(cls, key): - return key + for index in self.indexes.values(): + ret += "\n%s;" % index.as_cql_query() - @classmethod - def from_key(cls, key): - return cls(cls.hash_fn(key)) + for trigger_meta in self.triggers.values(): + ret += "\n%s;" % (trigger_meta.as_cql_query(),) + + for view_meta in self.views.values(): + ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),) + + if self.extensions: + registry = _RegisteredExtensionType._extension_registry + for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey + ext = registry[k] + cql = ext.after_table_cql(self, k, self.extensions[k]) + if cql: + ret += "\n\n%s" % (cql,) + + return ret + + def as_cql_query(self, formatted=False): + """ + Returns a CQL query that can be used to recreate this table (index + creations are not included). If `formatted` is set to :const:`True`, + extra whitespace will be added to make the query human readable. + """ + ret = "%s TABLE %s.%s (%s" % ( + ('VIRTUAL' if self.virtual else 'CREATE'), + protect_name(self.keyspace_name), + protect_name(self.name), + "\n" if formatted else "") - def __cmp__(self, other): - if self.value < other.value: - return -1 - elif self.value == other.value: - return 0 + if formatted: + column_join = ",\n" + padding = " " else: - return 1 + column_join = ", " + padding = "" -MIN_LONG = -(2 ** 63) -MAX_LONG = (2 ** 63) - 1 + columns = [] + for col in self.columns.values(): + columns.append("%s %s%s" % (protect_name(col.name), col.cql_type, ' static' if col.is_static else '')) + if len(self.partition_key) == 1 and not self.clustering_key: + columns[0] += " PRIMARY KEY" -class NoMurmur3(Exception): - pass + ret += column_join.join("%s%s" % (padding, col) for col in columns) + + # primary key + if len(self.partition_key) > 1 or self.clustering_key: + ret += "%s%sPRIMARY KEY (" % (column_join, padding) + if len(self.partition_key) > 1: + ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) + else: + ret += protect_name(self.partition_key[0].name) -class Murmur3Token(Token): - """ - A token for ``Murmur3Partitioner``. - """ + if self.clustering_key: + ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) + + ret += ")" + + # properties + ret += "%s) WITH " % ("\n" if formatted else "") + ret += self._property_string(formatted, self.clustering_key, self.options, self.is_compact_storage) + + return ret @classmethod - def hash_fn(cls, key): - if murmur3 is not None: - h = murmur3(key) - return h if h != MIN_LONG else MAX_LONG - else: - raise NoMurmur3() + def _property_string(cls, formatted, clustering_key, options_map, is_compact_storage=False): + properties = [] + if is_compact_storage: + properties.append("COMPACT STORAGE") - def __init__(self, token): - """ `token` should be an int or string representing the token """ - self.value = int(token) + if clustering_key: + cluster_str = "CLUSTERING ORDER BY " - def __repr__(self): - return "" % (self.__class__.__name__, self.value) + __str__ = __repr__ + + +MIN_LONG = -(2 ** 63) +MAX_LONG = (2 ** 63) - 1 + + +class NoMurmur3(Exception): + pass + + +class HashToken(Token): + + @classmethod + def from_string(cls, token_string): + """ `token_string` should be the string representation from the server. """ + # The hash partitioners just store the deciman value + return cls(int(token_string)) + + +class Murmur3Token(HashToken): + """ + A token for ``Murmur3Partitioner``. + """ + + @classmethod + def hash_fn(cls, key): + if murmur3 is not None: + h = int(murmur3(key)) + return h if h != MIN_LONG else MAX_LONG + else: + raise NoMurmur3() + + def __init__(self, token): + """ `token` is an int or string representing the token. """ + self.value = int(token) + + +class MD5Token(HashToken): + """ + A token for ``RandomPartitioner``. + """ + + @classmethod + def hash_fn(cls, key): + if isinstance(key, str): + key = key.encode('UTF-8') + return abs(varint_unpack(md5(key,usedforsecurity=False).digest())) + + +class BytesToken(Token): + """ + A token for ``ByteOrderedPartitioner``. + """ + + @classmethod + def from_string(cls, token_string): + """ `token_string` should be the string representation from the server. """ + # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" + if isinstance(token_string, str): + token_string = token_string.encode('ascii') + # The BOP stores a hex string + return cls(unhexlify(token_string)) + + +class TriggerMetadata(object): + """ + A representation of a trigger for a table. + """ + + table = None + """ The :class:`.TableMetadata` this trigger belongs to. """ + + name = None + """ The string name of this trigger. """ + + options = None + """ + A dict mapping trigger option names to their specific settings for this + table. + """ + def __init__(self, table_metadata, trigger_name, options=None): + self.table = table_metadata + self.name = trigger_name + self.options = options + + def as_cql_query(self): + ret = "CREATE TRIGGER %s ON %s.%s USING %s" % ( + protect_name(self.name), + protect_name(self.table.keyspace_name), + protect_name(self.table.name), + protect_value(self.options['class']) + ) + return ret + + def export_as_string(self): + return self.as_cql_query() + ';' + + +class _SchemaParser(object): + + def __init__(self, connection, timeout): + self.connection = connection + self.timeout = timeout + + def _handle_results(self, success, result, expected_failures=tuple()): + """ + Given a bool and a ResultSet (the form returned per result from + Connection.wait_for_responses), return a dictionary containing the + results. Used to process results from asynchronous queries to system + tables. + + ``expected_failures`` will usually be used to allow callers to ignore + ``InvalidRequest`` errors caused by a missing system keyspace. For + example, some DSE versions report a 4.X server version, but do not have + virtual tables. Thus, running against 4.X servers, SchemaParserV4 uses + expected_failures to make a best-effort attempt to read those + keyspaces, but treat them as empty if they're not found. + + :param success: A boolean representing whether or not the query + succeeded + :param result: The resultset in question. + :expected_failures: An Exception class or an iterable thereof. If the + query failed, but raised an instance of an expected failure class, this + will ignore the failure and return an empty list. + """ + if not success and isinstance(result, expected_failures): + return [] + elif success: + return dict_factory(result.column_names, result.parsed_rows) if result else [] + else: + raise result + + def _query_build_row(self, query_string, build_func): + result = self._query_build_rows(query_string, build_func) + return result[0] if result else None + + def _query_build_rows(self, query_string, build_func): + query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE) + responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False) + (success, response) = responses[0] + if success: + result = dict_factory(response.column_names, response.parsed_rows) + return [build_func(row) for row in result] + elif isinstance(response, InvalidRequest): + log.debug("user types table not found") + return [] + else: + raise response + + +class SchemaParserV22(_SchemaParser): + """ + For C* 2.2+ + """ + _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" + _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" + _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" + _SELECT_TRIGGERS = "SELECT * FROM system.schema_triggers" + _SELECT_TYPES = "SELECT * FROM system.schema_usertypes" + _SELECT_FUNCTIONS = "SELECT * FROM system.schema_functions" + _SELECT_AGGREGATES = "SELECT * FROM system.schema_aggregates" + + _table_name_col = 'columnfamily_name' + + _function_agg_arument_type_col = 'signature' + + recognized_table_options = ( + "comment", + "read_repair_chance", + "dclocal_read_repair_chance", # kept to be safe, but see _build_table_options() + "local_read_repair_chance", + "replicate_on_write", + "gc_grace_seconds", + "bloom_filter_fp_chance", + "caching", + "compaction_strategy_class", + "compaction_strategy_options", + "min_compaction_threshold", + "max_compaction_threshold", + "compression_parameters", + "min_index_interval", + "max_index_interval", + "index_interval", + "speculative_retry", + "rows_per_partition_to_cache", + "memtable_flush_period_in_ms", + "populate_io_cache_on_flush", + "compression", + "default_time_to_live") + + def __init__(self, connection, timeout): + super(SchemaParserV22, self).__init__(connection, timeout) + self.keyspaces_result = [] + self.tables_result = [] + self.columns_result = [] + self.triggers_result = [] + self.types_result = [] + self.functions_result = [] + self.aggregates_result = [] + + self.keyspace_table_rows = defaultdict(list) + self.keyspace_table_col_rows = defaultdict(lambda: defaultdict(list)) + self.keyspace_type_rows = defaultdict(list) + self.keyspace_func_rows = defaultdict(list) + self.keyspace_agg_rows = defaultdict(list) + self.keyspace_table_trigger_rows = defaultdict(lambda: defaultdict(list)) + + def get_all_keyspaces(self): + self._query_all() + + for row in self.keyspaces_result: + keyspace_meta = self._build_keyspace_metadata(row) + + try: + for table_row in self.keyspace_table_rows.get(keyspace_meta.name, []): + table_meta = self._build_table_metadata(table_row) + keyspace_meta._add_table_metadata(table_meta) + + for usertype_row in self.keyspace_type_rows.get(keyspace_meta.name, []): + usertype = self._build_user_type(usertype_row) + keyspace_meta.user_types[usertype.name] = usertype + + for fn_row in self.keyspace_func_rows.get(keyspace_meta.name, []): + fn = self._build_function(fn_row) + keyspace_meta.functions[fn.signature] = fn + + for agg_row in self.keyspace_agg_rows.get(keyspace_meta.name, []): + agg = self._build_aggregate(agg_row) + keyspace_meta.aggregates[agg.signature] = agg + except Exception: + log.exception("Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", keyspace_meta.name) + keyspace_meta._exc_info = sys.exc_info() + + yield keyspace_meta + + def get_table(self, keyspaces, keyspace, table): + cl = ConsistencyLevel.ONE + where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), (keyspace, table), _encoder) + cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) + col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) + triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) + (cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \ + = self.connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self.timeout, fail_on_error=False) + table_result = self._handle_results(cf_success, cf_result) + col_result = self._handle_results(col_success, col_result) + + # the triggers table doesn't exist in C* 1.2 + triggers_result = self._handle_results(triggers_success, triggers_result, + expected_failures=InvalidRequest) + + if table_result: + return self._build_table_metadata(table_result[0], col_result, triggers_result) + + def get_type(self, keyspaces, keyspace, type): + where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder) + return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type) + + def get_types_map(self, keyspaces, keyspace): + where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) + types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type) + return dict((t.name, t) for t in types) + + def get_function(self, keyspaces, keyspace, function): + where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), + (keyspace, function.name, function.argument_types), _encoder) + return self._query_build_row(self._SELECT_FUNCTIONS + where_clause, self._build_function) + + def get_aggregate(self, keyspaces, keyspace, aggregate): + where_clause = bind_params(" WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), + (keyspace, aggregate.name, aggregate.argument_types), _encoder) + + return self._query_build_row(self._SELECT_AGGREGATES + where_clause, self._build_aggregate) + + def get_keyspace(self, keyspaces, keyspace): + where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) + return self._query_build_row(self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata) + + @classmethod + def _build_keyspace_metadata(cls, row): + try: + ksm = cls._build_keyspace_metadata_internal(row) + except Exception: + name = row["keyspace_name"] + ksm = KeyspaceMetadata(name, False, 'UNKNOWN', {}) + ksm._exc_info = sys.exc_info() # capture exc_info before log because nose (test) logging clears it in certain circumstances + log.exception("Error while parsing metadata for keyspace %s row(%s)", name, row) + return ksm + + @staticmethod + def _build_keyspace_metadata_internal(row): + name = row["keyspace_name"] + durable_writes = row["durable_writes"] + strategy_class = row["strategy_class"] + strategy_options = json.loads(row["strategy_options"]) + return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + + @classmethod + def _build_user_type(cls, usertype_row): + field_types = list(map(cls._schema_type_to_cql, usertype_row['field_types'])) + return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], + usertype_row['field_names'], field_types) + + @classmethod + def _build_function(cls, function_row): + return_type = cls._schema_type_to_cql(function_row['return_type']) + deterministic = function_row.get('deterministic', False) + monotonic = function_row.get('monotonic', False) + monotonic_on = function_row.get('monotonic_on', ()) + return Function(function_row['keyspace_name'], function_row['function_name'], + function_row[cls._function_agg_arument_type_col], function_row['argument_names'], + return_type, function_row['language'], function_row['body'], + function_row['called_on_null_input'], + deterministic, monotonic, monotonic_on) + + @classmethod + def _build_aggregate(cls, aggregate_row): + cass_state_type = types.lookup_casstype(aggregate_row['state_type']) + initial_condition = aggregate_row['initcond'] + if initial_condition is not None: + initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) + state_type = _cql_from_cass_type(cass_state_type) + return_type = cls._schema_type_to_cql(aggregate_row['return_type']) + return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], + aggregate_row['signature'], aggregate_row['state_func'], state_type, + aggregate_row['final_func'], initial_condition, return_type, + aggregate_row.get('deterministic', False)) + + def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): + keyspace_name = row["keyspace_name"] + cfname = row[self._table_name_col] + + col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][cfname] + trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname] + + if not col_rows: # CASSANDRA-8487 + log.warning("Building table metadata with no column meta for %s.%s", + keyspace_name, cfname) + + table_meta = TableMetadata(keyspace_name, cfname) + + try: + comparator = types.lookup_casstype(row["comparator"]) + table_meta.comparator = comparator + + is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) + is_composite_comparator = issubclass(comparator, types.CompositeType) + column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) + + num_column_name_components = len(column_name_types) + last_col = column_name_types[-1] + + column_aliases = row.get("column_aliases", None) + + clustering_rows = [r for r in col_rows + if r.get('type', None) == "clustering_key"] + if len(clustering_rows) > 1: + clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index')) + + if column_aliases is not None: + column_aliases = json.loads(column_aliases) + + if not column_aliases: # json load failed or column_aliases empty PYTHON-562 + column_aliases = [r.get('column_name') for r in clustering_rows] + + if is_composite_comparator: + if issubclass(last_col, types.ColumnToCollectionType): + # collections + is_compact = False + has_value = False + clustering_size = num_column_name_components - 2 + elif (len(column_aliases) == num_column_name_components - 1 and + issubclass(last_col, types.UTF8Type)): + # aliases? + is_compact = False + has_value = False + clustering_size = num_column_name_components - 1 + else: + # compact table + is_compact = True + has_value = column_aliases or not col_rows + clustering_size = num_column_name_components + + # Some thrift tables define names in composite types (see PYTHON-192) + if not column_aliases and hasattr(comparator, 'fieldnames'): + column_aliases = filter(None, comparator.fieldnames) + else: + is_compact = True + if column_aliases or not col_rows or is_dct_comparator: + has_value = True + clustering_size = num_column_name_components + else: + has_value = False + clustering_size = 0 + + # partition key + partition_rows = [r for r in col_rows + if r.get('type', None) == "partition_key"] + + if len(partition_rows) > 1: + partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index')) + + key_aliases = row.get("key_aliases") + if key_aliases is not None: + key_aliases = json.loads(key_aliases) if key_aliases else [] + else: + # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. + key_aliases = [r.get('column_name') for r in partition_rows] + + key_validator = row.get("key_validator") + if key_validator is not None: + key_type = types.lookup_casstype(key_validator) + key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] + else: + key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows] + + for i, col_type in enumerate(key_types): + if len(key_aliases) > i: + column_name = key_aliases[i] + elif i == 0: + column_name = "key" + else: + column_name = "key%d" % i + + col = ColumnMetadata(table_meta, column_name, col_type.cql_parameterized_type()) + table_meta.columns[column_name] = col + table_meta.partition_key.append(col) + + # clustering key + for i in range(clustering_size): + if len(column_aliases) > i: + column_name = column_aliases[i] + else: + column_name = "column%d" % (i + 1) + + data_type = column_name_types[i] + cql_type = _cql_from_cass_type(data_type) + is_reversed = types.is_reversed_casstype(data_type) + col = ColumnMetadata(table_meta, column_name, cql_type, is_reversed=is_reversed) + table_meta.columns[column_name] = col + table_meta.clustering_key.append(col) + + # value alias (if present) + if has_value: + value_alias_rows = [r for r in col_rows + if r.get('type', None) == "compact_value"] + + if not key_aliases: # TODO are we checking the right thing here? + value_alias = "value" + else: + value_alias = row.get("value_alias", None) + if value_alias is None and value_alias_rows: # CASSANDRA-8487 + # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. + value_alias = value_alias_rows[0].get('column_name') + + default_validator = row.get("default_validator") + if default_validator: + validator = types.lookup_casstype(default_validator) + else: + if value_alias_rows: # CASSANDRA-8487 + validator = types.lookup_casstype(value_alias_rows[0].get('validator')) + + cql_type = _cql_from_cass_type(validator) + col = ColumnMetadata(table_meta, value_alias, cql_type) + if value_alias: # CASSANDRA-8487 + table_meta.columns[value_alias] = col + + # other normal columns + for col_row in col_rows: + column_meta = self._build_column_metadata(table_meta, col_row) + if column_meta.name is not None: + table_meta.columns[column_meta.name] = column_meta + index_meta = self._build_index_metadata(column_meta, col_row) + if index_meta: + table_meta.indexes[index_meta.name] = index_meta + + for trigger_row in trigger_rows: + trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) + table_meta.triggers[trigger_meta.name] = trigger_meta + + table_meta.options = self._build_table_options(row) + table_meta.is_compact_storage = is_compact + except Exception: + table_meta._exc_info = sys.exc_info() + log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, cfname, row, col_rows) + + return table_meta + + def _build_table_options(self, row): + """ Setup the mostly-non-schema table options, like caching settings """ + options = dict((o, row.get(o)) for o in self.recognized_table_options if o in row) + + # the option name when creating tables is "dclocal_read_repair_chance", + # but the column name in system.schema_columnfamilies is + # "local_read_repair_chance". We'll store this as dclocal_read_repair_chance, + # since that's probably what users are expecting (and we need it for the + # CREATE TABLE statement anyway). + if "local_read_repair_chance" in options: + val = options.pop("local_read_repair_chance") + options["dclocal_read_repair_chance"] = val + + return options + + @classmethod + def _build_column_metadata(cls, table_metadata, row): + name = row["column_name"] + type_string = row["validator"] + data_type = types.lookup_casstype(type_string) + cql_type = _cql_from_cass_type(data_type) + is_static = row.get("type", None) == "static" + is_reversed = types.is_reversed_casstype(data_type) + column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) + column_meta._cass_type = data_type + return column_meta + + @staticmethod + def _build_index_metadata(column_metadata, row): + index_name = row.get("index_name") + kind = row.get("index_type") + if index_name or kind: + options = row.get("index_options") + options = json.loads(options) if options else {} + options = options or {} # if the json parsed to None, init empty dict + + # generate a CQL index identity string + target = protect_name(column_metadata.name) + if kind != "CUSTOM": + if "index_keys" in options: + target = 'keys(%s)' % (target,) + elif "index_values" in options: + # don't use any "function" for collection values + pass + else: + # it might be a "full" index on a frozen collection, but + # we need to check the data type to verify that, because + # there is no special index option for full-collection + # indexes. + data_type = column_metadata._cass_type + collection_types = ('map', 'set', 'list') + if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types: + # no index option for full-collection index + target = 'full(%s)' % (target,) + options['target'] = target + return IndexMetadata(column_metadata.table.keyspace_name, column_metadata.table.name, index_name, kind, options) + + @staticmethod + def _build_trigger_metadata(table_metadata, row): + name = row["trigger_name"] + options = row["trigger_options"] + trigger_meta = TriggerMetadata(table_metadata, name, options) + return trigger_meta + + def _query_all(self): + cl = ConsistencyLevel.ONE + queries = [ + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl) + ] + + ((ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result)) = ( + self.connection.wait_for_responses(*queries, timeout=self.timeout, + fail_on_error=False) + ) + + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + + # if we're connected to Cassandra < 2.0, the triggers table will not exist + if triggers_success: + self.triggers_result = dict_factory(triggers_result.column_names, triggers_result.parsed_rows) + else: + if isinstance(triggers_result, InvalidRequest): + log.debug("triggers table not found") + elif isinstance(triggers_result, Unauthorized): + log.warning("this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " + "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.") + else: + raise triggers_result + + # if we're connected to Cassandra < 2.1, the usertypes table will not exist + if types_success: + self.types_result = dict_factory(types_result.column_names, types_result.parsed_rows) + else: + if isinstance(types_result, InvalidRequest): + log.debug("user types table not found") + self.types_result = {} + else: + raise types_result + + # functions were introduced in Cassandra 2.2 + if functions_success: + self.functions_result = dict_factory(functions_result.column_names, functions_result.parsed_rows) + else: + if isinstance(functions_result, InvalidRequest): + log.debug("user functions table not found") + else: + raise functions_result + + # aggregates were introduced in Cassandra 2.2 + if aggregates_success: + self.aggregates_result = dict_factory(aggregates_result.column_names, aggregates_result.parsed_rows) + else: + if isinstance(aggregates_result, InvalidRequest): + log.debug("user aggregates table not found") + else: + raise aggregates_result + + self._aggregate_results() + + def _aggregate_results(self): + m = self.keyspace_table_rows + for row in self.tables_result: + m[row["keyspace_name"]].append(row) + + m = self.keyspace_table_col_rows + for row in self.columns_result: + ksname = row["keyspace_name"] + cfname = row[self._table_name_col] + m[ksname][cfname].append(row) + + m = self.keyspace_type_rows + for row in self.types_result: + m[row["keyspace_name"]].append(row) + + m = self.keyspace_func_rows + for row in self.functions_result: + m[row["keyspace_name"]].append(row) + + m = self.keyspace_agg_rows + for row in self.aggregates_result: + m[row["keyspace_name"]].append(row) + + m = self.keyspace_table_trigger_rows + for row in self.triggers_result: + ksname = row["keyspace_name"] + cfname = row[self._table_name_col] + m[ksname][cfname].append(row) + + @staticmethod + def _schema_type_to_cql(type_string): + cass_type = types.lookup_casstype(type_string) + return _cql_from_cass_type(cass_type) + + +class SchemaParserV3(SchemaParserV22): + """ + For C* 3.0+ + """ + _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" + _SELECT_TABLES = "SELECT * FROM system_schema.tables" + _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" + _SELECT_INDEXES = "SELECT * FROM system_schema.indexes" + _SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers" + _SELECT_TYPES = "SELECT * FROM system_schema.types" + _SELECT_FUNCTIONS = "SELECT * FROM system_schema.functions" + _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" + _SELECT_VIEWS = "SELECT * FROM system_schema.views" + + _table_name_col = 'table_name' + + _function_agg_arument_type_col = 'argument_types' + + _table_metadata_class = TableMetadataV3 + + recognized_table_options = ( + 'bloom_filter_fp_chance', + 'caching', + 'cdc', + 'comment', + 'compaction', + 'compression', + 'crc_check_chance', + 'dclocal_read_repair_chance', + 'default_time_to_live', + 'gc_grace_seconds', + 'max_index_interval', + 'memtable_flush_period_in_ms', + 'min_index_interval', + 'read_repair_chance', + 'speculative_retry') + + def __init__(self, connection, timeout): + super(SchemaParserV3, self).__init__(connection, timeout) + self.indexes_result = [] + self.keyspace_table_index_rows = defaultdict(lambda: defaultdict(list)) + self.keyspace_view_rows = defaultdict(list) + + def get_all_keyspaces(self): + for keyspace_meta in super(SchemaParserV3, self).get_all_keyspaces(): + for row in self.keyspace_view_rows[keyspace_meta.name]: + view_meta = self._build_view_metadata(row) + keyspace_meta._add_view_metadata(view_meta) + yield keyspace_meta + + def get_table(self, keyspaces, keyspace, table): + cl = ConsistencyLevel.ONE + where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) + cf_query = QueryMessage(query=self._SELECT_TABLES + where_clause, consistency_level=cl) + col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) + indexes_query = QueryMessage(query=self._SELECT_INDEXES + where_clause, consistency_level=cl) + triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) + + # in protocol v4 we don't know if this event is a view or a table, so we look for both + where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) + view_query = QueryMessage(query=self._SELECT_VIEWS + where_clause, + consistency_level=cl) + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), (triggers_success, triggers_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, triggers_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) + table_result = self._handle_results(cf_success, cf_result) + col_result = self._handle_results(col_success, col_result) + if table_result: + indexes_result = self._handle_results(indexes_sucess, indexes_result) + triggers_result = self._handle_results(triggers_success, triggers_result) + return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) + + view_result = self._handle_results(view_success, view_result) + if view_result: + return self._build_view_metadata(view_result[0], col_result) + + @staticmethod + def _build_keyspace_metadata_internal(row): + name = row["keyspace_name"] + durable_writes = row["durable_writes"] + strategy_options = dict(row["replication"]) + strategy_class = strategy_options.pop("class") + return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + + @staticmethod + def _build_aggregate(aggregate_row): + return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], + aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'], + aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type'], + aggregate_row.get('deterministic', False)) + + def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False): + keyspace_name = row["keyspace_name"] + table_name = row[self._table_name_col] + + col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][table_name] + trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] + index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] + + table_meta = self._table_metadata_class(keyspace_name, table_name, virtual=virtual) + try: + table_meta.options = self._build_table_options(row) + flags = row.get('flags', set()) + if flags: + is_dense = 'dense' in flags + compact_static = not is_dense and 'super' not in flags and 'compound' not in flags + table_meta.is_compact_storage = is_dense or 'super' in flags or 'compound' not in flags + elif virtual: + compact_static = False + table_meta.is_compact_storage = False + is_dense = False + else: + compact_static = True + table_meta.is_compact_storage = True + is_dense = False + + self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) + + for trigger_row in trigger_rows: + trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) + table_meta.triggers[trigger_meta.name] = trigger_meta + + for index_row in index_rows: + index_meta = self._build_index_metadata(table_meta, index_row) + if index_meta: + table_meta.indexes[index_meta.name] = index_meta + + table_meta.extensions = row.get('extensions', {}) + except Exception: + table_meta._exc_info = sys.exc_info() + log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows) + + return table_meta + + def _build_table_options(self, row): + """ Setup the mostly-non-schema table options, like caching settings """ + return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) + + def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False): + # partition key + partition_rows = [r for r in col_rows + if r.get('kind', None) == "partition_key"] + if len(partition_rows) > 1: + partition_rows = sorted(partition_rows, key=lambda row: row.get('position')) + for r in partition_rows: + # we have to add meta here (and not in the later loop) because TableMetadata.columns is an + # OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL + column_meta = self._build_column_metadata(meta, r) + meta.columns[column_meta.name] = column_meta + meta.partition_key.append(meta.columns[r.get('column_name')]) + + # clustering key + if not compact_static: + clustering_rows = [r for r in col_rows + if r.get('kind', None) == "clustering"] + if len(clustering_rows) > 1: + clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position')) + for r in clustering_rows: + column_meta = self._build_column_metadata(meta, r) + meta.columns[column_meta.name] = column_meta + meta.clustering_key.append(meta.columns[r.get('column_name')]) + + for col_row in (r for r in col_rows + if r.get('kind', None) not in ('partition_key', 'clustering_key')): + column_meta = self._build_column_metadata(meta, col_row) + if is_dense and column_meta.cql_type == types.cql_empty_type: + continue + if compact_static and not column_meta.is_static: + # for compact static tables, we omit the clustering key and value, and only add the logical columns. + # They are marked not static so that it generates appropriate CQL + continue + if compact_static: + column_meta.is_static = False + meta.columns[column_meta.name] = column_meta + + def _build_view_metadata(self, row, col_rows=None): + keyspace_name = row["keyspace_name"] + view_name = row["view_name"] + base_table_name = row["base_table_name"] + include_all_columns = row["include_all_columns"] + where_clause = row["where_clause"] + col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][view_name] + view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, + include_all_columns, where_clause, self._build_table_options(row)) + self._build_table_columns(view_meta, col_rows) + view_meta.extensions = row.get('extensions', {}) + + return view_meta + + @staticmethod + def _build_column_metadata(table_metadata, row): + name = row["column_name"] + cql_type = row["type"] + is_static = row.get("kind", None) == "static" + is_reversed = row["clustering_order"].upper() == "DESC" + column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) + return column_meta + + @staticmethod + def _build_index_metadata(table_metadata, row): + index_name = row.get("index_name") + kind = row.get("kind") + if index_name or kind: + index_options = row.get("options") + return IndexMetadata(table_metadata.keyspace_name, table_metadata.name, index_name, kind, index_options) + else: + return None + + @staticmethod + def _build_trigger_metadata(table_metadata, row): + name = row["trigger_name"] + options = row["options"] + trigger_meta = TriggerMetadata(table_metadata, name, options) + return trigger_meta + + def _query_all(self): + cl = ConsistencyLevel.ONE + queries = [ + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), + QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl) + ] + + ((ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result)) = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False + ) + + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + self.triggers_result = self._handle_results(triggers_success, triggers_result) + self.types_result = self._handle_results(types_success, types_result) + self.functions_result = self._handle_results(functions_success, functions_result) + self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.indexes_result = self._handle_results(indexes_success, indexes_result) + self.views_result = self._handle_results(views_success, views_result) + + self._aggregate_results() + + def _aggregate_results(self): + super(SchemaParserV3, self)._aggregate_results() + + m = self.keyspace_table_index_rows + for row in self.indexes_result: + ksname = row["keyspace_name"] + cfname = row[self._table_name_col] + m[ksname][cfname].append(row) + + m = self.keyspace_view_rows + for row in self.views_result: + m[row["keyspace_name"]].append(row) + + @staticmethod + def _schema_type_to_cql(type_string): + return type_string + + +class SchemaParserDSE60(SchemaParserV3): + """ + For DSE 6.0+ + """ + recognized_table_options = (SchemaParserV3.recognized_table_options + + ("nodesync",)) + + +class SchemaParserV4(SchemaParserV3): + + recognized_table_options = ( + 'additional_write_policy', + 'bloom_filter_fp_chance', + 'caching', + 'cdc', + 'comment', + 'compaction', + 'compression', + 'crc_check_chance', + 'default_time_to_live', + 'gc_grace_seconds', + 'max_index_interval', + 'memtable_flush_period_in_ms', + 'min_index_interval', + 'read_repair', + 'speculative_retry') + + _SELECT_VIRTUAL_KEYSPACES = 'SELECT * from system_virtual_schema.keyspaces' + _SELECT_VIRTUAL_TABLES = 'SELECT * from system_virtual_schema.tables' + _SELECT_VIRTUAL_COLUMNS = 'SELECT * from system_virtual_schema.columns' + + def __init__(self, connection, timeout): + super(SchemaParserV4, self).__init__(connection, timeout) + self.virtual_keyspaces_rows = defaultdict(list) + self.virtual_tables_rows = defaultdict(list) + self.virtual_columns_rows = defaultdict(lambda: defaultdict(list)) + + def _query_all(self): + cl = ConsistencyLevel.ONE + # todo: this duplicates V3; we should find a way for _query_all methods + # to extend each other. + queries = [ + # copied from V3 + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), + QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl), + # V4-only queries + QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl) + ] + + responses = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False) + ( + # copied from V3 + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result), + # V4-only responses + (virtual_ks_success, virtual_ks_result), + (virtual_table_success, virtual_table_result), + (virtual_column_success, virtual_column_result) + ) = responses + + # copied from V3 + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + self.triggers_result = self._handle_results(triggers_success, triggers_result) + self.types_result = self._handle_results(types_success, types_result) + self.functions_result = self._handle_results(functions_success, functions_result) + self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.indexes_result = self._handle_results(indexes_success, indexes_result) + self.views_result = self._handle_results(views_success, views_result) + # V4-only results + # These tables don't exist in some DSE versions reporting 4.X so we can + # ignore them if we got an error + self.virtual_keyspaces_result = self._handle_results( + virtual_ks_success, virtual_ks_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_tables_result = self._handle_results( + virtual_table_success, virtual_table_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_columns_result = self._handle_results( + virtual_column_success, virtual_column_result, + expected_failures=(InvalidRequest,) + ) + + self._aggregate_results() + + def _aggregate_results(self): + super(SchemaParserV4, self)._aggregate_results() + + m = self.virtual_tables_rows + for row in self.virtual_tables_result: + m[row["keyspace_name"]].append(row) + + m = self.virtual_columns_rows + for row in self.virtual_columns_result: + ks_name = row['keyspace_name'] + tab_name = row[self._table_name_col] + m[ks_name][tab_name].append(row) + + def get_all_keyspaces(self): + for x in super(SchemaParserV4, self).get_all_keyspaces(): + yield x + + for row in self.virtual_keyspaces_result: + ks_name = row['keyspace_name'] + keyspace_meta = self._build_keyspace_metadata(row) + keyspace_meta.virtual = True + + for table_row in self.virtual_tables_rows.get(ks_name, []): + table_name = table_row[self._table_name_col] + + col_rows = self.virtual_columns_rows[ks_name][table_name] + keyspace_meta._add_table_metadata( + self._build_table_metadata(table_row, + col_rows=col_rows, + virtual=True) + ) + yield keyspace_meta + + @staticmethod + def _build_keyspace_metadata_internal(row): + # necessary fields that aren't int virtual ks + row["durable_writes"] = row.get("durable_writes", None) + row["replication"] = row.get("replication", {}) + row["replication"]["class"] = row["replication"].get("class", None) + return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row) + + +class SchemaParserDSE67(SchemaParserV4): + """ + For DSE 6.7+ + """ + recognized_table_options = (SchemaParserV4.recognized_table_options + + ("nodesync",)) + + +class SchemaParserDSE68(SchemaParserDSE67): + """ + For DSE 6.8+ + """ + + _SELECT_VERTICES = "SELECT * FROM system_schema.vertices" + _SELECT_EDGES = "SELECT * FROM system_schema.edges" + + _table_metadata_class = TableMetadataDSE68 + + def __init__(self, connection, timeout): + super(SchemaParserDSE68, self).__init__(connection, timeout) + self.keyspace_table_vertex_rows = defaultdict(lambda: defaultdict(list)) + self.keyspace_table_edge_rows = defaultdict(lambda: defaultdict(list)) + + def get_all_keyspaces(self): + for keyspace_meta in super(SchemaParserDSE68, self).get_all_keyspaces(): + self._build_graph_metadata(keyspace_meta) + yield keyspace_meta + + def get_table(self, keyspaces, keyspace, table): + table_meta = super(SchemaParserDSE68, self).get_table(keyspaces, keyspace, table) + cl = ConsistencyLevel.ONE + where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) + vertices_query = QueryMessage(query=self._SELECT_VERTICES + where_clause, consistency_level=cl) + edges_query = QueryMessage(query=self._SELECT_EDGES + where_clause, consistency_level=cl) + + (vertices_success, vertices_result), (edges_success, edges_result) \ + = self.connection.wait_for_responses(vertices_query, edges_query, timeout=self.timeout, fail_on_error=False) + vertices_result = self._handle_results(vertices_success, vertices_result) + edges_result = self._handle_results(edges_success, edges_result) + + try: + if vertices_result: + table_meta.vertex = self._build_table_vertex_metadata(vertices_result[0]) + elif edges_result: + table_meta.edge = self._build_table_edge_metadata(keyspaces[keyspace], edges_result[0]) + except Exception: + table_meta.vertex = None + table_meta.edge = None + table_meta._exc_info = sys.exc_info() + log.exception("Error while parsing graph metadata for table %s.%s.", keyspace, table) + + return table_meta + + @staticmethod + def _build_keyspace_metadata_internal(row): + name = row["keyspace_name"] + durable_writes = row.get("durable_writes", None) + replication = dict(row.get("replication")) if 'replication' in row else {} + replication_class = replication.pop("class") if 'class' in replication else None + graph_engine = row.get("graph_engine", None) + return KeyspaceMetadata(name, durable_writes, replication_class, replication, graph_engine) + + def _build_graph_metadata(self, keyspace_meta): + + def _build_table_graph_metadata(table_meta): + for row in self.keyspace_table_vertex_rows[keyspace_meta.name][table_meta.name]: + table_meta.vertex = self._build_table_vertex_metadata(row) + + for row in self.keyspace_table_edge_rows[keyspace_meta.name][table_meta.name]: + table_meta.edge = self._build_table_edge_metadata(keyspace_meta, row) + + try: + # Make sure we process vertices before edges + for table_meta in [t for t in keyspace_meta.tables.values() + if t.name in self.keyspace_table_vertex_rows[keyspace_meta.name]]: + _build_table_graph_metadata(table_meta) + + # all other tables... + for table_meta in [t for t in keyspace_meta.tables.values() + if t.name not in self.keyspace_table_vertex_rows[keyspace_meta.name]]: + _build_table_graph_metadata(table_meta) + except Exception: + # schema error, remove all graph metadata for this keyspace + for t in keyspace_meta.tables.values(): + t.edge = t.vertex = None + keyspace_meta._exc_info = sys.exc_info() + log.exception("Error while parsing graph metadata for keyspace %s", keyspace_meta.name) + + @staticmethod + def _build_table_vertex_metadata(row): + return VertexMetadata(row.get("keyspace_name"), row.get("table_name"), + row.get("label_name")) + + @staticmethod + def _build_table_edge_metadata(keyspace_meta, row): + from_table = row.get("from_table") + from_table_meta = keyspace_meta.tables.get(from_table) + from_label = from_table_meta.vertex.label_name + to_table = row.get("to_table") + to_table_meta = keyspace_meta.tables.get(to_table) + to_label = to_table_meta.vertex.label_name + + return EdgeMetadata( + row.get("keyspace_name"), row.get("table_name"), + row.get("label_name"), from_table, from_label, + row.get("from_partition_key_columns"), + row.get("from_clustering_columns"), to_table, to_label, + row.get("to_partition_key_columns"), + row.get("to_clustering_columns")) + + def _query_all(self): + cl = ConsistencyLevel.ONE + queries = [ + # copied from v4 + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), + QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl), + # dse6.8 only + QueryMessage(query=self._SELECT_VERTICES, consistency_level=cl), + QueryMessage(query=self._SELECT_EDGES, consistency_level=cl) + ] + + responses = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False) + ( + # copied from V4 + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result), + (virtual_ks_success, virtual_ks_result), + (virtual_table_success, virtual_table_result), + (virtual_column_success, virtual_column_result), + # dse6.8 responses + (vertices_success, vertices_result), + (edges_success, edges_result) + ) = responses + + # copied from V4 + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + self.triggers_result = self._handle_results(triggers_success, triggers_result) + self.types_result = self._handle_results(types_success, types_result) + self.functions_result = self._handle_results(functions_success, functions_result) + self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.indexes_result = self._handle_results(indexes_success, indexes_result) + self.views_result = self._handle_results(views_success, views_result) + + # These tables don't exist in some DSE versions reporting 4.X so we can + # ignore them if we got an error + self.virtual_keyspaces_result = self._handle_results( + virtual_ks_success, virtual_ks_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_tables_result = self._handle_results( + virtual_table_success, virtual_table_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_columns_result = self._handle_results( + virtual_column_success, virtual_column_result, + expected_failures=(InvalidRequest,) + ) + + # dse6.8-only results + self.vertices_result = self._handle_results(vertices_success, vertices_result) + self.edges_result = self._handle_results(edges_success, edges_result) + + self._aggregate_results() + + def _aggregate_results(self): + super(SchemaParserDSE68, self)._aggregate_results() + + m = self.keyspace_table_vertex_rows + for row in self.vertices_result: + ksname = row["keyspace_name"] + cfname = row['table_name'] + m[ksname][cfname].append(row) + + m = self.keyspace_table_edge_rows + for row in self.edges_result: + ksname = row["keyspace_name"] + cfname = row['table_name'] + m[ksname][cfname].append(row) + + +class MaterializedViewMetadata(object): + """ + A representation of a materialized view on a table + """ + + keyspace_name = None + """ A string name of the keyspace of this view.""" + + name = None + """ A string name of the view.""" + + base_table_name = None + """ A string name of the base table for this view.""" + + partition_key = None + """ + A list of :class:`.ColumnMetadata` instances representing the columns in + the partition key for this view. This will always hold at least one + column. + """ + + clustering_key = None + """ + A list of :class:`.ColumnMetadata` instances representing the columns + in the clustering key for this view. + + Note that a table may have no clustering keys, in which case this will + be an empty list. + """ + + columns = None + """ + A dict mapping column names to :class:`.ColumnMetadata` instances. + """ + + include_all_columns = None + """ A flag indicating whether the view was created AS SELECT * """ + + where_clause = None + """ String WHERE clause for the view select statement. From server metadata """ + + options = None + """ + A dict mapping table option names to their specific settings for this + view. + """ + + extensions = None + """ + Metadata describing configuration for table extensions + """ + + def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): + self.keyspace_name = keyspace_name + self.name = view_name + self.base_table_name = base_table_name + self.partition_key = [] + self.clustering_key = [] + self.columns = OrderedDict() + self.include_all_columns = include_all_columns + self.where_clause = where_clause + self.options = options or {} + + def as_cql_query(self, formatted=False): + """ + Returns a CQL query that can be used to recreate this function. + If `formatted` is set to :const:`True`, extra whitespace will + be added to make the query more readable. + """ + sep = '\n ' if formatted else ' ' + keyspace = protect_name(self.keyspace_name) + name = protect_name(self.name) + + selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values()) + base_table = protect_name(self.base_table_name) + where_clause = self.where_clause + + part_key = ', '.join(protect_name(col.name) for col in self.partition_key) + if len(self.partition_key) > 1: + pk = "((%s)" % part_key + else: + pk = "(%s" % part_key + if self.clustering_key: + pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key) + pk += ")" + + properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) + + ret = ("CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" + "SELECT %(selected_cols)s%(sep)s" + "FROM %(keyspace)s.%(base_table)s%(sep)s" + "WHERE %(where_clause)s%(sep)s" + "PRIMARY KEY %(pk)s%(sep)s" + "WITH %(properties)s") % locals() + + if self.extensions: + registry = _RegisteredExtensionType._extension_registry + for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey + ext = registry[k] + cql = ext.after_table_cql(self, k, self.extensions[k]) + if cql: + ret += "\n\n%s" % (cql,) + return ret + + def export_as_string(self): + return self.as_cql_query(formatted=True) + ";" + + +class VertexMetadata(object): + """ + A representation of a vertex on a table + """ + + keyspace_name = None + """ A string name of the keyspace. """ + + table_name = None + """ A string name of the table this vertex is on. """ + + label_name = None + """ A string name of the label of this vertex.""" + + def __init__(self, keyspace_name, table_name, label_name): + self.keyspace_name = keyspace_name + self.table_name = table_name + self.label_name = label_name + + +class EdgeMetadata(object): + """ + A representation of an edge on a table + """ + + keyspace_name = None + """A string name of the keyspace """ + + table_name = None + """A string name of the table this edge is on""" + + label_name = None + """A string name of the label of this edge""" + + from_table = None + """A string name of the from table of this edge (incoming vertex)""" + + from_label = None + """A string name of the from table label of this edge (incoming vertex)""" + + from_partition_key_columns = None + """The columns that match the partition key of the incoming vertex table.""" + + from_clustering_columns = None + """The columns that match the clustering columns of the incoming vertex table.""" + + to_table = None + """A string name of the to table of this edge (outgoing vertex)""" + + to_label = None + """A string name of the to table label of this edge (outgoing vertex)""" + + to_partition_key_columns = None + """The columns that match the partition key of the outgoing vertex table.""" + + to_clustering_columns = None + """The columns that match the clustering columns of the outgoing vertex table.""" + + def __init__( + self, keyspace_name, table_name, label_name, from_table, + from_label, from_partition_key_columns, from_clustering_columns, + to_table, to_label, to_partition_key_columns, + to_clustering_columns): + self.keyspace_name = keyspace_name + self.table_name = table_name + self.label_name = label_name + self.from_table = from_table + self.from_label = from_label + self.from_partition_key_columns = from_partition_key_columns + self.from_clustering_columns = from_clustering_columns + self.to_table = to_table + self.to_label = to_label + self.to_partition_key_columns = to_partition_key_columns + self.to_clustering_columns = to_clustering_columns + + +def get_schema_parser(connection, server_version, dse_version, timeout): + version = Version(server_version) + if dse_version: + v = Version(dse_version) + if v >= Version('6.8.0'): + return SchemaParserDSE68(connection, timeout) + elif v >= Version('6.7.0'): + return SchemaParserDSE67(connection, timeout) + elif v >= Version('6.0.0'): + return SchemaParserDSE60(connection, timeout) + + if version >= Version('4.0-alpha'): + return SchemaParserV4(connection, timeout) + elif version >= Version('3.0.0'): + return SchemaParserV3(connection, timeout) + else: + # we could further specialize by version. Right now just refactoring the + # multi-version parser we have as of C* 2.2.0rc1. + return SchemaParserV22(connection, timeout) + + +def _cql_from_cass_type(cass_type): + """ + A string representation of the type for this column, such as "varchar" + or "map". + """ + if issubclass(cass_type, types.ReversedType): + return cass_type.subtypes[0].cql_parameterized_type() + else: + return cass_type.cql_parameterized_type() + + +class RLACTableExtension(RegisteredTableExtension): + name = "DSE_RLACA" + + @classmethod + def after_table_cql(cls, table_meta, ext_key, ext_blob): + return "RESTRICT ROWS ON %s.%s USING %s;" % (protect_name(table_meta.keyspace_name), + protect_name(table_meta.name), + protect_name(ext_blob.decode('utf-8'))) +NO_VALID_REPLICA = object() + + +def group_keys_by_replica(session, keyspace, table, keys): + """ + Returns a :class:`dict` with the keys grouped per host. This can be + used to more accurately group by IN clause or to batch the keys per host. + + If a valid replica is not found for a particular key it will be grouped under + :class:`~.NO_VALID_REPLICA` + + Example usage:: + result = group_keys_by_replica( + session, "system", "peers", + (("127.0.0.1", ), ("127.0.0.2", )) + ) + """ + cluster = session.cluster + + partition_keys = cluster.metadata.keyspaces[keyspace].tables[table].partition_key + + serializers = list(types._cqltypes[partition_key.cql_type] for partition_key in partition_keys) + keys_per_host = defaultdict(list) + distance = cluster._default_load_balancing_policy.distance + + for key in keys: + serialized_key = [serializer.serialize(pk, cluster.protocol_version) + for serializer, pk in zip(serializers, key)] + if len(serialized_key) == 1: + routing_key = serialized_key[0] + else: + routing_key = b"".join(struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key) + all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) + # First check if there are local replicas + valid_replicas = [host for host in all_replicas if + host.is_up and distance(host) == HostDistance.LOCAL] + if not valid_replicas: + valid_replicas = [host for host in all_replicas if host.is_up] + + if valid_replicas: + keys_per_host[random.choice(valid_replicas)].append(key) + else: + # We will group under this statement all the keys for which + # we haven't found a valid replica + keys_per_host[NO_VALID_REPLICA].append(key) + + return dict(keys_per_host) + + +# TODO next major reorg +class _NodeInfo(object): + """ + Internal utility functions to determine the different host addresses/ports + from a local or peers row. + """ + + @staticmethod + def get_broadcast_rpc_address(row): + # TODO next major, change the parsing logic to avoid any + # overriding of a non-null value + addr = row.get("rpc_address") + if "native_address" in row: + addr = row.get("native_address") + if "native_transport_address" in row: + addr = row.get("native_transport_address") + if not addr or addr in ["0.0.0.0", "::"]: + addr = row.get("peer") + + return addr + + @staticmethod + def get_broadcast_rpc_port(row): + port = row.get("rpc_port") + if port is None or port == 0: + port = row.get("native_port") + + return port if port and port > 0 else None + + @staticmethod + def get_broadcast_address(row): + addr = row.get("broadcast_address") + if addr is None: + addr = row.get("peer") + + return addr + + @staticmethod + def get_broadcast_port(row): + port = row.get("broadcast_port") + if port is None or port == 0: + port = row.get("peer_port") + + return port if port and port > 0 else None diff --git a/cassandra/metrics.py b/cassandra/metrics.py index c5c5380127..a1eadc1fc4 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -1,28 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from itertools import chain import logging -from greplin import scales +try: + from greplin import scales +except ImportError: + raise ImportError( + "The scales library is required for metrics support: " + "https://pypi.org/project/scales/") log = logging.getLogger(__name__) + class Metrics(object): + """ + A collection of timers and counters for various performance metrics. + + Timer metrics are represented as floating point seconds. + """ request_timer = None """ A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like object with the following keys: - * count - number of requests that have been timed - * min - min latency - * max - max latency - * mean - mean latency - * stdev - standard deviation for latencies - * median - median latency - * 75percentile - 75th percentile latencies - * 97percentile - 97th percentile latencies - * 98percentile - 98th percentile latencies - * 99percentile - 99th percentile latencies - * 999percentile - 99.9th percentile latencies + * count - number of requests that have been timed + * min - min latency + * max - max latency + * mean - mean latency + * stddev - standard deviation for latencies + * median - median latency + * 75percentile - 75th percentile latencies + * 95percentile - 95th percentile latencies + * 98percentile - 98th percentile latencies + * 99percentile - 99th percentile latencies + * 999percentile - 99.9th percentile latencies """ connection_errors = None @@ -88,10 +115,14 @@ class Metrics(object): the driver currently has open. """ + _stats_counter = 0 + def __init__(self, cluster_proxy): log.debug("Starting metric capture") - self.stats = scales.collection('/cassandra', + self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) + Metrics._stats_counter += 1 + self.stats = scales.collection(self.stats_name, scales.PmfStat('request_timer'), scales.IntStat('connection_errors'), scales.IntStat('write_timeouts'), @@ -109,6 +140,11 @@ def __init__(self, cluster_proxy): scales.Stat('open_connections', lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) + # TODO, to be removed in 4.0 + # /cassandra contains the metrics of the first cluster registered + if 'cassandra' not in scales._Stats.stats: + scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] + self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors self.write_timeouts = self.stats.write_timeouts @@ -141,3 +177,27 @@ def on_ignore(self): def on_retry(self): self.stats.retries += 1 + + def get_stats(self): + """ + Returns the metrics for the registered cluster instance. + """ + return scales.getStats()[self.stats_name] + + def set_stats_name(self, stats_name): + """ + Set the metrics stats name. + The stats_name is a string used to access the metrics through scales: scales.getStats()[] + Default is 'cassandra-'. + """ + + if self.stats_name == stats_name: + return + + if stats_name in scales._Stats.stats: + raise ValueError('"{0}" already exists in stats.'.format(stats_name)) + + stats = scales._Stats.stats[self.stats_name] + del scales._Stats.stats[self.stats_name] + self.stats_name = stats_name + scales._Stats.stats[self.stats_name] = stats diff --git a/cassandra/murmur3.c b/cassandra/murmur3.c deleted file mode 100644 index 62b89d78cf..0000000000 --- a/cassandra/murmur3.c +++ /dev/null @@ -1,207 +0,0 @@ -/* - * The majority of this code was taken from the python-smhasher library, - * which can be found here: https://github.com/phensley/python-smhasher - * - * That library is under the MIT license with the following copyright: - * - * Copyright (c) 2011 Austin Appleby (Murmur3 routine) - * Copyright (c) 2011 Patrick Hensley (Python wrapper, packaging) - * - */ - -#include -#include - -#if PY_VERSION_HEX < 0x02050000 -typedef int Py_ssize_t; -#define PY_SSIZE_T_MAX INT_MAX -#define PY_SSIZE_T_MIN INT_MIN -#endif - -//----------------------------------------------------------------------------- -// Platform-specific functions and macros - -// Microsoft Visual Studio - -#if defined(_MSC_VER) - -typedef unsigned char uint8_t; -typedef unsigned long uint32_t; -typedef unsigned __int64 uint64_t; - -#define FORCE_INLINE __forceinline - -#include - -#define ROTL32(x,y) _rotl(x,y) -#define ROTL64(x,y) _rotl64(x,y) - -#define BIG_CONSTANT(x) (x) - -// Other compilers - -#else // defined(_MSC_VER) - -#include - -#define FORCE_INLINE inline __attribute__((always_inline)) - -inline uint32_t rotl32 ( uint32_t x, int8_t r ) -{ - return (x << r) | (x >> (32 - r)); -} - -inline uint64_t rotl64 ( uint64_t x, int8_t r ) -{ - return (x << r) | (x >> (64 - r)); -} - -#define ROTL32(x,y) rotl32(x,y) -#define ROTL64(x,y) rotl64(x,y) - -#define BIG_CONSTANT(x) (x##LLU) - -#endif // !defined(_MSC_VER) - -//----------------------------------------------------------------------------- -// Block read - if your platform needs to do endian-swapping or can only -// handle aligned reads, do the conversion here - -// TODO 32bit? - -FORCE_INLINE uint64_t getblock ( const uint64_t * p, int i ) -{ - return p[i]; -} - -//----------------------------------------------------------------------------- -// Finalization mix - force all bits of a hash block to avalanche - -FORCE_INLINE uint64_t fmix ( uint64_t k ) -{ - k ^= k >> 33; - k *= BIG_CONSTANT(0xff51afd7ed558ccd); - k ^= k >> 33; - k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); - k ^= k >> 33; - - return k; -} - -uint64_t MurmurHash3_x64_128 (const void * key, const int len, - const uint32_t seed) -{ - const uint8_t * data = (const uint8_t*)key; - const int nblocks = len / 16; - - uint64_t h1 = seed; - uint64_t h2 = seed; - - uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); - uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); - - //---------- - // body - - const uint64_t * blocks = (const uint64_t *)(data); - - int i; - for(i = 0; i < nblocks; i++) - { - uint64_t k1 = getblock(blocks,i*2+0); - uint64_t k2 = getblock(blocks,i*2+1); - - k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; - - h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; - - k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; - - h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; - - } - - //---------- - // tail - - const uint8_t * tail = (const uint8_t*)(data + nblocks*16); - - uint64_t k1 = 0; - uint64_t k2 = 0; - - switch(len & 15) - { - case 15: k2 ^= (uint64_t)(tail[14]) << 48; - case 14: k2 ^= (uint64_t)(tail[13]) << 40; - case 13: k2 ^= (uint64_t)(tail[12]) << 32; - case 12: k2 ^= (uint64_t)(tail[11]) << 24; - case 11: k2 ^= (uint64_t)(tail[10]) << 16; - case 10: k2 ^= (uint64_t)(tail[ 9]) << 8; - case 9: k2 ^= (uint64_t)(tail[ 8]) << 0; - k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; - - case 8: k1 ^= (uint64_t)(tail[ 7]) << 56; - case 7: k1 ^= (uint64_t)(tail[ 6]) << 48; - case 6: k1 ^= (uint64_t)(tail[ 5]) << 40; - case 5: k1 ^= (uint64_t)(tail[ 4]) << 32; - case 4: k1 ^= (uint64_t)(tail[ 3]) << 24; - case 3: k1 ^= (uint64_t)(tail[ 2]) << 16; - case 2: k1 ^= (uint64_t)(tail[ 1]) << 8; - case 1: k1 ^= (uint64_t)(tail[ 0]) << 0; - k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; h2 ^= len; - - h1 += h2; - h2 += h1; - - h1 = fmix(h1); - h2 = fmix(h2); - - h1 += h2; - h2 += h1; - - return h1; -} - -static PyObject * -murmur3(PyObject *self, PyObject *args) -{ - const char *key; - Py_ssize_t len; - uint32_t seed = 0; - - if (!PyArg_ParseTuple(args, "s#|I", &key, &len, &seed)) { - return NULL; - } - - // TODO handle x86 version? - uint64_t result = MurmurHash3_x64_128((void *)key, len, seed); - return (PyObject *) PyLong_FromLong((long int)result); -} - -static PyMethodDef murmur3_methods[] = { - {"murmur3", murmur3, METH_VARARGS, - "Make an x64 murmur3 64-bit hash value"}, - - {NULL, NULL, 0, NULL} -}; - -#if PY_MAJOR_VERSION <= 2 - -PyMODINIT_FUNC -initmurmur3(void) -{ - (void) Py_InitModule("murmur3", murmur3_methods); -} - -#else - -/* Python 3.x */ -// TODO - -#endif diff --git a/cassandra/murmur3.py b/cassandra/murmur3.py new file mode 100644 index 0000000000..282c43578d --- /dev/null +++ b/cassandra/murmur3.py @@ -0,0 +1,114 @@ +import struct + + +def body_and_tail(data): + l = len(data) + nblocks = l // 16 + tail = l % 16 + if nblocks: + # we use '<', specifying little-endian byte order for data bigger than + # a byte so behavior is the same on little- and big-endian platforms + return struct.unpack_from('<' + ('qq' * nblocks), data), struct.unpack_from('b' * tail, data, -tail), l + else: + return tuple(), struct.unpack_from('b' * tail, data, -tail), l + + +def rotl64(x, r): + # note: not a general-purpose function because it leaves the high-order bits intact + # suitable for this use case without wasting cycles + mask = 2 ** r - 1 + rotated = (x << r) | ((x >> 64 - r) & mask) + return rotated + + +def fmix(k): + # masking off the 31s bits that would be leftover after >> 33 a 64-bit number + k ^= (k >> 33) & 0x7fffffff + k *= 0xff51afd7ed558ccd + k ^= (k >> 33) & 0x7fffffff + k *= 0xc4ceb9fe1a85ec53 + k ^= (k >> 33) & 0x7fffffff + return k + + +INT64_MAX = int(2 ** 63 - 1) +INT64_MIN = -INT64_MAX - 1 +INT64_OVF_OFFSET = INT64_MAX + 1 +INT64_OVF_DIV = 2 * INT64_OVF_OFFSET + + +def truncate_int64(x): + if not INT64_MIN <= x <= INT64_MAX: + x = (x + INT64_OVF_OFFSET) % INT64_OVF_DIV - INT64_OVF_OFFSET + return x + + +def _murmur3(data): + + h1 = h2 = 0 + + c1 = -8663945395140668459 # 0x87c37b91114253d5 + c2 = 0x4cf5ad432745937f + + body, tail, total_len = body_and_tail(data) + + # body + for i in range(0, len(body), 2): + k1 = body[i] + k2 = body[i + 1] + + k1 *= c1 + k1 = rotl64(k1, 31) + k1 *= c2 + h1 ^= k1 + + h1 = rotl64(h1, 27) + h1 += h2 + h1 = h1 * 5 + 0x52dce729 + + k2 *= c2 + k2 = rotl64(k2, 33) + k2 *= c1 + h2 ^= k2 + + h2 = rotl64(h2, 31) + h2 += h1 + h2 = h2 * 5 + 0x38495ab5 + + # tail + k1 = k2 = 0 + len_tail = len(tail) + if len_tail > 8: + for i in range(len_tail - 1, 7, -1): + k2 ^= tail[i] << (i - 8) * 8 + k2 *= c2 + k2 = rotl64(k2, 33) + k2 *= c1 + h2 ^= k2 + + if len_tail: + for i in range(min(7, len_tail - 1), -1, -1): + k1 ^= tail[i] << i * 8 + k1 *= c1 + k1 = rotl64(k1, 31) + k1 *= c2 + h1 ^= k1 + + # finalization + h1 ^= total_len + h2 ^= total_len + + h1 += h2 + h2 += h1 + + h1 = fmix(h1) + h2 = fmix(h2) + + h1 += h2 + + return truncate_int64(h1) + +try: + from cassandra.cmurmur3 import murmur3 +except ImportError: + murmur3 = _murmur3 diff --git a/cassandra/numpyFlags.h b/cassandra/numpyFlags.h new file mode 100644 index 0000000000..db464ed9ac --- /dev/null +++ b/cassandra/numpyFlags.h @@ -0,0 +1 @@ +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx new file mode 100644 index 0000000000..2377258b36 --- /dev/null +++ b/cassandra/numpy_parser.pyx @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides an optional protocol parser that returns +NumPy arrays. + +============================================================================= +This module should not be imported by any of the main python-driver modules, +as numpy is an optional dependency. +============================================================================= +""" + +include "ioutils.pyx" + +cimport cython +from libc.stdint cimport uint64_t, uint8_t +from cpython.ref cimport Py_INCREF, PyObject + +from cassandra.bytesio cimport BytesIOReader +from cassandra.deserializers cimport Deserializer, from_binary +from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser +from cassandra import cqltypes +from cassandra.util import is_little_endian + +import numpy as np + +cdef extern from "numpyFlags.h": + # Include 'numpyFlags.h' into the generated C code to disable the + # deprecated NumPy API + pass + +cdef extern from "Python.h": + # An integer type large enough to hold a pointer + ctypedef uint64_t Py_uintptr_t + + +# Simple array descriptor, useful to parse rows into a NumPy array +ctypedef struct ArrDesc: + Py_uintptr_t buf_ptr + int stride # should be large enough as we allocate contiguous arrays + int is_object + Py_uintptr_t mask_ptr + +arrDescDtype = np.dtype( + [ ('buf_ptr', np.uintp) + , ('stride', np.dtype('i')) + , ('is_object', np.dtype('i')) + , ('mask_ptr', np.uintp) + ], align=True) + +_cqltype_to_numpy = { + cqltypes.LongType: np.dtype('>i8'), + cqltypes.CounterColumnType: np.dtype('>i8'), + cqltypes.Int32Type: np.dtype('>i4'), + cqltypes.ShortType: np.dtype('>i2'), + cqltypes.FloatType: np.dtype('>f4'), + cqltypes.DoubleType: np.dtype('>f8'), +} + +obj_dtype = np.dtype('O') + +cdef uint8_t mask_true = 0x01 + +cdef class NumpyParser(ColumnParser): + """Decode a ResultMessage into a bunch of NumPy arrays""" + + cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): + cdef Py_ssize_t rowcount + cdef ArrDesc[::1] array_descs + cdef ArrDesc *arrs + + rowcount = read_int(reader) + array_descs, arrays = make_arrays(desc, rowcount) + arrs = &array_descs[0] + + _parse_rows(reader, desc, arrs, rowcount) + + arrays = [make_native_byteorder(arr) for arr in arrays] + result = dict(zip(desc.colnames, arrays)) + return result + + +cdef _parse_rows(BytesIOReader reader, ParseDesc desc, + ArrDesc *arrs, Py_ssize_t rowcount): + cdef Py_ssize_t i + + for i in range(rowcount): + unpack_row(reader, desc, arrs) + + +### Helper functions to create NumPy arrays and array descriptors + +def make_arrays(ParseDesc desc, array_size): + """ + Allocate arrays for each result column. + + returns a tuple of (array_descs, arrays), where + 'array_descs' describe the arrays for NativeRowParser and + 'arrays' is a dict mapping column names to arrays + (e.g. this can be fed into pandas.DataFrame) + """ + array_descs = np.empty((desc.rowsize,), arrDescDtype) + arrays = [] + + for i, coltype in enumerate(desc.coltypes): + arr = make_array(coltype, array_size) + array_descs[i]['buf_ptr'] = arr.ctypes.data + array_descs[i]['stride'] = arr.strides[0] + array_descs[i]['is_object'] = arr.dtype is obj_dtype + try: + array_descs[i]['mask_ptr'] = arr.mask.ctypes.data + except AttributeError: + array_descs[i]['mask_ptr'] = 0 + arrays.append(arr) + + return array_descs, arrays + + +def make_array(coltype, array_size): + """ + Allocate a new NumPy array of the given column type and size. + """ + try: + a = np.ma.empty((array_size,), dtype=_cqltype_to_numpy[coltype]) + a.mask = np.zeros((array_size,), dtype=bool) + except KeyError: + a = np.empty((array_size,), dtype=obj_dtype) + return a + + +#### Parse rows into NumPy arrays + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef inline int unpack_row( + BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1: + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef ArrDesc arr + cdef Deserializer deserializer + for i in range(rowsize): + get_buf(reader, &buf) + arr = arrays[i] + + if arr.is_object: + deserializer = desc.deserializers[i] + val = from_binary(deserializer, &buf, desc.protocol_version) + Py_INCREF(val) + ( arr.buf_ptr)[0] = val + elif buf.size >= 0: + memcpy( arr.buf_ptr, buf.ptr, buf.size) + else: + memcpy(arr.mask_ptr, &mask_true, 1) + + # Update the pointer into the array for the next time + arrays[i].buf_ptr += arr.stride + arrays[i].mask_ptr += 1 + + return 0 + + +def make_native_byteorder(arr): + """ + Make sure all values have a native endian in the NumPy arrays. + """ + if is_little_endian and not arr.dtype.kind == 'O': + # We have arrays in big-endian order. First swap the bytes + # into little endian order, and then update the numpy dtype + # accordingly (e.g. from '>i8' to '= 0 + + cdef Buffer buf + cdef Buffer newbuf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef Deserializer deserializer + cdef tuple res = tuple_new(desc.rowsize) + + ce_policy = desc.column_encryption_policy + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + coldesc = desc.coldescs[i] + uses_ce = ce_policy and ce_policy.contains_column(coldesc) + try: + if uses_ce: + col_type = ce_policy.column_type(coldesc) + decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) + PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size) + deserializer = find_deserializer(ce_policy.column_type(coldesc)) + val = from_binary(deserializer, &newbuf, desc.protocol_version) + else: + val = from_binary(deserializer, &buf, desc.protocol_version) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) + # Insert new object into tuple + tuple_set(res, i, val) + + return res diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd new file mode 100644 index 0000000000..1b3ed3dcbf --- /dev/null +++ b/cassandra/parsing.pxd @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.bytesio cimport BytesIOReader +from cassandra.deserializers cimport Deserializer + +cdef class ParseDesc: + cdef public object colnames + cdef public object coltypes + cdef public object column_encryption_policy + cdef public list coldescs + cdef Deserializer[::1] deserializers + cdef public int protocol_version + cdef Py_ssize_t rowsize + +cdef class ColumnParser: + cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc) + +cdef class RowParser: + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc) + diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx new file mode 100644 index 0000000000..085544a362 --- /dev/null +++ b/cassandra/parsing.pyx @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module containing the definitions and declarations (parsing.pxd) for parsers. +""" + +cdef class ParseDesc: + """Description of what structure to parse""" + + def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version): + self.colnames = colnames + self.coltypes = coltypes + self.column_encryption_policy = column_encryption_policy + self.coldescs = coldescs + self.deserializers = deserializers + self.protocol_version = protocol_version + self.rowsize = len(colnames) + + +cdef class ColumnParser: + """Decode a ResultMessage into a set of columns""" + + cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): + raise NotImplementedError + + +cdef class RowParser: + """Parser for a single row""" + + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + """ + Unpack a single row of data in a ResultMessage. + """ + raise NotImplementedError diff --git a/cassandra/policies.py b/cassandra/policies.py index 6b7bf5837a..d6f7063e7a 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1,11 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from functools import lru_cache from itertools import islice, cycle, groupby, repeat import logging -from random import randint - -from cassandra import ConsistencyLevel +from random import randint, shuffle +from threading import Lock +import socket +import warnings log = logging.getLogger(__name__) +from cassandra import WriteType as WT + +# This is done this way because WriteType was originally +# defined here and in order not to break the API. +# It may be removed in the next major. +WriteType = WT + +from cassandra import ConsistencyLevel, OperationTimedOut + class HostDistance(object): """ A measure of how "distant" a node is from the client, which @@ -42,7 +70,29 @@ class HostDistance(object): """ -class LoadBalancingPolicy(object): +class HostStateListener(object): + + def on_up(self, host): + """ Called when a node is marked up. """ + raise NotImplementedError() + + def on_down(self, host): + """ Called when a node is marked down. """ + raise NotImplementedError() + + def on_add(self, host): + """ + Called when a node is added to the cluster. The newly added node + should be considered up. + """ + raise NotImplementedError() + + def on_remove(self, host): + """ Called when a node is removed from the cluster. """ + raise NotImplementedError() + + +class LoadBalancingPolicy(HostStateListener): """ Load balancing policies are used to decide how to distribute requests among all possible coordinator nodes in the cluster. @@ -55,6 +105,11 @@ class LoadBalancingPolicy(object): custom behavior. """ + _hosts_lock = None + + def __init__(self): + self._hosts_lock = Lock() + def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in @@ -73,7 +128,7 @@ def populate(self, cluster, hosts): def make_query_plan(self, working_keyspace=None, query=None): """ - Given a :class:`~.query.Statement` instance, return a iterable + Given a :class:`~.query.Statement` instance, return an iterable of :class:`.Host` instances which should be queried in that order. A generator may work well for custom implementations of this method. @@ -87,35 +142,14 @@ def make_query_plan(self, working_keyspace=None, query=None): """ raise NotImplementedError() - def on_up(self, host): - """ - Called when a :class:`~.pool.Host`'s :class:`~.HealthMonitor` - marks the node up. - """ - raise NotImplementedError() - - def on_down(self, host): - """ - Called when a :class:`~.pool.Host`'s :class:`~.HealthMonitor` - marks the node down. - """ - raise NotImplementedError() - - def on_add(self, host): - """ - Called when a :class:`.Cluster` instance is first created and - the initial contact points are added as well as when a new - :class:`~.pool.Host` is discovered in the cluster, which may - happen the first time the ring topology is examined or when - a new node joins the cluster. - """ - raise NotImplementedError() - - def on_remove(self, host): + def check_supported(self): """ - Called when a :class:`~.pool.Host` leaves the cluster. + This will be called after the cluster Metadata has been initialized. + If the load balancing policy implementation cannot be supported for + some reason (such as a missing C extension), this is the point at + which it should raise an exception. """ - raise NotImplementedError() + pass class RoundRobinPolicy(LoadBalancingPolicy): @@ -123,15 +157,13 @@ class RoundRobinPolicy(LoadBalancingPolicy): A subclass of :class:`.LoadBalancingPolicy` which evenly distributes queries across all nodes in the cluster, regardless of what datacenter the nodes may be in. - - This load balancing policy is used by default. """ + _live_hosts = frozenset(()) + _position = 0 def populate(self, cluster, hosts): - self._live_hosts = set(hosts) - if len(hosts) <= 1: - self._position = 0 - else: + self._live_hosts = frozenset(hosts) + if len(hosts) > 1: self._position = randint(0, len(hosts) - 1) def distance(self, host): @@ -143,24 +175,29 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - length = len(self._live_hosts) + hosts = self._live_hosts + length = len(hosts) if length: pos %= length - return list(islice(cycle(self._live_hosts), pos, pos + length)) + return islice(cycle(hosts), pos, pos + length) else: return [] def on_up(self, host): - self._live_hosts.add(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.union((host, )) def on_down(self, host): - self._live_hosts.discard(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.difference((host, )) def on_add(self, host): - self._live_hosts.add(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.union((host, )) def on_remove(self, host): - self._live_hosts.remove(host) + with self._hosts_lock: + self._live_hosts = self._live_hosts.difference((host, )) class DCAwareRoundRobinPolicy(LoadBalancingPolicy): @@ -173,11 +210,14 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): local_dc = None used_hosts_per_remote_dc = 0 - def __init__(self, local_dc, used_hosts_per_remote_dc=0): + def __init__(self, local_dc='', used_hosts_per_remote_dc=0): """ The `local_dc` parameter should be the name of the datacenter (such as is reported by ``nodetool ring``) that should - be considered local. + be considered local. If not specified, the driver will choose + a local_dc based on the first host among :attr:`.Cluster.contact_points` + having a valid DC. If relying on this mechanism, all specified + contact points should be nodes in a single, local DC. `used_hosts_per_remote_dc` controls how many nodes in each remote datacenter will have connections opened @@ -189,22 +229,23 @@ def __init__(self, local_dc, used_hosts_per_remote_dc=0): self.local_dc = local_dc self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} + self._position = 0 + self._endpoints = [] + LoadBalancingPolicy.__init__(self) def _dc(self, host): return host.datacenter or self.local_dc def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): - self._dc_live_hosts[dc] = set(dc_hosts) + self._dc_live_hosts[dc] = tuple(set(dc_hosts)) - # position is currently only used for local hosts - local_live = self._dc_live_hosts.get(self.local_dc) - if not local_live: - self._position = 0 - elif len(local_live) == 1: - self._position = 0 - else: - self._position = randint(0, len(local_live) - 1) + if not self.local_dc: + self._endpoints = [ + endpoint + for endpoint in cluster.endpoints_resolved] + + self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): dc = self._dc(host) @@ -229,29 +270,52 @@ def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 - local_live = list(self._dc_live_hosts.get(self.local_dc, ())) + local_live = self._dc_live_hosts.get(self.local_dc, ()) pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host - for dc, current_dc_hosts in self._dc_live_hosts.iteritems(): - if dc == self.local_dc: - continue - - for host in list(current_dc_hosts)[:self.used_hosts_per_remote_dc]: + # the dict can change, so get candidate DCs iterating over keys of a copy + other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] + for dc in other_dcs: + remote_live = self._dc_live_hosts.get(dc, ()) + for host in remote_live[:self.used_hosts_per_remote_dc]: yield host def on_up(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) + # not worrying about threads because this will happen during + # control connection startup/refresh + if not self.local_dc and host.datacenter: + if host.endpoint in self._endpoints: + self.local_dc = host.datacenter + log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " + "if incorrect, please specify a local_dc to the constructor, " + "or limit contact points to local cluster nodes" % + (self.local_dc, host.endpoint)) + del self._endpoints + + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.get(dc, ()) + if host not in current_hosts: + self._dc_live_hosts[dc] = current_hosts + (host, ) def on_down(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.get(dc, ()) + if host in current_hosts: + hosts = tuple(h for h in current_hosts if h != host) + if hosts: + self._dc_live_hosts[dc] = hosts + else: + del self._dc_live_hosts[dc] def on_add(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) + self.on_up(host) def on_remove(self, host): - self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) + self.on_down(host) class TokenAwarePolicy(LoadBalancingPolicy): @@ -262,8 +326,10 @@ class TokenAwarePolicy(LoadBalancingPolicy): This alters the child policy's behavior so that it first attempts to send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined by the child policy) based on the :class:`.Statement`'s - :attr:`~.Statement.routing_key`. Once those hosts are exhausted, the - remaining hosts in the child policy's query plan will be used. + :attr:`~.Statement.routing_key`. If :attr:`.shuffle_replicas` is + truthy, these replicas will be yielded in a random order. Once those + hosts are exhausted, the remaining hosts in the child policy's query + plan will be used in the order provided by the child policy. If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. @@ -271,16 +337,30 @@ class TokenAwarePolicy(LoadBalancingPolicy): _child_policy = None _cluster_metadata = None + shuffle_replicas = False + """ + Yield local replicas in a random order. + """ - def __init__(self, child_policy): - self.child_policy = child_policy + def __init__(self, child_policy, shuffle_replicas=False): + self._child_policy = child_policy + self.shuffle_replicas = shuffle_replicas def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata - self.child_policy.populate(cluster, hosts) + self._child_policy.populate(cluster, hosts) + + def check_supported(self): + if not self._cluster_metadata.can_support_partitioner(): + raise RuntimeError( + '%s cannot be used with the cluster partitioner (%s) because ' + 'the relevant C extension for this driver was not compiled. ' + 'See the installation instructions for details on building ' + 'and installing the C extensions.' % + (self.__class__.__name__, self._cluster_metadata.partitioner)) def distance(self, *args, **kwargs): - return self.child_policy.distance(*args, **kwargs) + return self._child_policy.distance(*args, **kwargs) def make_query_plan(self, working_keyspace=None, query=None): if query and query.keyspace: @@ -288,19 +368,21 @@ def make_query_plan(self, working_keyspace=None, query=None): else: keyspace = working_keyspace - child = self.child_policy + child = self._child_policy if query is None: for host in child.make_query_plan(keyspace, query): yield host else: routing_key = query.routing_key - if routing_key is None: + if routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): yield host else: replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) + if self.shuffle_replicas: + shuffle(replicas) for replica in replicas: - if replica.monitor.is_up and \ + if replica.is_up and \ child.distance(replica) == HostDistance.LOCAL: yield replica @@ -311,16 +393,175 @@ def make_query_plan(self, working_keyspace=None, query=None): yield host def on_up(self, *args, **kwargs): - return self.child_policy.on_up(*args, **kwargs) + return self._child_policy.on_up(*args, **kwargs) def on_down(self, *args, **kwargs): - return self.child_policy.on_down(*args, **kwargs) + return self._child_policy.on_down(*args, **kwargs) def on_add(self, *args, **kwargs): - return self.child_policy.on_add(*args, **kwargs) + return self._child_policy.on_add(*args, **kwargs) def on_remove(self, *args, **kwargs): - return self.child_policy.on_remove(*args, **kwargs) + return self._child_policy.on_remove(*args, **kwargs) + + +class WhiteListRoundRobinPolicy(RoundRobinPolicy): + """ + A subclass of :class:`.RoundRobinPolicy` which evenly + distributes queries across all nodes in the cluster, + regardless of what datacenter the nodes may be in, but + only if that node exists in the list of allowed nodes + + This policy is addresses the issue described in + https://datastax-oss.atlassian.net/browse/JAVA-145 + Where connection errors occur when connection + attempts are made to private IP addresses remotely + """ + + def __init__(self, hosts): + """ + The `hosts` parameter should be a sequence of hosts to permit + connections to. + """ + self._allowed_hosts = tuple(hosts) + self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts + for endpoint in socket.getaddrinfo(a, None, socket.AF_UNSPEC, socket.SOCK_STREAM)] + + RoundRobinPolicy.__init__(self) + + def populate(self, cluster, hosts): + self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts_resolved) + + if len(hosts) <= 1: + self._position = 0 + else: + self._position = randint(0, len(hosts) - 1) + + def distance(self, host): + if host.address in self._allowed_hosts_resolved: + return HostDistance.LOCAL + else: + return HostDistance.IGNORED + + def on_up(self, host): + if host.address in self._allowed_hosts_resolved: + RoundRobinPolicy.on_up(self, host) + + def on_add(self, host): + if host.address in self._allowed_hosts_resolved: + RoundRobinPolicy.on_add(self, host) + + +class HostFilterPolicy(LoadBalancingPolicy): + """ + A :class:`.LoadBalancingPolicy` subclass configured with a child policy, + and a single-argument predicate. This policy defers to the child policy for + hosts where ``predicate(host)`` is truthy. Hosts for which + ``predicate(host)`` is falsy will be considered :attr:`.IGNORED`, and will + not be used in a query plan. + + This can be used in the cases where you need a whitelist or blacklist + policy, e.g. to prepare for decommissioning nodes or for testing: + + .. code-block:: python + + def address_is_ignored(host): + return host.address in [ignored_address0, ignored_address1] + + blacklist_filter_policy = HostFilterPolicy( + child_policy=RoundRobinPolicy(), + predicate=address_is_ignored + ) + + cluster = Cluster( + primary_host, + load_balancing_policy=blacklist_filter_policy, + ) + + See the note in the :meth:`.make_query_plan` documentation for a caveat on + how wrapping ordering polices (e.g. :class:`.RoundRobinPolicy`) may break + desirable properties of the wrapped policy. + + Please note that whitelist and blacklist policies are not recommended for + general, day-to-day use. You probably want something like + :class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has + fallbacks, over a brute-force method like whitelisting or blacklisting. + """ + + def __init__(self, child_policy, predicate): + """ + :param child_policy: an instantiated :class:`.LoadBalancingPolicy` + that this one will defer to. + :param predicate: a one-parameter function that takes a :class:`.Host`. + If it returns a falsy value, the :class:`.Host` will + be :attr:`.IGNORED` and not returned in query plans. + """ + super(HostFilterPolicy, self).__init__() + self._child_policy = child_policy + self._predicate = predicate + + def on_up(self, host, *args, **kwargs): + return self._child_policy.on_up(host, *args, **kwargs) + + def on_down(self, host, *args, **kwargs): + return self._child_policy.on_down(host, *args, **kwargs) + + def on_add(self, host, *args, **kwargs): + return self._child_policy.on_add(host, *args, **kwargs) + + def on_remove(self, host, *args, **kwargs): + return self._child_policy.on_remove(host, *args, **kwargs) + + @property + def predicate(self): + """ + A predicate, set on object initialization, that takes a :class:`.Host` + and returns a value. If the value is falsy, the :class:`.Host` is + :class:`~HostDistance.IGNORED`. If the value is truthy, + :class:`.HostFilterPolicy` defers to the child policy to determine the + host's distance. + + This is a read-only value set in ``__init__``, implemented as a + ``property``. + """ + return self._predicate + + def distance(self, host): + """ + Checks if ``predicate(host)``, then returns + :attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy + otherwise. + """ + if self.predicate(host): + return self._child_policy.distance(host) + else: + return HostDistance.IGNORED + + def populate(self, cluster, hosts): + self._child_policy.populate(cluster=cluster, hosts=hosts) + + def make_query_plan(self, working_keyspace=None, query=None): + """ + Defers to the child policy's + :meth:`.LoadBalancingPolicy.make_query_plan` and filters the results. + + Note that this filtering may break desirable properties of the wrapped + policy in some cases. For instance, imagine if you configure this + policy to filter out ``host2``, and to wrap a round-robin policy that + rotates through three hosts in the order ``host1, host2, host3``, + ``host2, host3, host1``, ``host3, host1, host2``, repeating. This + policy will yield ``host1, host3``, ``host3, host1``, ``host3, host1``, + disproportionately favoring ``host3``. + """ + child_qp = self._child_policy.make_query_plan( + working_keyspace=working_keyspace, query=query + ) + for host in child_qp: + if self.predicate(host): + yield host + + def check_supported(self): + return self._child_policy.check_supported() class ConvictionPolicy(object): @@ -360,7 +601,7 @@ class SimpleConvictionPolicy(ConvictionPolicy): """ def add_failure(self, connection_exc): - return True + return not isinstance(connection_exc, OperationTimedOut) def reset(self): pass @@ -369,7 +610,7 @@ def reset(self): class ReconnectionPolicy(object): """ This class and its subclasses govern how frequently an attempt is made - to reconnect to nodes that are marked dead. + to reconnect to nodes that are marked as dead. If custom behavior is needed, this class may be subclassed. """ @@ -377,7 +618,7 @@ class ReconnectionPolicy(object): def new_schedule(self): """ This should return a finite or infinite iterable of delays (each as a - floating point number of seconds) inbetween each failed reconnection + floating point number of seconds) in-between each failed reconnection attempt. Note that if the iterable is finite, reconnection attempts will cease once the iterable is exhausted. """ @@ -387,12 +628,12 @@ def new_schedule(self): class ConstantReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay - inbetween each reconnection attempt. + in-between each reconnection attempt. """ def __init__(self, delay, max_attempts=64): """ - `delay` should be a floating point number of seconds to wait inbetween + `delay` should be a floating point number of seconds to wait in-between each attempt. `max_attempts` should be a total number of attempts to be made before @@ -401,27 +642,40 @@ def __init__(self, delay, max_attempts=64): """ if delay < 0: raise ValueError("delay must not be negative") - if max_attempts < 0: + if max_attempts is not None and max_attempts < 0: raise ValueError("max_attempts must not be negative") self.delay = delay self.max_attempts = max_attempts def new_schedule(self): - return repeat(self.delay, self.max_attempts) + if self.max_attempts: + return repeat(self.delay, self.max_attempts) + return repeat(self.delay) class ExponentialReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which exponentially increases - the length of the delay inbetween each reconnection attempt up to + the length of the delay in-between each reconnection attempt up to a set maximum delay. + + A random amount of jitter (+/- 15%) will be added to the pure exponential + delay value to avoid the situations where many reconnection handlers are + trying to reconnect at exactly the same time. """ - def __init__(self, base_delay, max_delay): + # TODO: max_attempts is 64 to preserve legacy default behavior + # consider changing to None in major release to prevent the policy + # giving up forever + def __init__(self, base_delay, max_delay, max_attempts=64): """ `base_delay` and `max_delay` should be in floating point units of seconds. + + `max_attempts` should be a total number of attempts to be made before + giving up, or :const:`None` to continue reconnection attempts forever. + The default is 64. """ if base_delay < 0 or max_delay < 0: raise ValueError("Delays may not be negative") @@ -429,54 +683,43 @@ def __init__(self, base_delay, max_delay): if max_delay < base_delay: raise ValueError("Max delay must be greater than base delay") + if max_attempts is not None and max_attempts < 0: + raise ValueError("max_attempts must not be negative") + self.base_delay = base_delay self.max_delay = max_delay + self.max_attempts = max_attempts def new_schedule(self): - return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64)) - - -class WriteType(object): - """ - For usage with :class:`.RetryPolicy`, this describe a type - of write operation. - """ - - SIMPLE = 0 - """ - A write to a single partition key. Such writes are guaranteed to be atomic - and isolated. - """ - - BATCH = 1 - """ - A write to multiple partition keys that used the distributed batch log to - ensure atomicity. - """ - - UNLOGGED_BATCH = 2 - """ - A write to multiple partition keys that did not use the distributed batch - log. Atomicity for such writes is not guaranteed. - """ + i, overflowed = 0, False + while self.max_attempts is None or i < self.max_attempts: + if overflowed: + yield self.max_delay + else: + try: + yield self._add_jitter(min(self.base_delay * (2 ** i), self.max_delay)) + except OverflowError: + overflowed = True + yield self.max_delay - COUNTER = 3 - """ - A counter write (for one or multiple partition keys). Such writes should - not be replayed in order to avoid overcount. - """ + i += 1 - BATCH_LOG = 4 - """ - The initial write to the distributed batch log that Cassandra performs - internally before a BATCH write. - """ + # Adds -+ 15% to the delay provided + def _add_jitter(self, value): + jitter = randint(85, 115) + delay = (jitter * value) / 100 + return min(max(self.base_delay, delay), self.max_delay) class RetryPolicy(object): """ - A policy that describes whether to retry, rethrow, or ignore timeout - and unavailable failures. + A policy that describes whether to retry, rethrow, or ignore coordinator + timeout and unavailable failures. These are failures reported from the + server side. Timeouts are configured by + `settings in cassandra.yaml `_. + Unavailable failures occur when the coordinator cannot achieve the consistency + level for a request. For further information see the method descriptions + below. To specify a default retry policy, set the :attr:`.Cluster.default_retry_policy` attribute to an instance of this @@ -507,6 +750,12 @@ class or one of its subclasses. should be ignored but no more retries should be attempted. """ + RETRY_NEXT_HOST = 3 + """ + This should be returned from the below methods if the operation + should be retried on another connection. + """ + def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): """ @@ -525,20 +774,20 @@ def on_read_timeout(self, query, consistency, required_responses, how many replicas needed to respond to meet the requested consistency level and how many actually did respond before the coordinator timed out the request. `data_retrieved` is a boolean indicating whether - any of those responses contained data (as opposed to just a checksum). + any of those responses contained data (as opposed to just a digest). `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. By default, operations will be retried at most once, and only if - a sufficient number of replicas responded (with checksums). + a sufficient number of replicas responded (with data digests). """ if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif received_responses >= required_responses and not data_retrieved: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): @@ -562,23 +811,23 @@ def on_write_timeout(self, query, consistency, write_type, `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. - By default, failed write operations will retried at most once, and - they will only be retried if the `write_type` was + By default, a failed write operations will be retried at most once, and + will only be retried if the `write_type` was :attr:`~.WriteType.BATCH_LOG`. """ if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif write_type == WriteType.BATCH_LOG: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): """ This is called when the coordinator node determines that a read or write operation cannot be successful because the number of live replicas are too low to meet the requested :class:`.ConsistencyLevel`. - This means that the read or write operation was never forwared to + This means that the read or write operation was never forwarded to any replicas. `query` is the :class:`.Statement` that failed. @@ -594,9 +843,36 @@ def on_unavailable(self, query, consistency, required_replicas, alive_replicas, `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. - By default, no retries will be attempted and the error will be re-raised. + By default, if this is the first retry, it triggers a retry on the next + host in the query plan with the same consistency level. If this is not the + first retry, no retries will be attempted and the error will be re-raised. + """ + return (self.RETRY_NEXT_HOST, None) if retry_num == 0 else (self.RETHROW, None) + + def on_request_error(self, query, consistency, error, retry_num): + """ + This is called when an unexpected error happens. This can be in the + following situations: + + * On a connection error + * On server errors: overloaded, isBootstrapping, serverError, etc. + + `query` is the :class:`.Statement` that timed out. + + `consistency` is the :class:`.ConsistencyLevel` that the operation was + attempted at. + + `error` the instance of the exception. + + `retry_num` counts how many times the operation has been retried, so + the first time this method is called, `retry_num` will be 0. + + By default, it triggers a retry on the next host in the query plan + with the same consistency level. """ - return (self.RETHROW, None) + # TODO revisit this for the next major + # To preserve the same behavior than before, we don't take retry_num into account + return self.RETRY_NEXT_HOST, None class FallthroughRetryPolicy(RetryPolicy): @@ -606,34 +882,39 @@ class FallthroughRetryPolicy(RetryPolicy): """ def on_read_timeout(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None + + def on_request_error(self, *args, **kwargs): + return self.RETHROW, None class DowngradingConsistencyRetryPolicy(RetryPolicy): """ + *Deprecated:* This retry policy will be removed in the next major release. + A retry policy that sometimes retries with a lower consistency level than the one initially requested. **BEWARE**: This policy may retry queries using a lower consistency level than the one initially requested. By doing so, it may break consistency guarantees. In other words, if you use this retry policy, - there is cases (documented below) where a read at :attr:`~.QUORUM` + there are cases (documented below) where a read at :attr:`~.QUORUM` *may not* see a preceding write at :attr:`~.QUORUM`. Do not use this policy unless you have understood the cases where this can happen and are ok with that. It is also recommended to subclass this class so that queries that required a consistency level downgrade can be - recorded (so that repairs can be made later, etc). + recorded (so that repairs can be made later, etc.). This policy implements the same retries as :class:`.RetryPolicy`, but on top of that, it also retries in the following cases: - * On a read timeout: if the number of replica that responded is + * On a read timeout: if the number of replicas that responded is greater than one but lower than is required by the requested consistency level, the operation is retried at a lower consistency level. @@ -645,7 +926,7 @@ class DowngradingConsistencyRetryPolicy(RetryPolicy): * On an unavailable exception: if at least one replica is alive, the operation is retried at a lower consistency level. - The reasoning being this retry policy is as follows:. If, based + The reasoning behind this retry policy is as follows: if, based on the information the Cassandra coordinator node returns, retrying the operation with the initially requested consistency has a chance to succeed, do it. Otherwise, if based on that information we know the @@ -662,42 +943,302 @@ class DowngradingConsistencyRetryPolicy(RetryPolicy): to make sure the data is persisted, and that reading something is better than reading nothing, even if there is a risk of reading stale data. """ + def __init__(self, *args, **kwargs): + super(DowngradingConsistencyRetryPolicy, self).__init__(*args, **kwargs) + warnings.warn('DowngradingConsistencyRetryPolicy is deprecated ' + 'and will be removed in the next major release.', + DeprecationWarning) + def _pick_consistency(self, num_responses): if num_responses >= 3: - return (self.RETRY, ConsistencyLevel.THREE) + return self.RETRY, ConsistencyLevel.THREE elif num_responses >= 2: - return (self.RETRY, ConsistencyLevel.TWO) + return self.RETRY, ConsistencyLevel.TWO elif num_responses >= 1: - return (self.RETRY, ConsistencyLevel.ONE) + return self.RETRY, ConsistencyLevel.ONE else: - return (self.RETHROW, None) + return self.RETHROW, None def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None + elif ConsistencyLevel.is_serial(consistency): + # Downgrading does not make sense for a CAS read query + return self.RETHROW, None elif received_responses < required_responses: return self._pick_consistency(received_responses) elif not data_retrieved: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): if retry_num != 0: - return (self.RETHROW, None) - elif write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): - return (self.IGNORE, None) + return self.RETHROW, None + + if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): + if received_responses > 0: + # persisted on at least one replica + return self.IGNORE, None + else: + return self.RETHROW, None elif write_type == WriteType.UNLOGGED_BATCH: return self._pick_consistency(received_responses) elif write_type == WriteType.BATCH_LOG: - return (self.RETRY, consistency) - else: - return (self.RETHROW, None) + return self.RETRY, consistency + + return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None + elif ConsistencyLevel.is_serial(consistency): + # failed at the paxos phase of a LWT, retry on the next host + return self.RETRY_NEXT_HOST, None else: return self._pick_consistency(alive_replicas) + + +class AddressTranslator(object): + """ + Interface for translating cluster-defined endpoints. + + The driver discovers nodes using server metadata and topology change events. Normally, + the endpoint defined by the server is the right way to connect to a node. In some environments, + these addresses may not be reachable, or not preferred (public vs. private IPs in cloud environments, + suboptimal routing, etc.). This interface allows for translating from server defined endpoints to + preferred addresses for driver connections. + + *Note:* :attr:`~Cluster.contact_points` provided while creating the :class:`~.Cluster` instance are not + translated using this mechanism -- only addresses received from Cassandra nodes are. + """ + def translate(self, addr): + """ + Accepts the node ip address, and returns a translated address to be used connecting to this node. + """ + raise NotImplementedError() + + +class IdentityTranslator(AddressTranslator): + """ + Returns the endpoint with no translation + """ + def translate(self, addr): + return addr + + +class EC2MultiRegionTranslator(AddressTranslator): + """ + Resolves private ips of the hosts in the same datacenter as the client, and public ips of hosts in other datacenters. + """ + def translate(self, addr): + """ + Reverse DNS the public broadcast_address, then lookup that hostname to get the AWS-resolved IP, which + will point to the private IP address within the same datacenter. + """ + # get family of this address, so we translate to the same + family = socket.getaddrinfo(addr, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)[0][0] + host = socket.getfqdn(addr) + for a in socket.getaddrinfo(host, 0, family, socket.SOCK_STREAM): + try: + return a[4][0] + except Exception: + pass + return addr + + +class SpeculativeExecutionPolicy(object): + """ + Interface for specifying speculative execution plans + """ + + def new_plan(self, keyspace, statement): + """ + Returns + + :param keyspace: + :param statement: + :return: + """ + raise NotImplementedError() + + +class SpeculativeExecutionPlan(object): + def next_execution(self, host): + raise NotImplementedError() + + +class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan): + def next_execution(self, host): + return -1 + + +class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): + + def new_plan(self, keyspace, statement): + return NoSpeculativeExecutionPlan() + + +class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): + """ + A speculative execution policy that sends a new query every X seconds (**delay**) for a maximum of Y attempts (**max_attempts**). + """ + + def __init__(self, delay, max_attempts): + self.delay = delay + self.max_attempts = max_attempts + + class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan): + def __init__(self, delay, max_attempts): + self.delay = delay + self.remaining = max_attempts + + def next_execution(self, host): + if self.remaining > 0: + self.remaining -= 1 + return self.delay + else: + return -1 + + def new_plan(self, keyspace, statement): + return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts) + + +class WrapperPolicy(LoadBalancingPolicy): + + def __init__(self, child_policy): + self._child_policy = child_policy + + def distance(self, *args, **kwargs): + return self._child_policy.distance(*args, **kwargs) + + def populate(self, cluster, hosts): + self._child_policy.populate(cluster, hosts) + + def on_up(self, *args, **kwargs): + return self._child_policy.on_up(*args, **kwargs) + + def on_down(self, *args, **kwargs): + return self._child_policy.on_down(*args, **kwargs) + + def on_add(self, *args, **kwargs): + return self._child_policy.on_add(*args, **kwargs) + + def on_remove(self, *args, **kwargs): + return self._child_policy.on_remove(*args, **kwargs) + + +class DefaultLoadBalancingPolicy(WrapperPolicy): + """ + A :class:`.LoadBalancingPolicy` wrapper that adds the ability to target a specific host first. + + If no host is set on the query, the child policy's query plan will be used as is. + """ + + _cluster_metadata = None + + def populate(self, cluster, hosts): + self._cluster_metadata = cluster.metadata + self._child_policy.populate(cluster, hosts) + + def make_query_plan(self, working_keyspace=None, query=None): + if query and query.keyspace: + keyspace = query.keyspace + else: + keyspace = working_keyspace + + # TODO remove next major since execute(..., host=XXX) is now available + addr = getattr(query, 'target_host', None) if query else None + target_host = self._cluster_metadata.get_host(addr) + + child = self._child_policy + if target_host and target_host.is_up: + yield target_host + for h in child.make_query_plan(keyspace, query): + if h != target_host: + yield h + else: + for h in child.make_query_plan(keyspace, query): + yield h + + +# TODO for backward compatibility, remove in next major +class DSELoadBalancingPolicy(DefaultLoadBalancingPolicy): + """ + *Deprecated:* This will be removed in the next major release, + consider using :class:`.DefaultLoadBalancingPolicy`. + """ + def __init__(self, *args, **kwargs): + super(DSELoadBalancingPolicy, self).__init__(*args, **kwargs) + warnings.warn("DSELoadBalancingPolicy will be removed in 4.0. Consider using " + "DefaultLoadBalancingPolicy.", DeprecationWarning) + + +class NeverRetryPolicy(RetryPolicy): + def _rethrow(self, *args, **kwargs): + return self.RETHROW, None + + on_read_timeout = _rethrow + on_write_timeout = _rethrow + on_unavailable = _rethrow + + +ColDesc = namedtuple('ColDesc', ['ks', 'table', 'col']) + +class ColumnEncryptionPolicy(object): + """ + A policy enabling (mostly) transparent encryption and decryption of data before it is + sent to the cluster. + + Key materials and other configurations are specified on a per-column basis. This policy can + then be used by driver structures which are aware of the underlying columns involved in their + work. In practice this includes the following cases: + + * Prepared statements - data for columns specified by the cluster's policy will be transparently + encrypted before they are sent + * Rows returned from any query - data for columns specified by the cluster's policy will be + transparently decrypted before they are returned to the user + + To enable this functionality, create an instance of this class (or more likely a subclass) + before creating a cluster. This policy should then be configured and supplied to the Cluster + at creation time via the :attr:`.Cluster.column_encryption_policy` attribute. + """ + + def encrypt(self, coldesc, obj_bytes): + """ + Encrypt the specified bytes using the cryptography materials for the specified column. + Largely used internally, although this could also be used to encrypt values supplied + to non-prepared statements in a way that is consistent with this policy. + """ + raise NotImplementedError() + + def decrypt(self, coldesc, encrypted_bytes): + """ + Decrypt the specified (encrypted) bytes using the cryptography materials for the + specified column. Used internally; could be used externally as well but there's + not currently an obvious use case. + """ + raise NotImplementedError() + + def add_column(self, coldesc, key): + """ + Provide cryptography materials to be used when encrypted and/or decrypting data + for the specified column. + """ + raise NotImplementedError() + + def contains_column(self, coldesc): + """ + Predicate to determine if a specific column is supported by this policy. + Currently only used internally. + """ + raise NotImplementedError() + + def encode_and_encrypt(self, coldesc, obj): + """ + Helper function to enable use of this policy on simple (i.e. non-prepared) + statements. + """ + raise NotImplementedError() diff --git a/cassandra/pool.py b/cassandra/pool.py index 75a310eb2b..37fdaee96b 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -1,8 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Connection pooling and host management. """ +from functools import total_ordering import logging +import socket import time from threading import Lock, RLock, Condition import weakref @@ -12,7 +30,8 @@ from cassandra.util import WeakSet # NOQA from cassandra import AuthenticationFailed -from cassandra.connection import MAX_STREAM_PER_CONNECTION, ConnectionException +from cassandra.connection import ConnectionException, EndPoint, DefaultEndPoint +from cassandra.policies import HostDistance log = logging.getLogger(__name__) @@ -25,36 +44,145 @@ class NoConnectionsAvailable(Exception): pass +@total_ordering class Host(object): """ Represents a single Cassandra node. """ - address = None + endpoint = None + """ + The :class:`~.connection.EndPoint` to connect to the node. + """ + + broadcast_address = None + """ + broadcast address configured for the node, *if available*: + + 'system.local.broadcast_address' or 'system.peers.peer' (Cassandra 2-3) + 'system.local.broadcast_address' or 'system.peers_v2.peer' (Cassandra 4) + + This is not present in the ``system.local`` table for older versions of Cassandra. It + is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + broadcast_port = None + """ + broadcast port configured for the node, *if available*: + + 'system.local.broadcast_port' or 'system.peers_v2.peer_port' (Cassandra 4) + + It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ - The IP address or hostname of the node. + + broadcast_rpc_address = None + """ + The broadcast rpc address of the node: + + 'system.local.rpc_address' or 'system.peers.rpc_address' (Cassandra 3) + 'system.local.rpc_address' or 'system.peers.native_transport_address (DSE 6+)' + 'system.local.rpc_address' or 'system.peers_v2.native_address (Cassandra 4)' + """ + + broadcast_rpc_port = None + """ + The broadcast rpc port of the node, *if available*: + + 'system.local.rpc_port' or 'system.peers.native_transport_port' (DSE 6+) + 'system.local.rpc_port' or 'system.peers_v2.native_port' (Cassandra 4) + """ + + listen_address = None + """ + listen address configured for the node, *if available*: + + 'system.local.listen_address' + + This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not + queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. Usually the same as ``broadcast_address`` + unless configured differently in cassandra.yaml. """ - monitor = None + listen_port = None """ - A :class:`.HealthMonitor` instance that tracks whether this node is + listen port configured for the node, *if available*: + + 'system.local.listen_port' + + This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not + queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + conviction_policy = None + """ + A :class:`~.ConvictionPolicy` instance for determining when this node should + be marked up or down. + """ + + is_up = None + """ + :const:`True` if the node is considered up, :const:`False` if it is + considered down, and :const:`None` if it is not known if the node is up or down. """ + release_version = None + """ + release_version as queried from the control connection system tables + """ + + host_id = None + """ + The unique identifier of the cassandra node + """ + + dse_version = None + """ + dse_version as queried from the control connection system tables. Only populated when connecting to + DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + dse_workload = None + """ + DSE workload queried from the control connection system tables. Only populated when connecting to + DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + This is a legacy attribute that does not portray multiple workloads in a uniform fashion. + See also :attr:`~.Host.dse_workloads`. + """ + + dse_workloads = None + """ + DSE workloads set, queried from the control connection system tables. Only populated when connecting to + DSE with this property available (added in DSE 5.1). + Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + _datacenter = None _rack = None _reconnection_handler = None + lock = None + + _currently_handling_node_up = False - def __init__(self, inet_address, conviction_policy_factory): - if inet_address is None: - raise ValueError("inet_address may not be None") + def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): + if endpoint is None: + raise ValueError("endpoint may not be None") if conviction_policy_factory is None: raise ValueError("conviction_policy_factory may not be None") - self.address = inet_address - self.monitor = HealthMonitor(conviction_policy_factory(self)) + self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) + self.conviction_policy = conviction_policy_factory(self) + self.host_id = host_id + self.set_location_info(datacenter, rack) + self.lock = RLock() - self._reconnection_lock = Lock() + @property + def address(self): + """ + The IP address of the endpoint. This is the RPC address the driver uses when connecting to the node. + """ + # backward compatibility + return self.endpoint.address @property def datacenter(self): @@ -75,28 +203,49 @@ def set_location_info(self, datacenter, rack): self._datacenter = datacenter self._rack = rack + def set_up(self): + if not self.is_up: + log.debug("Host %s is now marked up", self.endpoint) + self.conviction_policy.reset() + self.is_up = True + + def set_down(self): + self.is_up = False + + def signal_connection_failure(self, connection_exc): + return self.conviction_policy.add_failure(connection_exc) + + def is_currently_reconnecting(self): + return self._reconnection_handler is not None + def get_and_set_reconnection_handler(self, new_handler): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ - with self._reconnection_lock: + with self.lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old def __eq__(self, other): - if not isinstance(other, Host): - return False + if isinstance(other, Host): + return self.endpoint == other.endpoint + else: # TODO Backward compatibility, remove next major + return self.endpoint.address == other + + def __hash__(self): + return hash(self.endpoint) - return self.address == other.address + def __lt__(self, other): + return self.endpoint < other.endpoint def __str__(self): - return self.address + return str(self.endpoint) def __repr__(self): dc = (" %s" % (self._datacenter,)) if self._datacenter else "" - return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc) + return "<%s: %s%s>" % (self.__class__.__name__, self.endpoint, dc) class _ReconnectionHandler(object): @@ -116,26 +265,41 @@ def __init__(self, scheduler, schedule, callback, *callback_args, **callback_kwa def start(self): if self._cancelled: + log.debug("Reconnection handler was cancelled before starting") return - # TODO cancel previous reconnection handlers? That's probably the job - # of whatever created this. - - first_delay = self.schedule.next() + first_delay = next(self.schedule) self.scheduler.schedule(first_delay, self.run) def run(self): if self._cancelled: - self.callback(*(self.callback_args), **(self.callback_kwargs)) + return + conn = None try: - self.on_reconnection(self.try_reconnect()) + conn = self.try_reconnect() except Exception as exc: - next_delay = self.schedule.next() + try: + next_delay = next(self.schedule) + except StopIteration: + # the schedule has been exhausted + next_delay = None + + # call on_exception for logging purposes even if next_delay is None if self.on_exception(exc, next_delay): - self.scheduler.schedule(next_delay, self.run) + if next_delay is None: + log.warning( + "Will not continue to retry reconnection attempts " + "due to an exhausted retry schedule") + else: + self.scheduler.schedule(next_delay, self.run) else: - self.callback(*(self.callback_args), **(self.callback_kwargs)) + if not self._cancelled: + self.on_reconnection(conn) + self.callback(*(self.callback_args), **(self.callback_kwargs)) + finally: + if conn: + conn.close() def cancel(self): self._cancelled = True @@ -175,8 +339,11 @@ def on_exception(self, exc, next_delay): class _HostReconnectionHandler(_ReconnectionHandler): - def __init__(self, host, connection_factory, *args, **kwargs): + def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) + self.is_host_addition = is_host_addition + self.on_add = on_add + self.on_up = on_up self.host = host self.connection_factory = connection_factory @@ -184,89 +351,242 @@ def try_reconnect(self): return self.connection_factory() def on_reconnection(self, connection): - self.host.monitor.reset() + log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) + if self.is_host_addition: + self.on_add(self.host) + else: + self.on_up(self.host) def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): return False else: - log.warn("Error attempting to reconnect to %s: %s", self.host, exc) + log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", + self.host, next_delay, exc) log.debug("Reconnection error details", exc_info=True) return True -class HealthMonitor(object): - """ - Monitors whether a particular host is marked as up or down. - This class is primarily intended for internal use, although - applications may find it useful to check whether a given node - is up or down. - """ - - is_up = True +class HostConnection(object): """ - A boolean representing the current state of the node. + When using v3 of the native protocol, this is used instead of a connection + pool per host (HostConnectionPool) due to the increased in-flight capacity + of individual connections. """ - def __init__(self, conviction_policy): - self._conviction_policy = conviction_policy - self._host = conviction_policy.host - # self._listeners will hold, among other things, references to - # Cluster objects. To allow those to be GC'ed (and shutdown) even - # though we've implemented __del__, use weak references. - self._listeners = WeakSet() - self._lock = RLock() + host = None + host_distance = None + is_shutdown = False + shutdown_on_error = False - def register(self, listener): - with self._lock: - self._listeners.add(listener) + _session = None + _connection = None + _lock = None + _keyspace = None - def unregister(self, listener): - with self._lock: - self._listeners.remove(listener) + def __init__(self, host, host_distance, session): + self.host = host + self.host_distance = host_distance + self._session = weakref.proxy(session) + self._lock = Lock() + # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. + self._stream_available_condition = Condition(self._lock) + self._is_replacing = False + # Contains connections which shouldn't be used anymore + # and are waiting until all requests time out or complete + # so that we can dispose of them. + self._trash = set() - def set_up(self): - if self.is_up: + if host_distance == HostDistance.IGNORED: + log.debug("Not opening connection to ignored host %s", self.host) + return + elif host_distance == HostDistance.REMOTE and not session.cluster.connect_to_remote_hosts: + log.debug("Not opening connection to remote host %s", self.host) return - self._conviction_policy.reset() - log.info("Host %s is considered up", self._host) + log.debug("Initializing connection for host %s", self.host) + self._connection = session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + self._keyspace = session.keyspace + if self._keyspace: + self._connection.set_keyspace_blocking(self._keyspace) + log.debug("Finished initializing connection for host %s", self.host) - with self._lock: - listeners = self._listeners.copy() + def _get_connection(self): + if self.is_shutdown: + raise ConnectionException( + "Pool for %s is shutdown" % (self.host,), self.host) - for listener in listeners: - listener.on_up(self._host) + conn = self._connection + if not conn: + raise NoConnectionsAvailable() + return conn - self.is_up = True + def borrow_connection(self, timeout): + conn = self._get_connection() + if conn.orphaned_threshold_reached: + with self._lock: + if not self._is_replacing: + self._is_replacing = True + self._session.submit(self._replace, conn) + log.debug( + "Connection to host %s reached orphaned stream limit, replacing...", + self.host + ) - def set_down(self): - if not self.is_up: - return + start = time.time() + remaining = timeout + while True: + with conn.lock: + if not (conn.orphaned_threshold_reached and conn.is_closed) and conn.in_flight < conn.max_request_id: + conn.in_flight += 1 + return conn, conn.get_request_id() + if timeout is not None: + remaining = timeout - time.time() + start + if remaining < 0: + break + with self._stream_available_condition: + if conn.orphaned_threshold_reached and conn.is_closed: + conn = self._get_connection() + else: + self._stream_available_condition.wait(remaining) - self.is_up = False - log.info("Host %s is considered down", self._host) + raise NoConnectionsAvailable("All request IDs are currently in use") + def return_connection(self, connection, stream_was_orphaned=False): + if not stream_was_orphaned: + with connection.lock: + connection.in_flight -= 1 + with self._stream_available_condition: + self._stream_available_condition.notify() + + if connection.is_defunct or connection.is_closed: + if connection.signaled_error and not self.shutdown_on_error: + return + + is_down = False + if not connection.signaled_error: + log.debug("Defunct or closed connection (%s) returned to pool, potentially " + "marking host %s as down", id(connection), self.host) + is_down = self._session.cluster.signal_connection_failure( + self.host, connection.last_error, is_host_addition=False) + connection.signaled_error = True + + if self.shutdown_on_error and not is_down: + is_down = True + self._session.cluster.on_down(self.host, is_host_addition=False) + + if is_down: + self.shutdown() + else: + self._connection = None + with self._lock: + if self._is_replacing: + return + self._is_replacing = True + self._session.submit(self._replace, connection) + else: + if connection in self._trash: + with connection.lock: + if connection.in_flight == len(connection.orphaned_request_ids): + with self._lock: + if connection in self._trash: + self._trash.remove(connection) + log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) + connection.close() + return + + def on_orphaned_stream_released(self): + """ + Called when a response for an orphaned stream (timed out on the client + side) was received. + """ + with self._stream_available_condition: + self._stream_available_condition.notify() + + def _replace(self, connection): + with self._lock: + if self.is_shutdown: + return + + log.debug("Replacing connection (%s) to %s", id(connection), self.host) + try: + conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + if self._keyspace: + conn.set_keyspace_blocking(self._keyspace) + self._connection = conn + except Exception: + log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) + self._session.submit(self._replace, connection) + else: + with connection.lock: + with self._lock: + if connection.orphaned_threshold_reached: + if connection.in_flight == len(connection.orphaned_request_ids): + connection.close() + else: + self._trash.add(connection) + self._is_replacing = False + self._stream_available_condition.notify() + + def shutdown(self): with self._lock: - listeners = self._listeners.copy() + if self.is_shutdown: + return + else: + self.is_shutdown = True + self._stream_available_condition.notify_all() - for listener in listeners: - listener.on_down(self._host) + if self._connection: + self._connection.close() + self._connection = None - def reset(self): - return self.set_up() + trash_conns = None + with self._lock: + if self._trash: + trash_conns = self._trash + self._trash = set() - def signal_connection_failure(self, connection_exc): - is_down = self._conviction_policy.add_failure(connection_exc) - if is_down: - self.set_down() - return is_down + if trash_conns is not None: + for conn in self._trash: + conn.close() + def _set_keyspace_for_all_conns(self, keyspace, callback): + if self.is_shutdown or not self._connection: + return + + def connection_finished_setting_keyspace(conn, error): + self.return_connection(conn) + errors = [] if not error else [error] + callback(self, errors) + + self._keyspace = keyspace + self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) + + def get_connections(self): + c = self._connection + return [c] if c else [] + + def get_state(self): + connection = self._connection + open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 + in_flights = [connection.in_flight] if connection else [] + orphan_requests = [connection.orphaned_request_ids] if connection else [] + return {'shutdown': self.is_shutdown, 'open_count': open_count, \ + 'in_flights': in_flights, 'orphan_requests': orphan_requests} + + @property + def open_count(self): + connection = self._connection + return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 _MAX_SIMULTANEOUS_CREATION = 1 +_MIN_TRASH_INTERVAL = 10 class HostConnectionPool(object): + """ + Used to pool connections to a host for v1 and v2 native protocol. + """ host = None host_distance = None @@ -274,6 +594,8 @@ class HostConnectionPool(object): is_shutdown = False open_count = 0 _scheduled_for_creation = 0 + _next_trash_allowed_at = 0 + _keyspace = None def __init__(self, host, host_distance, session): self.host = host @@ -283,11 +605,20 @@ def __init__(self, host, host_distance, session): self._lock = RLock() self._conn_available_condition = Condition() + log.debug("Initializing new connection pool for host %s", self.host) core_conns = session.cluster.get_core_connections_per_host(host_distance) - self._connections = [session.cluster.connection_factory(host.address) + self._connections = [session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) for i in range(core_conns)] + + self._keyspace = session.keyspace + if self._keyspace: + for conn in self._connections: + conn.set_keyspace_blocking(self._keyspace) + self._trash = set() + self._next_trash_allowed_at = time.time() self.open_count = core_conns + log.debug("Finished initializing new connection pool for host %s", self.host) def borrow_connection(self, timeout): if self.is_shutdown: @@ -297,7 +628,7 @@ def borrow_connection(self, timeout): conns = self._connections if not conns: # handled specially just for simpler code - log.debug("Detected empty pool, opening core conns to %s" % (self.host,)) + log.debug("Detected empty pool, opening core conns to %s", self.host) core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) with self._lock: # we check the length of self._connections again @@ -310,7 +641,6 @@ def borrow_connection(self, timeout): # in_flight is incremented by wait_for_conn conn = self._wait_for_conn(timeout) - conn.set_keyspace(self._session.keyspace) return conn else: # note: it would be nice to push changes to these config settings @@ -320,44 +650,48 @@ def borrow_connection(self, timeout): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) least_busy = min(conns, key=lambda c: c.in_flight) + request_id = None # to avoid another thread closing this connection while # trashing it (through the return_connection process), hold # the connection lock from this point until we've incremented # its in_flight count + need_to_wait = False with least_busy.lock: - - # if we have too many requests on this connection but we still - # have space to open a new connection against this host, go ahead - # and schedule the creation of a new connection - if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: - self._maybe_spawn_new_connection() - - if least_busy.in_flight >= MAX_STREAM_PER_CONNECTION: + if least_busy.in_flight < least_busy.max_request_id: + least_busy.in_flight += 1 + request_id = least_busy.get_request_id() + else: # once we release the lock, wait for another connection need_to_wait = True - else: - need_to_wait = False - least_busy.in_flight += 1 if need_to_wait: # wait_for_conn will increment in_flight on the conn - least_busy = self._wait_for_conn(timeout) + least_busy, request_id = self._wait_for_conn(timeout) + + # if we have too many requests on this connection, but we still + # have space to open a new connection against this host, go ahead + # and schedule the creation of a new connection + if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: + self._maybe_spawn_new_connection() - least_busy.set_keyspace(self._session.keyspace) - return least_busy + return least_busy, request_id def _maybe_spawn_new_connection(self): with self._lock: if self._scheduled_for_creation >= _MAX_SIMULTANEOUS_CREATION: return + if self.open_count >= self._session.cluster.get_max_connections_per_host(self.host_distance): + return self._scheduled_for_creation += 1 - log.debug("Submitting task for creation of new Connection to %s" % (self.host,)) + log.debug("Submitting task for creation of new Connection to %s", self.host) self._session.submit(self._create_new_connection) def _create_new_connection(self): try: self._add_conn_if_under_max() + except (ConnectionException, socket.error) as exc: + log.warning("Failed to create new connection to %s: %s", self.host, exc) except Exception: log.exception("Unexpectedly failed to create new connection") finally: @@ -368,25 +702,31 @@ def _add_conn_if_under_max(self): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) with self._lock: if self.is_shutdown: - return False + return True if self.open_count >= max_conns: - return False + return True self.open_count += 1 + log.debug("Going to open new connection to host %s", self.host) try: - conn = self._session.cluster.connection_factory(self.host.address) + conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + if self._keyspace: + conn.set_keyspace_blocking(self._session.keyspace) + self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL with self._lock: new_connections = self._connections[:] + [conn] self._connections = new_connections + log.debug("Added new connection (%s) to pool for host %s, signaling availability", + id(conn), self.host) self._signal_available_conn() return True - except ConnectionException as exc: - log.exception("Failed to add new connection to pool for host %s" % (self.host,)) + except (ConnectionException, socket.error) as exc: + log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc) with self._lock: self.open_count -= 1 - if self.host.monitor.signal_connection_failure(exc): + if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False): self.shutdown() return False except AuthenticationFailed: @@ -423,31 +763,39 @@ def _wait_for_conn(self, timeout): if conns: least_busy = min(conns, key=lambda c: c.in_flight) with least_busy.lock: - if least_busy.in_flight < MAX_STREAM_PER_CONNECTION: + if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 - return least_busy + return least_busy, least_busy.get_request_id() remaining = timeout - (time.time() - start) raise NoConnectionsAvailable() - def return_connection(self, connection): + def return_connection(self, connection, stream_was_orphaned=False): with connection.lock: - connection.in_flight -= 1 + if not stream_was_orphaned: + connection.in_flight -= 1 in_flight = connection.in_flight if connection.is_defunct or connection.is_closed: - is_down = self.host.monitor.signal_connection_failure(connection.last_error) - if is_down: - self.shutdown() - else: - self._replace(connection) + if not connection.signaled_error: + log.debug("Defunct or closed connection (%s) returned to pool, potentially " + "marking host %s as down", id(connection), self.host) + is_down = self._session.cluster.signal_connection_failure( + self.host, connection.last_error, is_host_addition=False) + connection.signaled_error = True + if is_down: + self.shutdown() + else: + self._replace(connection) else: if connection in self._trash: with connection.lock: - if in_flight == 0: + if connection.in_flight == 0: with self._lock: - self._trash.remove(connection) + if connection in self._trash: + self._trash.remove(connection) + log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) connection.close() return @@ -456,11 +804,19 @@ def return_connection(self, connection): # we can use in_flight here without holding the connection lock # because the fact that in_flight dipped below the min at some # point is enough to start the trashing procedure - if len(self._connections) > core_conns and in_flight <= min_reqs: + if len(self._connections) > core_conns and in_flight <= min_reqs and \ + time.time() >= self._next_trash_allowed_at: self._maybe_trash_connection(connection) else: self._signal_available_conn() + def on_orphaned_stream_released(self): + """ + Called when a response for an orphaned stream (timed out on the client + side) was received. + """ + self._signal_available_conn() + def _maybe_trash_connection(self, connection): core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) did_trash = False @@ -477,6 +833,7 @@ def _maybe_trash_connection(self, connection): with connection.lock: if connection.in_flight == 0: + log.debug("Skipping trash and closing unused connection (%s) to %s", id(connection), self.host) connection.close() # skip adding it to the trash if we're already closing it @@ -485,7 +842,8 @@ def _maybe_trash_connection(self, connection): self._trash.add(connection) if did_trash: - log.debug("Trashed connection to %s" % (self.host,)) + self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL + log.debug("Trashed connection (%s) to %s", id(connection), self.host) def _replace(self, connection): should_replace = False @@ -498,18 +856,23 @@ def _replace(self, connection): should_replace = True if should_replace: - log.debug("Replacing connection to %s" % (self.host,)) - - def close_and_replace(): - connection.close() - self._add_conn_if_under_max() - - self._session.submit(close_and_replace) + log.debug("Replacing connection (%s) to %s", id(connection), self.host) + connection.close() + self._session.submit(self._retrying_replace) else: - # just close it - log.debug("Closing connection to %s" % (self.host,)) + log.debug("Closing connection (%s) to %s", id(connection), self.host) connection.close() + def _retrying_replace(self): + replaced = False + try: + replaced = self._add_conn_if_under_max() + except Exception: + log.exception("Failed replacing connection to %s", self.host) + if not replaced: + log.debug("Failed replacing connection to %s. Retrying.", self.host) + self._session.submit(self._retrying_replace) + def shutdown(self): with self._lock: if self.is_shutdown: @@ -522,9 +885,8 @@ def shutdown(self): conn.close() self.open_count -= 1 - reconnector = self.host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() + for conn in self._trash: + conn.close() def ensure_core_connections(self): if self.is_shutdown: @@ -536,3 +898,38 @@ def ensure_core_connections(self): for i in range(to_create): self._scheduled_for_creation += 1 self._session.submit(self._create_new_connection) + + def _set_keyspace_for_all_conns(self, keyspace, callback): + """ + Asynchronously sets the keyspace for all connections. When all + connections have been set, `callback` will be called with two + arguments: this pool, and a list of any errors that occurred. + """ + remaining_callbacks = set(self._connections) + errors = [] + + if not remaining_callbacks: + callback(self, errors) + return + + def connection_finished_setting_keyspace(conn, error): + self.return_connection(conn) + remaining_callbacks.remove(conn) + if error: + errors.append(error) + + if not remaining_callbacks: + callback(self, errors) + + self._keyspace = keyspace + for conn in self._connections: + conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) + + def get_connections(self): + return self._connections + + def get_state(self): + in_flights = [c.in_flight for c in self._connections] + orphan_requests = [c.orphaned_request_ids for c in self._connections] + return {'shutdown': self.is_shutdown, 'open_count': self.open_count, \ + 'in_flights': in_flights, 'orphan_requests': orphan_requests} diff --git a/cassandra/protocol.py b/cassandra/protocol.py new file mode 100644 index 0000000000..69340a805d --- /dev/null +++ b/cassandra/protocol.py @@ -0,0 +1,1491 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import logging +import socket +from uuid import UUID + +import io + +from cassandra import ProtocolVersion +from cassandra import type_codes, DriverException +from cassandra import (Unavailable, WriteTimeout, ReadTimeout, + WriteFailure, ReadFailure, FunctionFailure, + AlreadyExists, InvalidRequest, Unauthorized, + UnsupportedOperation, UserFunctionDescriptor, + UserAggregateDescriptor, SchemaTargetType) +from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, + CounterColumnType, DateType, DecimalType, + DoubleType, FloatType, Int32Type, + InetAddressType, IntegerType, ListType, + LongType, MapType, SetType, TimeUUIDType, + UTF8Type, VarcharType, UUIDType, UserType, + TupleType, lookup_casstype, SimpleDateType, + TimeType, ByteType, ShortType, DurationType) +from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, + uint8_pack, int8_unpack, uint64_pack, header_pack, + v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack) +from cassandra.policies import ColDesc +from cassandra import WriteType +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY +from cassandra import util + +log = logging.getLogger(__name__) + + +class NotSupportedError(Exception): + pass + + +class InternalError(Exception): + pass + +ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) + +HEADER_DIRECTION_TO_CLIENT = 0x80 +HEADER_DIRECTION_MASK = 0x80 + +COMPRESSED_FLAG = 0x01 +TRACING_FLAG = 0x02 +CUSTOM_PAYLOAD_FLAG = 0x04 +WARNING_FLAG = 0x08 +USE_BETA_FLAG = 0x10 +USE_BETA_MASK = ~USE_BETA_FLAG + +_message_types_by_opcode = {} + +_UNSET_VALUE = object() + + +def register_class(cls): + _message_types_by_opcode[cls.opcode] = cls + + +def get_registered_classes(): + return _message_types_by_opcode.copy() + + +class _RegisterMessageType(type): + def __init__(cls, name, bases, dct): + if not name.startswith('_'): + register_class(cls) + + +class _MessageType(object, metaclass=_RegisterMessageType): + + tracing = False + custom_payload = None + warnings = None + + def update_custom_payload(self, other): + if other: + if not self.custom_payload: + self.custom_payload = {} + self.custom_payload.update(other) + if len(self.custom_payload) > 65535: + raise ValueError("Custom payload map exceeds max count allowed by protocol (65535)") + + def __repr__(self): + return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) + + +def _get_params(message_obj): + base_attrs = dir(_MessageType) + return ( + (n, a) for n, a in message_obj.__dict__.items() + if n not in base_attrs and not n.startswith('_') and not callable(a) + ) + + +error_classes = {} + + +class ErrorMessage(_MessageType, Exception): + opcode = 0x00 + name = 'ERROR' + summary = 'Unknown' + + def __init__(self, code, message, info): + self.code = code + self.message = message + self.info = info + + @classmethod + def recv_body(cls, f, protocol_version, *args): + code = read_int(f) + msg = read_string(f) + subcls = error_classes.get(code, cls) + extra_info = subcls.recv_error_info(f, protocol_version) + return subcls(code=code, message=msg, info=extra_info) + + def summary_msg(self): + msg = 'Error from server: code=%04x [%s] message="%s"' \ + % (self.code, self.summary, self.message) + return msg + + def __str__(self): + return '<%s>' % self.summary_msg() + __repr__ = __str__ + + @staticmethod + def recv_error_info(f, protocol_version): + pass + + def to_exception(self): + return self + + +class ErrorMessageSubclass(_RegisterMessageType): + def __init__(cls, name, bases, dct): + if cls.error_code is not None: # Server has an error code of 0. + error_classes[cls.error_code] = cls + + +class ErrorMessageSub(ErrorMessage, metaclass=ErrorMessageSubclass): + error_code = None + + +class RequestExecutionException(ErrorMessageSub): + pass + + +class RequestValidationException(ErrorMessageSub): + pass + + +class ServerError(ErrorMessageSub): + summary = 'Server error' + error_code = 0x0000 + + +class ProtocolException(ErrorMessageSub): + summary = 'Protocol error' + error_code = 0x000A + + @property + def is_beta_protocol_error(self): + return 'USE_BETA flag is unset' in str(self) + + +class BadCredentials(ErrorMessageSub): + summary = 'Bad credentials' + error_code = 0x0100 + + +class UnavailableErrorMessage(RequestExecutionException): + summary = 'Unavailable exception' + error_code = 0x1000 + + @staticmethod + def recv_error_info(f, protocol_version): + return { + 'consistency': read_consistency_level(f), + 'required_replicas': read_int(f), + 'alive_replicas': read_int(f), + } + + def to_exception(self): + return Unavailable(self.summary_msg(), **self.info) + + +class OverloadedErrorMessage(RequestExecutionException): + summary = 'Coordinator node overloaded' + error_code = 0x1001 + + +class IsBootstrappingErrorMessage(RequestExecutionException): + summary = 'Coordinator node is bootstrapping' + error_code = 0x1002 + + +class TruncateError(RequestExecutionException): + summary = 'Error during truncate' + error_code = 0x1003 + + +class WriteTimeoutErrorMessage(RequestExecutionException): + summary = "Coordinator node timed out waiting for replica nodes' responses" + error_code = 0x1100 + + @staticmethod + def recv_error_info(f, protocol_version): + return { + 'consistency': read_consistency_level(f), + 'received_responses': read_int(f), + 'required_responses': read_int(f), + 'write_type': WriteType.name_to_value[read_string(f)], + } + + def to_exception(self): + return WriteTimeout(self.summary_msg(), **self.info) + + +class ReadTimeoutErrorMessage(RequestExecutionException): + summary = "Coordinator node timed out waiting for replica nodes' responses" + error_code = 0x1200 + + @staticmethod + def recv_error_info(f, protocol_version): + return { + 'consistency': read_consistency_level(f), + 'received_responses': read_int(f), + 'required_responses': read_int(f), + 'data_retrieved': bool(read_byte(f)), + } + + def to_exception(self): + return ReadTimeout(self.summary_msg(), **self.info) + + +class ReadFailureMessage(RequestExecutionException): + summary = "Replica(s) failed to execute read" + error_code = 0x1300 + + @staticmethod + def recv_error_info(f, protocol_version): + consistency = read_consistency_level(f) + received_responses = read_int(f) + required_responses = read_int(f) + + if ProtocolVersion.uses_error_code_map(protocol_version): + error_code_map = read_error_code_map(f) + failures = len(error_code_map) + else: + error_code_map = None + failures = read_int(f) + + data_retrieved = bool(read_byte(f)) + + return { + 'consistency': consistency, + 'received_responses': received_responses, + 'required_responses': required_responses, + 'failures': failures, + 'error_code_map': error_code_map, + 'data_retrieved': data_retrieved + } + + def to_exception(self): + return ReadFailure(self.summary_msg(), **self.info) + + +class FunctionFailureMessage(RequestExecutionException): + summary = "User Defined Function failure" + error_code = 0x1400 + + @staticmethod + def recv_error_info(f, protocol_version): + return { + 'keyspace': read_string(f), + 'function': read_string(f), + 'arg_types': [read_string(f) for _ in range(read_short(f))], + } + + def to_exception(self): + return FunctionFailure(self.summary_msg(), **self.info) + + +class WriteFailureMessage(RequestExecutionException): + summary = "Replica(s) failed to execute write" + error_code = 0x1500 + + @staticmethod + def recv_error_info(f, protocol_version): + consistency = read_consistency_level(f) + received_responses = read_int(f) + required_responses = read_int(f) + + if ProtocolVersion.uses_error_code_map(protocol_version): + error_code_map = read_error_code_map(f) + failures = len(error_code_map) + else: + error_code_map = None + failures = read_int(f) + + write_type = WriteType.name_to_value[read_string(f)] + + return { + 'consistency': consistency, + 'received_responses': received_responses, + 'required_responses': required_responses, + 'failures': failures, + 'error_code_map': error_code_map, + 'write_type': write_type + } + + def to_exception(self): + return WriteFailure(self.summary_msg(), **self.info) + + +class CDCWriteException(RequestExecutionException): + summary = 'Failed to execute write due to CDC space exhaustion.' + error_code = 0x1600 + + +class SyntaxException(RequestValidationException): + summary = 'Syntax error in CQL query' + error_code = 0x2000 + + +class UnauthorizedErrorMessage(RequestValidationException): + summary = 'Unauthorized' + error_code = 0x2100 + + def to_exception(self): + return Unauthorized(self.summary_msg()) + + +class InvalidRequestException(RequestValidationException): + summary = 'Invalid query' + error_code = 0x2200 + + def to_exception(self): + return InvalidRequest(self.summary_msg()) + + +class ConfigurationException(RequestValidationException): + summary = 'Query invalid because of configuration issue' + error_code = 0x2300 + + +class PreparedQueryNotFound(RequestValidationException): + summary = 'Matching prepared statement not found on this node' + error_code = 0x2500 + + @staticmethod + def recv_error_info(f, protocol_version): + # return the query ID + return read_binary_string(f) + + +class AlreadyExistsException(ConfigurationException): + summary = 'Item already exists' + error_code = 0x2400 + + @staticmethod + def recv_error_info(f, protocol_version): + return { + 'keyspace': read_string(f), + 'table': read_string(f), + } + + def to_exception(self): + return AlreadyExists(**self.info) + + +class ClientWriteError(RequestExecutionException): + summary = 'Client write failure.' + error_code = 0x8000 + + +class StartupMessage(_MessageType): + opcode = 0x01 + name = 'STARTUP' + + KNOWN_OPTION_KEYS = set(( + 'CQL_VERSION', + 'COMPRESSION', + 'NO_COMPACT' + )) + + def __init__(self, cqlversion, options): + self.cqlversion = cqlversion + self.options = options + + def send_body(self, f, protocol_version): + optmap = self.options.copy() + optmap['CQL_VERSION'] = self.cqlversion + write_stringmap(f, optmap) + + +class ReadyMessage(_MessageType): + opcode = 0x02 + name = 'READY' + + @classmethod + def recv_body(cls, *args): + return cls() + + +class AuthenticateMessage(_MessageType): + opcode = 0x03 + name = 'AUTHENTICATE' + + def __init__(self, authenticator): + self.authenticator = authenticator + + @classmethod + def recv_body(cls, f, *args): + authname = read_string(f) + return cls(authenticator=authname) + + +class CredentialsMessage(_MessageType): + opcode = 0x04 + name = 'CREDENTIALS' + + def __init__(self, creds): + self.creds = creds + + def send_body(self, f, protocol_version): + if protocol_version > 1: + raise UnsupportedOperation( + "Credentials-based authentication is not supported with " + "protocol version 2 or higher. Use the SASL authentication " + "mechanism instead.") + write_short(f, len(self.creds)) + for credkey, credval in self.creds.items(): + write_string(f, credkey) + write_string(f, credval) + + +class AuthChallengeMessage(_MessageType): + opcode = 0x0E + name = 'AUTH_CHALLENGE' + + def __init__(self, challenge): + self.challenge = challenge + + @classmethod + def recv_body(cls, f, *args): + return cls(read_binary_longstring(f)) + + +class AuthResponseMessage(_MessageType): + opcode = 0x0F + name = 'AUTH_RESPONSE' + + def __init__(self, response): + self.response = response + + def send_body(self, f, protocol_version): + write_longstring(f, self.response) + + +class AuthSuccessMessage(_MessageType): + opcode = 0x10 + name = 'AUTH_SUCCESS' + + def __init__(self, token): + self.token = token + + @classmethod + def recv_body(cls, f, *args): + return cls(read_longstring(f)) + + +class OptionsMessage(_MessageType): + opcode = 0x05 + name = 'OPTIONS' + + def send_body(self, f, protocol_version): + pass + + +class SupportedMessage(_MessageType): + opcode = 0x06 + name = 'SUPPORTED' + + def __init__(self, cql_versions, options): + self.cql_versions = cql_versions + self.options = options + + @classmethod + def recv_body(cls, f, *args): + options = read_stringmultimap(f) + cql_versions = options.pop('CQL_VERSION') + return cls(cql_versions=cql_versions, options=options) + + +# used for QueryMessage and ExecuteMessage +_VALUES_FLAG = 0x01 +_SKIP_METADATA_FLAG = 0x02 +_PAGE_SIZE_FLAG = 0x04 +_WITH_PAGING_STATE_FLAG = 0x08 +_WITH_SERIAL_CONSISTENCY_FLAG = 0x10 +_PROTOCOL_TIMESTAMP_FLAG = 0x20 +_NAMES_FOR_VALUES_FLAG = 0x40 # not used here +_WITH_KEYSPACE_FLAG = 0x80 +_PREPARED_WITH_KEYSPACE_FLAG = 0x01 +_PAGE_SIZE_BYTES_FLAG = 0x40000000 +_PAGING_OPTIONS_FLAG = 0x80000000 + + +class _QueryMessage(_MessageType): + + def __init__(self, query_params, consistency_level, + serial_consistency_level=None, fetch_size=None, + paging_state=None, timestamp=None, skip_meta=False, + continuous_paging_options=None, keyspace=None): + self.query_params = query_params + self.consistency_level = consistency_level + self.serial_consistency_level = serial_consistency_level + self.fetch_size = fetch_size + self.paging_state = paging_state + self.timestamp = timestamp + self.skip_meta = skip_meta + self.continuous_paging_options = continuous_paging_options + self.keyspace = keyspace + + def _write_query_params(self, f, protocol_version): + write_consistency_level(f, self.consistency_level) + flags = 0x00 + if self.query_params is not None: + flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now + + if self.serial_consistency_level: + if protocol_version >= 2: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + else: + raise UnsupportedOperation( + "Serial consistency levels require the use of protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2 " + "to support serial consistency levels.") + + if self.fetch_size: + if protocol_version >= 2: + flags |= _PAGE_SIZE_FLAG + else: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + + if self.paging_state: + if protocol_version >= 2: + flags |= _WITH_PAGING_STATE_FLAG + else: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + + if self.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + + if self.continuous_paging_options: + if ProtocolVersion.has_continuous_paging_support(protocol_version): + flags |= _PAGING_OPTIONS_FLAG + else: + raise UnsupportedOperation( + "Continuous paging may only be used with protocol version " + "ProtocolVersion.DSE_V1 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V1.") + + if self.keyspace is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + else: + raise UnsupportedOperation( + "Keyspaces may only be set on queries with protocol version " + "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_uint(f, flags) + else: + write_byte(f, flags) + + if self.query_params is not None: + write_short(f, len(self.query_params)) + for param in self.query_params: + write_value(f, param) + if self.fetch_size: + write_int(f, self.fetch_size) + if self.paging_state: + write_longstring(f, self.paging_state) + if self.serial_consistency_level: + write_consistency_level(f, self.serial_consistency_level) + if self.timestamp is not None: + write_long(f, self.timestamp) + if self.keyspace is not None: + write_string(f, self.keyspace) + if self.continuous_paging_options: + self._write_paging_options(f, self.continuous_paging_options, protocol_version) + + def _write_paging_options(self, f, paging_options, protocol_version): + write_int(f, paging_options.max_pages) + write_int(f, paging_options.max_pages_per_second) + if ProtocolVersion.has_continuous_paging_next_pages(protocol_version): + write_int(f, paging_options.max_queue_size) + + +class QueryMessage(_QueryMessage): + opcode = 0x07 + name = 'QUERY' + + def __init__(self, query, consistency_level, serial_consistency_level=None, + fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None): + self.query = query + super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, + paging_state, timestamp, False, continuous_paging_options, keyspace) + + def send_body(self, f, protocol_version): + write_longstring(f, self.query) + self._write_query_params(f, protocol_version) + + +class ExecuteMessage(_QueryMessage): + opcode = 0x0A + name = 'EXECUTE' + + def __init__(self, query_id, query_params, consistency_level, + serial_consistency_level=None, fetch_size=None, + paging_state=None, timestamp=None, skip_meta=False, + continuous_paging_options=None, result_metadata_id=None): + self.query_id = query_id + self.result_metadata_id = result_metadata_id + super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, + paging_state, timestamp, skip_meta, continuous_paging_options) + + def _write_query_params(self, f, protocol_version): + if protocol_version == 1: + if self.serial_consistency_level: + raise UnsupportedOperation( + "Serial consistency levels require the use of protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2 " + "to support serial consistency levels.") + if self.fetch_size or self.paging_state: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + write_short(f, len(self.query_params)) + for param in self.query_params: + write_value(f, param) + write_consistency_level(f, self.consistency_level) + else: + super(ExecuteMessage, self)._write_query_params(f, protocol_version) + + def send_body(self, f, protocol_version): + write_string(f, self.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, self.result_metadata_id) + self._write_query_params(f, protocol_version) + + +CUSTOM_TYPE = object() + +RESULT_KIND_VOID = 0x0001 +RESULT_KIND_ROWS = 0x0002 +RESULT_KIND_SET_KEYSPACE = 0x0003 +RESULT_KIND_PREPARED = 0x0004 +RESULT_KIND_SCHEMA_CHANGE = 0x0005 + + +class ResultMessage(_MessageType): + opcode = 0x08 + name = 'RESULT' + + kind = None + results = None + paging_state = None + + # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) + type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_')) + + _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 + _HAS_MORE_PAGES_FLAG = 0x0002 + _NO_METADATA_FLAG = 0x0004 + _CONTINUOUS_PAGING_FLAG = 0x40000000 + _CONTINUOUS_PAGING_LAST_FLAG = 0x80000000 + _METADATA_ID_FLAG = 0x0008 + + kind = None + + # These are all the things a result message might contain. They are populated according to 'kind' + column_names = None + column_types = None + parsed_rows = None + paging_state = None + continuous_paging_seq = None + continuous_paging_last = None + new_keyspace = None + column_metadata = None + query_id = None + bind_metadata = None + pk_indexes = None + schema_change_event = None + + def __init__(self, kind): + self.kind = kind + + def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + if self.kind == RESULT_KIND_VOID: + return + elif self.kind == RESULT_KIND_ROWS: + self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + elif self.kind == RESULT_KIND_SET_KEYSPACE: + self.new_keyspace = read_string(f) + elif self.kind == RESULT_KIND_PREPARED: + self.recv_results_prepared(f, protocol_version, user_type_map) + elif self.kind == RESULT_KIND_SCHEMA_CHANGE: + self.recv_results_schema_change(f, protocol_version) + else: + raise DriverException("Unknown RESULT kind: %d" % self.kind) + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + kind = read_int(f) + msg = cls(kind) + msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + return msg + + def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + self.recv_results_metadata(f, user_type_map) + column_metadata = self.column_metadata or result_metadata + rowcount = read_int(f) + rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + self.column_names = [c[2] for c in column_metadata] + self.column_types = [c[3] for c in column_metadata] + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] + + def decode_val(val, col_md, col_desc): + uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) + col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] + raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val + return col_type.from_binary(raw_bytes, protocol_version) + + def decode_row(row): + return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + + try: + self.parsed_rows = [decode_row(row) for row in rows] + except Exception: + for row in rows: + for val, col_md, col_desc in zip(row, column_metadata, col_descs): + try: + decode_val(val, col_md, col_desc) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], + col_md[3].cql_parameterized_type(), + str(e))) + + def recv_results_prepared(self, f, protocol_version, user_type_map): + self.query_id = read_binary_string(f) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + self.result_metadata_id = read_binary_string(f) + else: + self.result_metadata_id = None + self.recv_prepared_metadata(f, protocol_version, user_type_map) + + def recv_results_metadata(self, f, user_type_map): + flags = read_int(f) + colcount = read_int(f) + + if flags & self._HAS_MORE_PAGES_FLAG: + self.paging_state = read_binary_longstring(f) + + no_meta = bool(flags & self._NO_METADATA_FLAG) + if no_meta: + return + + if flags & self._CONTINUOUS_PAGING_FLAG: + self.continuous_paging_seq = read_int(f) + self.continuous_paging_last = flags & self._CONTINUOUS_PAGING_LAST_FLAG + + if flags & self._METADATA_ID_FLAG: + self.result_metadata_id = read_binary_string(f) + + glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC) + if glob_tblspec: + ksname = read_string(f) + cfname = read_string(f) + column_metadata = [] + for _ in range(colcount): + if glob_tblspec: + colksname = ksname + colcfname = cfname + else: + colksname = read_string(f) + colcfname = read_string(f) + colname = read_string(f) + coltype = self.read_type(f, user_type_map) + column_metadata.append((colksname, colcfname, colname, coltype)) + + self.column_metadata = column_metadata + + def recv_prepared_metadata(self, f, protocol_version, user_type_map): + flags = read_int(f) + colcount = read_int(f) + pk_indexes = None + if protocol_version >= 4: + num_pk_indexes = read_int(f) + pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] + + glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC) + if glob_tblspec: + ksname = read_string(f) + cfname = read_string(f) + bind_metadata = [] + for _ in range(colcount): + if glob_tblspec: + colksname = ksname + colcfname = cfname + else: + colksname = read_string(f) + colcfname = read_string(f) + colname = read_string(f) + coltype = self.read_type(f, user_type_map) + bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) + + if protocol_version >= 2: + self.recv_results_metadata(f, user_type_map) + + self.bind_metadata = bind_metadata + self.pk_indexes = pk_indexes + + def recv_results_schema_change(self, f, protocol_version): + self.schema_change_event = EventMessage.recv_schema_change(f, protocol_version) + + @classmethod + def read_type(cls, f, user_type_map): + optid = read_short(f) + try: + typeclass = cls.type_codes[optid] + except KeyError: + raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" + " entire result set." % (optid,)) + if typeclass in (ListType, SetType): + subtype = cls.read_type(f, user_type_map) + typeclass = typeclass.apply_parameters((subtype,)) + elif typeclass == MapType: + keysubtype = cls.read_type(f, user_type_map) + valsubtype = cls.read_type(f, user_type_map) + typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) + elif typeclass == TupleType: + num_items = read_short(f) + types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items)) + typeclass = typeclass.apply_parameters(types) + elif typeclass == UserType: + ks = read_string(f) + udt_name = read_string(f) + num_fields = read_short(f) + names, types = zip(*((read_string(f), cls.read_type(f, user_type_map)) + for _ in range(num_fields))) + specialized_type = typeclass.make_udt_class(ks, udt_name, names, types) + specialized_type.mapped_class = user_type_map.get(ks, {}).get(udt_name) + typeclass = specialized_type + elif typeclass == CUSTOM_TYPE: + classname = read_string(f) + typeclass = lookup_casstype(classname) + + return typeclass + + @staticmethod + def recv_row(f, colcount): + return [read_value(f) for _ in range(colcount)] + + +class PrepareMessage(_MessageType): + opcode = 0x09 + name = 'PREPARE' + + def __init__(self, query, keyspace=None): + self.query = query + self.keyspace = keyspace + + def send_body(self, f, protocol_version): + write_longstring(f, self.query) + + flags = 0x00 + + if self.keyspace is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _PREPARED_WITH_KEYSPACE_FLAG + else: + raise UnsupportedOperation( + "Keyspaces may only be set on queries with protocol version " + "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + + if ProtocolVersion.uses_prepare_flags(protocol_version): + write_uint(f, flags) + else: + # checks above should prevent this, but just to be safe... + if flags: + raise UnsupportedOperation( + "Attempted to set flags with value {flags:0=#8x} on" + "protocol version {pv}, which doesn't support flags" + "in prepared statements." + "Consider setting Cluster.protocol_version to 5 or DSE_V2." + "".format(flags=flags, pv=protocol_version)) + + if ProtocolVersion.uses_keyspace_flag(protocol_version): + if self.keyspace: + write_string(f, self.keyspace) + + +class BatchMessage(_MessageType): + opcode = 0x0D + name = 'BATCH' + + def __init__(self, batch_type, queries, consistency_level, + serial_consistency_level=None, timestamp=None, + keyspace=None): + self.batch_type = batch_type + self.queries = queries + self.consistency_level = consistency_level + self.serial_consistency_level = serial_consistency_level + self.timestamp = timestamp + self.keyspace = keyspace + + def send_body(self, f, protocol_version): + write_byte(f, self.batch_type.value) + write_short(f, len(self.queries)) + for prepared, string_or_query_id, params in self.queries: + if not prepared: + write_byte(f, 0) + write_longstring(f, string_or_query_id) + else: + write_byte(f, 1) + write_short(f, len(string_or_query_id)) + f.write(string_or_query_id) + write_short(f, len(params)) + for param in params: + write_value(f, param) + + write_consistency_level(f, self.consistency_level) + if protocol_version >= 3: + flags = 0 + if self.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if self.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if self.keyspace: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + else: + raise UnsupportedOperation( + "Keyspaces may only be set on queries with protocol version " + "5 or higher. Consider setting Cluster.protocol_version to 5.") + + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_int(f, flags) + else: + write_byte(f, flags) + + if self.serial_consistency_level: + write_consistency_level(f, self.serial_consistency_level) + if self.timestamp is not None: + write_long(f, self.timestamp) + + if ProtocolVersion.uses_keyspace_flag(protocol_version): + if self.keyspace is not None: + write_string(f, self.keyspace) + + +known_event_types = frozenset(( + 'TOPOLOGY_CHANGE', + 'STATUS_CHANGE', + 'SCHEMA_CHANGE' +)) + + +class RegisterMessage(_MessageType): + opcode = 0x0B + name = 'REGISTER' + + def __init__(self, event_list): + self.event_list = event_list + + def send_body(self, f, protocol_version): + write_stringlist(f, self.event_list) + + +class EventMessage(_MessageType): + opcode = 0x0C + name = 'EVENT' + + def __init__(self, event_type, event_args): + self.event_type = event_type + self.event_args = event_args + + @classmethod + def recv_body(cls, f, protocol_version, *args): + event_type = read_string(f).upper() + if event_type in known_event_types: + read_method = getattr(cls, 'recv_' + event_type.lower()) + return cls(event_type=event_type, event_args=read_method(f, protocol_version)) + raise NotSupportedError('Unknown event type %r' % event_type) + + @classmethod + def recv_topology_change(cls, f, protocol_version): + # "NEW_NODE" or "REMOVED_NODE" + change_type = read_string(f) + address = read_inet(f) + return dict(change_type=change_type, address=address) + + @classmethod + def recv_status_change(cls, f, protocol_version): + # "UP" or "DOWN" + change_type = read_string(f) + address = read_inet(f) + return dict(change_type=change_type, address=address) + + @classmethod + def recv_schema_change(cls, f, protocol_version): + # "CREATED", "DROPPED", or "UPDATED" + change_type = read_string(f) + if protocol_version >= 3: + target = read_string(f) + keyspace = read_string(f) + event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace} + if target != SchemaTargetType.KEYSPACE: + target_name = read_string(f) + if target == SchemaTargetType.FUNCTION: + event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) + elif target == SchemaTargetType.AGGREGATE: + event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) + else: + event[target.lower()] = target_name + else: + keyspace = read_string(f) + table = read_string(f) + if table: + event = {'target_type': SchemaTargetType.TABLE, 'change_type': change_type, 'keyspace': keyspace, 'table': table} + else: + event = {'target_type': SchemaTargetType.KEYSPACE, 'change_type': change_type, 'keyspace': keyspace} + return event + + +class ReviseRequestMessage(_MessageType): + + class RevisionType(object): + PAGING_CANCEL = 1 + PAGING_BACKPRESSURE = 2 + + opcode = 0xFF + name = 'REVISE_REQUEST' + + def __init__(self, op_type, op_id, next_pages=0): + self.op_type = op_type + self.op_id = op_id + self.next_pages = next_pages + + def send_body(self, f, protocol_version): + write_int(f, self.op_type) + write_int(f, self.op_id) + if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE: + if self.next_pages <= 0: + raise UnsupportedOperation("Continuous paging backpressure requires next_pages > 0") + elif not ProtocolVersion.has_continuous_paging_next_pages(protocol_version): + raise UnsupportedOperation( + "Continuous paging backpressure may only be used with protocol version " + "ProtocolVersion.DSE_V2 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V2.") + else: + write_int(f, self.next_pages) + + +class _ProtocolHandler(object): + """ + _ProtocolHander handles encoding and decoding messages. + + This class can be specialized to compose Handlers which implement alternative + result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster` + on initialization. + + Contracted class methods are :meth:`_ProtocolHandler.encode_message` and :meth:`_ProtocolHandler.decode_message`. + """ + + message_types_by_opcode = _message_types_by_opcode.copy() + """ + Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses + this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized + result decoding implementations. + """ + + column_encryption_policy = None + """Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler""" + + @classmethod + def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): + """ + Encodes a message using the specified frame parameters, and compressor + + :param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver + :param stream_id: protocol stream id for the frame header + :param protocol_version: version for the frame header, and used encoding contents + :param compressor: optional compression function to be used on the body + """ + flags = 0 + body = io.BytesIO() + if msg.custom_payload: + if protocol_version < 4: + raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") + flags |= CUSTOM_PAYLOAD_FLAG + write_bytesmap(body, msg.custom_payload) + msg.send_body(body, protocol_version) + body = body.getvalue() + + # With checksumming, the compression is done at the segment frame encoding + if (not ProtocolVersion.has_checksumming_support(protocol_version) + and compressor and len(body) > 0): + body = compressor(body) + flags |= COMPRESSED_FLAG + + if msg.tracing: + flags |= TRACING_FLAG + + if allow_beta_protocol_version: + flags |= USE_BETA_FLAG + + buff = io.BytesIO() + cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) + buff.write(body) + + return buff.getvalue() + + @staticmethod + def _write_header(f, version, flags, stream_id, opcode, length): + """ + Write a CQL protocol frame header. + """ + pack = v3_header_pack if version >= 3 else header_pack + f.write(pack(version, flags, stream_id, opcode)) + write_int(f, length) + + @classmethod + def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body, + decompressor, result_metadata): + """ + Decodes a native protocol message body + + :param protocol_version: version to use decoding contents + :param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type + :param stream_id: native protocol stream id from the frame header + :param flags: native protocol flags bitmap from the header + :param opcode: native protocol opcode from the header + :param body: frame body + :param decompressor: optional decompression function to inflate the body + :return: a message decoded from the body and frame attributes + """ + if (not ProtocolVersion.has_checksumming_support(protocol_version) and + flags & COMPRESSED_FLAG): + if decompressor is None: + raise RuntimeError("No de-compressor available for compressed frame!") + body = decompressor(body) + flags ^= COMPRESSED_FLAG + + body = io.BytesIO(body) + if flags & TRACING_FLAG: + trace_id = UUID(bytes=body.read(16)) + flags ^= TRACING_FLAG + else: + trace_id = None + + if flags & WARNING_FLAG: + warnings = read_stringlist(body) + flags ^= WARNING_FLAG + else: + warnings = None + + if flags & CUSTOM_PAYLOAD_FLAG: + custom_payload = read_bytesmap(body) + flags ^= CUSTOM_PAYLOAD_FLAG + else: + custom_payload = None + + flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment + + if flags: + log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) + + msg_class = cls.message_types_by_opcode[opcode] + msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata, cls.column_encryption_policy) + msg.stream_id = stream_id + msg.trace_id = trace_id + msg.custom_payload = custom_payload + msg.warnings = warnings + + if msg.warnings: + for w in msg.warnings: + log.warning("Server warning: %s", w) + + return msg + + +def cython_protocol_handler(colparser): + """ + Given a column parser to deserialize ResultMessages, return a suitable + Cython-based protocol handler. + + There are three Cython-based protocol handlers: + + - obj_parser.ListParser + decodes result messages into a list of tuples + + - obj_parser.LazyParser + decodes result messages lazily by returning an iterator + + - numpy_parser.NumPyParser + decodes result messages into NumPy arrays + + The default is to use obj_parser.ListParser + """ + from cassandra.row_parser import make_recv_results_rows + + class FastResultMessage(ResultMessage): + """ + Cython version of Result Message that has a faster implementation of + recv_results_row. + """ + # type_codes = ResultMessage.type_codes.copy() + code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) + recv_results_rows = make_recv_results_rows(colparser) + + class CythonProtocolHandler(_ProtocolHandler): + """ + Use FastResultMessage to decode query result message messages. + """ + + my_opcodes = _ProtocolHandler.message_types_by_opcode.copy() + my_opcodes[FastResultMessage.opcode] = FastResultMessage + message_types_by_opcode = my_opcodes + + col_parser = colparser + + return CythonProtocolHandler + + +if HAVE_CYTHON: + from cassandra.obj_parser import ListParser, LazyParser + ProtocolHandler = cython_protocol_handler(ListParser()) + LazyProtocolHandler = cython_protocol_handler(LazyParser()) +else: + # Use Python-based ProtocolHandler + ProtocolHandler = _ProtocolHandler + LazyProtocolHandler = None + + +if HAVE_CYTHON and HAVE_NUMPY: + from cassandra.numpy_parser import NumpyParser + NumpyProtocolHandler = cython_protocol_handler(NumpyParser()) +else: + NumpyProtocolHandler = None + + +def read_byte(f): + return int8_unpack(f.read(1)) + + +def write_byte(f, b): + f.write(uint8_pack(b)) + + +def read_int(f): + return int32_unpack(f.read(4)) + + +def read_uint_le(f, size=4): + """ + Read a sequence of little endian bytes and return an unsigned integer. + """ + + if size == 4: + value = uint32_le_unpack(f.read(4)) + else: + value = 0 + for i in range(size): + value |= (read_byte(f) & 0xFF) << 8 * i + + return value + + +def write_uint_le(f, i, size=4): + """ + Write an unsigned integer on a sequence of little endian bytes. + """ + if size == 4: + f.write(uint32_le_pack(i)) + else: + for j in range(size): + shift = j * 8 + write_byte(f, i >> shift & 0xFF) + + +def write_int(f, i): + f.write(int32_pack(i)) + + +def write_uint(f, i): + f.write(uint32_pack(i)) + + +def write_long(f, i): + f.write(uint64_pack(i)) + + +def read_short(f): + return uint16_unpack(f.read(2)) + + +def write_short(f, s): + f.write(uint16_pack(s)) + + +def read_consistency_level(f): + return read_short(f) + + +def write_consistency_level(f, cl): + write_short(f, cl) + + +def read_string(f): + size = read_short(f) + contents = f.read(size) + return contents.decode('utf8') + + +def read_binary_string(f): + size = read_short(f) + contents = f.read(size) + return contents + + +def write_string(f, s): + if isinstance(s, str): + s = s.encode('utf8') + write_short(f, len(s)) + f.write(s) + + +def read_binary_longstring(f): + size = read_int(f) + contents = f.read(size) + return contents + + +def read_longstring(f): + return read_binary_longstring(f).decode('utf8') + + +def write_longstring(f, s): + if isinstance(s, str): + s = s.encode('utf8') + write_int(f, len(s)) + f.write(s) + + +def read_stringlist(f): + numstrs = read_short(f) + return [read_string(f) for _ in range(numstrs)] + + +def write_stringlist(f, stringlist): + write_short(f, len(stringlist)) + for s in stringlist: + write_string(f, s) + + +def read_stringmap(f): + numpairs = read_short(f) + strmap = {} + for _ in range(numpairs): + k = read_string(f) + strmap[k] = read_string(f) + return strmap + + +def write_stringmap(f, strmap): + write_short(f, len(strmap)) + for k, v in strmap.items(): + write_string(f, k) + write_string(f, v) + + +def read_bytesmap(f): + numpairs = read_short(f) + bytesmap = {} + for _ in range(numpairs): + k = read_string(f) + bytesmap[k] = read_value(f) + return bytesmap + + +def write_bytesmap(f, bytesmap): + write_short(f, len(bytesmap)) + for k, v in bytesmap.items(): + write_string(f, k) + write_value(f, v) + + +def read_stringmultimap(f): + numkeys = read_short(f) + strmmap = {} + for _ in range(numkeys): + k = read_string(f) + strmmap[k] = read_stringlist(f) + return strmmap + + +def write_stringmultimap(f, strmmap): + write_short(f, len(strmmap)) + for k, v in strmmap.items(): + write_string(f, k) + write_stringlist(f, v) + + +def read_error_code_map(f): + numpairs = read_int(f) + error_code_map = {} + for _ in range(numpairs): + endpoint = read_inet_addr_only(f) + error_code_map[endpoint] = read_short(f) + return error_code_map + + +def read_value(f): + size = read_int(f) + if size < 0: + return None + return f.read(size) + + +def write_value(f, v): + if v is None: + write_int(f, -1) + elif v is _UNSET_VALUE: + write_int(f, -2) + else: + write_int(f, len(v)) + f.write(v) + + +def read_inet_addr_only(f): + size = read_byte(f) + addrbytes = f.read(size) + if size == 4: + addrfam = socket.AF_INET + elif size == 16: + addrfam = socket.AF_INET6 + else: + raise InternalError("bad inet address: %r" % (addrbytes,)) + return util.inet_ntop(addrfam, addrbytes) + + +def read_inet(f): + addr = read_inet_addr_only(f) + port = read_int(f) + return (addr, port) + + +def write_inet(f, addrtuple): + addr, port = addrtuple + if ':' in addr: + addrfam = socket.AF_INET6 + else: + addrfam = socket.AF_INET + addrbytes = util.inet_pton(addrfam, addr) + write_byte(f, len(addrbytes)) + f.write(addrbytes) + write_int(f, port) diff --git a/cassandra/query.py b/cassandra/query.py index c0143e22c5..40e4d63c9e 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -1,23 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ This module holds classes for working with prepared statements and specifying consistency levels and retry policies for individual queries. """ -from datetime import datetime, timedelta +from collections import namedtuple +from datetime import datetime, timedelta, timezone +import re import struct import time +import warnings + +from cassandra import ConsistencyLevel, OperationTimedOut +from cassandra.util import unix_time_from_uuid1 +from cassandra.encoder import Encoder +import cassandra.encoder +from cassandra.policies import ColDesc +from cassandra.protocol import _UNSET_VALUE +from cassandra.util import OrderedDict, _sanitize_identifiers + +import logging +log = logging.getLogger(__name__) + +UNSET_VALUE = _UNSET_VALUE +""" +Specifies an unset value when binding a prepared statement. + +Unset values are ignored, allowing prepared statements to be used without specify + +See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics. + +.. versionadded:: 2.6.0 + +Only valid when using native protocol v4+ +""" + +NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') +START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') +END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') + +_clean_name_cache = {} + + +def _clean_column_name(name): + try: + return _clean_name_cache[name] + except KeyError: + clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + _clean_name_cache[name] = clean + return clean + + +def tuple_factory(colnames, rows): + """ + Returns each row as a tuple + + Example:: + + >>> from cassandra.query import tuple_factory + >>> session = cluster.connect('mykeyspace') + >>> session.row_factory = tuple_factory + >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") + >>> print(rows[0]) + ('Bob', 42) + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + return rows + +class PseudoNamedTupleRow(object): + """ + Helper class for pseudo_named_tuple_factory. These objects provide an + __iter__ interface, as well as index- and attribute-based access to values, + but otherwise do not attempt to implement the full namedtuple or iterable + interface. + """ + def __init__(self, ordered_dict): + self._dict = ordered_dict + self._tuple = tuple(ordered_dict.values()) + + def __getattr__(self, name): + return self._dict[name] + + def __getitem__(self, idx): + return self._tuple[idx] + + def __iter__(self): + return iter(self._tuple) + + def __repr__(self): + return '{t}({od})'.format(t=self.__class__.__name__, + od=self._dict) + + +def pseudo_namedtuple_factory(colnames, rows): + """ + Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback + factory for cases where :meth:`.named_tuple_factory` fails to create rows. + """ + return [PseudoNamedTupleRow(od) + for od in ordered_dict_factory(colnames, rows)] + + +def named_tuple_factory(colnames, rows): + """ + Returns each row as a `namedtuple `_. + This is the default row factory. + + Example:: + + >>> from cassandra.query import named_tuple_factory + >>> session = cluster.connect('mykeyspace') + >>> session.row_factory = named_tuple_factory + >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") + >>> user = rows[0] + + >>> # you can access field by their name: + >>> print("name: %s, age: %d" % (user.name, user.age)) + name: Bob, age: 42 + + >>> # or you can access fields by their position (like a tuple) + >>> name, age = user + >>> print("name: %s, age: %d" % (name, age)) + name: Bob, age: 42 + >>> name = user[0] + >>> age = user[1] + >>> print("name: %s, age: %d" % (name, age)) + name: Bob, age: 42 + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + clean_column_names = map(_clean_column_name, colnames) + try: + Row = namedtuple('Row', clean_column_names) + except SyntaxError: + warnings.warn( + "Failed creating namedtuple for a result because there were too " + "many columns. This is due to a Python limitation that affects " + "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " + "created with {substitute_factory_name}, which lacks some namedtuple " + "features and is slower. To avoid slower performance accessing " + "values on row objects, Upgrade to Python 3.7, or use a different " + "row factory. (column names: {colnames})".format( + substitute_factory_name=pseudo_namedtuple_factory.__name__, + colnames=colnames + ) + ) + return pseudo_namedtuple_factory(colnames, rows) + except Exception: + clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt + log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + "Avoid this by choosing different names, using SELECT \"\" AS aliases, " + "or specifying a different row_factory on your Session" % + (colnames, clean_column_names)) + Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) + + return [Row(*row) for row in rows] + + +def dict_factory(colnames, rows): + """ + Returns each row as a dict. + + Example:: + + >>> from cassandra.query import dict_factory + >>> session = cluster.connect('mykeyspace') + >>> session.row_factory = dict_factory + >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") + >>> print(rows[0]) + {u'age': 42, u'name': u'Bob'} + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + return [dict(zip(colnames, row)) for row in rows] + + +def ordered_dict_factory(colnames, rows): + """ + Like :meth:`~cassandra.query.dict_factory`, but returns each row as an OrderedDict, + so the order of the columns is preserved. + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + return [OrderedDict(zip(colnames, row)) for row in rows] + + +FETCH_SIZE_UNSET = object() -from cassandra import ConsistencyLevel -from cassandra.cqltypes import unix_time_from_uuid1 -from cassandra.decoder import (cql_encoders, cql_encode_object, - cql_encode_sequence) class Statement(object): """ - An abstract class representing a single query. There are two subclasses: - :class:`.SimpleStatement` and :class:`.BoundStatement`. These can - be passed to :meth:`.Session.execute()`. + An abstract class representing a single query. There are three subclasses: + :class:`.SimpleStatement`, :class:`.BoundStatement`, and :class:`.BatchStatement`. + These can be passed to :meth:`.Session.execute()`. """ retry_policy = None @@ -27,34 +225,86 @@ class Statement(object): will be retried. """ - trace = None + consistency_level = None """ - If :meth:`.Session.execute()` is run with `trace` set to :const:`True`, - this will be set to a :class:`.QueryTrace` instance. + The :class:`.ConsistencyLevel` to be used for this operation. Defaults + to :const:`None`, which means that the default consistency level for + the Session this is executed in will be used. """ - consistency_level = ConsistencyLevel.ONE + fetch_size = FETCH_SIZE_UNSET """ - The :class:`.ConsistencyLevel` to be used for this operation. Defaults - to :attr:`.ConsistencyLevel.ONE`. + How many rows will be fetched at a time. This overrides the default + of :attr:`.Session.default_fetch_size` + + This only takes effect when protocol version 2 or higher is used. + See :attr:`.Cluster.protocol_version` for details. + + .. versionadded:: 2.0.0 + """ + + keyspace = None + """ + The string name of the keyspace this query acts on. This is used when + :class:`~.TokenAwarePolicy` is configured in the profile load balancing policy. + + It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`, + but must be set explicitly on :class:`.SimpleStatement`. + + .. versionadded:: 2.1.3 + """ + + custom_payload = None + """ + :ref:`custom_payload` to be passed to the server. + + These are only allowed when using protocol version 4 or higher. + + .. versionadded:: 2.6.0 + """ + + is_idempotent = False + """ + Flag indicating whether this statement is safe to run multiple times in speculative execution. """ + _serial_consistency_level = None _routing_key = None - def __init__(self, retry_policy=None, tracing_enabled=False, - consistency_level=ConsistencyLevel.ONE, routing_key=None): - self.retry_policy = retry_policy - self.tracing_enabled = tracing_enabled - self.consistency_level = consistency_level + def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, + is_idempotent=False): + if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors + raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + if retry_policy is not None: + self.retry_policy = retry_policy + if consistency_level is not None: + self.consistency_level = consistency_level self._routing_key = routing_key + if serial_consistency_level is not None: + self.serial_consistency_level = serial_consistency_level + if fetch_size is not FETCH_SIZE_UNSET: + self.fetch_size = fetch_size + if keyspace is not None: + self.keyspace = keyspace + if custom_payload is not None: + self.custom_payload = custom_payload + self.is_idempotent = is_idempotent + + def _key_parts_packed(self, parts): + for p in parts: + l = len(p) + yield struct.pack(">H%dsB" % l, l, p, 0) def _get_routing_key(self): return self._routing_key def _set_routing_key(self, key): if isinstance(key, (list, tuple)): - self._routing_key = "".join(struct.pack("HsB", len(component), component, 0) - for component in key) + if len(key) == 1: + self._routing_key = key[0] + else: + self._routing_key = b"".join(self._key_parts_packed(key)) else: self._routing_key = key @@ -74,37 +324,87 @@ def _del_routing_key(self): components should be strings. """) - @property - def keyspace(self): - """ - The string name of the keyspace this query acts on. + def _get_serial_consistency_level(self): + return self._serial_consistency_level + + def _set_serial_consistency_level(self, serial_consistency_level): + if (serial_consistency_level is not None and + not ConsistencyLevel.is_serial(serial_consistency_level)): + raise ValueError( + "serial_consistency_level must be either ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL") + self._serial_consistency_level = serial_consistency_level + + def _del_serial_consistency_level(self): + self._serial_consistency_level = None + + serial_consistency_level = property( + _get_serial_consistency_level, + _set_serial_consistency_level, + _del_serial_consistency_level, """ - return None + The serial consistency level is only used by conditional updates + (``INSERT``, ``UPDATE`` and ``DELETE`` with an ``IF`` condition). For + those, the ``serial_consistency_level`` defines the consistency level of + the serial phase (or "paxos" phase) while the normal + :attr:`~.consistency_level` defines the consistency for the "learn" phase, + i.e. what type of reads will be guaranteed to see the update right away. + For example, if a conditional write has a :attr:`~.consistency_level` of + :attr:`~.ConsistencyLevel.QUORUM` (and is successful), then a + :attr:`~.ConsistencyLevel.QUORUM` read is guaranteed to see that write. + But if the regular :attr:`~.consistency_level` of that write is + :attr:`~.ConsistencyLevel.ANY`, then only a read with a + :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.SERIAL` is + guaranteed to see it (even a read with consistency + :attr:`~.ConsistencyLevel.ALL` is not guaranteed to be enough). + + The serial consistency can only be one of :attr:`~.ConsistencyLevel.SERIAL` + or :attr:`~.ConsistencyLevel.LOCAL_SERIAL`. While ``SERIAL`` guarantees full + linearizability (with other ``SERIAL`` updates), ``LOCAL_SERIAL`` only + guarantees it in the local data center. + + The serial consistency level is ignored for any query that is not a + conditional update. Serial reads should use the regular + :attr:`consistency_level`. + + Serial consistency levels may only be used against Cassandra 2.0+ + and the :attr:`~.Cluster.protocol_version` must be set to 2 or higher. + + See :doc:`/lwt` for a discussion on how to work with results returned from + conditional statements. + + .. versionadded:: 2.0.0 + """) class SimpleStatement(Statement): """ - A simple, un-prepared query. All attributes of :class:`Statement` apply - to this class as well. + A simple, un-prepared query. """ - def __init__(self, query_string, *args, **kwargs): + def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, + custom_payload=None, is_idempotent=False): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the `parameters` argument of :meth:`.Session.execute()`. + + See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, *args, **kwargs) + Statement.__init__(self, retry_policy, consistency_level, routing_key, + serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) self._query_string = query_string @property def query_string(self): return self._query_string - def __repr__(self): - consistency = ConsistencyLevel.value_to_name[self.consistency_level] + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) + __repr__ = __str__ class PreparedStatement(object): @@ -112,72 +412,115 @@ class PreparedStatement(object): A statement that has been prepared against at least one Cassandra node. Instances of this class should not be created directly, but through :meth:`.Session.prepare()`. + + A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement + may affect performance (as the operation requires a network roundtrip). + + |prepared_stmt_head|: Do not use ``*`` in prepared statements if you might + change the schema of the table being queried. The driver and server each + maintain a map between metadata for a schema and statements that were + prepared against that schema. When a user changes a schema, e.g. by adding + or removing a column, the server invalidates its mappings involving that + schema. However, there is currently no way to propagate that invalidation + to drivers. Thus, after a schema change, the driver will incorrectly + interpret the results of ``SELECT *`` queries prepared before the schema + change. This is currently being addressed in `CASSANDRA-10786 + `_. + + .. |prepared_stmt_head| raw:: html + + A note about * in prepared statements """ - column_metadata = None + column_metadata = None #TODO: make this bind_metadata in next major + retry_policy = None + consistency_level = None + custom_payload = None + fetch_size = FETCH_SIZE_UNSET + keyspace = None # change to prepared_keyspace in major release + protocol_version = None query_id = None query_string = None - keyspace = None - + result_metadata = None + result_metadata_id = None + column_encryption_policy = None routing_key_indexes = None + _routing_key_index_set = None + serial_consistency_level = None # TODO never used? - consistency_level = ConsistencyLevel.ONE - - def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace): + def __init__(self, column_metadata, query_id, routing_key_indexes, query, + keyspace, protocol_version, result_metadata, result_metadata_id, + column_encryption_policy=None): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace + self.protocol_version = protocol_version + self.result_metadata = result_metadata + self.result_metadata_id = result_metadata_id + self.column_encryption_policy = column_encryption_policy + self.is_idempotent = False @classmethod - def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace): + def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, + query, prepared_keyspace, protocol_version, result_metadata, + result_metadata_id, column_encryption_policy=None): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, query, keyspace) + return PreparedStatement(column_metadata, query_id, None, + query, prepared_keyspace, protocol_version, result_metadata, + result_metadata_id, column_encryption_policy) - partition_key_columns = None - routing_key_indexes = None + if pk_indexes: + routing_key_indexes = pk_indexes + else: + routing_key_indexes = None - ks_name, table_name, _, _ = column_metadata[0] - ks_meta = cluster_metadata.keyspaces.get(ks_name) - if ks_meta: - table_meta = ks_meta.tables.get(table_name) - if table_meta: - partition_key_columns = table_meta.partition_key + first_col = column_metadata[0] + ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name) + if ks_meta: + table_meta = ks_meta.tables.get(first_col.table_name) + if table_meta: + partition_key_columns = table_meta.partition_key - # make a map of {column_name: index} for each column in the statement - statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata)) + # make a map of {column_name: index} for each column in the statement + statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) - # a list of which indexes in the statement correspond to partition key items - try: - routing_key_indexes = [statement_indexes[c.name] - for c in partition_key_columns] - except KeyError: - pass # we're missing a partition key component in the prepared - # statement; just leave routing_key_indexes as None + # a list of which indexes in the statement correspond to partition key items + try: + routing_key_indexes = [statement_indexes[c.name] + for c in partition_key_columns] + except KeyError: # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None - return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, keyspace) + return PreparedStatement(column_metadata, query_id, routing_key_indexes, + query, prepared_keyspace, protocol_version, result_metadata, + result_metadata_id, column_encryption_policy) def bind(self, values): """ Creates and returns a :class:`BoundStatement` instance using `values`. - The `values` parameter *must* be a sequence, such as a tuple or list, - even if there is only one value to bind. + + See :meth:`BoundStatement.bind` for rules on input ``values``. """ return BoundStatement(self).bind(values) - def __repr__(self): - consistency = ConsistencyLevel.value_to_name[self.consistency_level] + def is_routing_key_index(self, i): + if self._routing_key_index_set is None: + self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() + return i in self._routing_key_index_set + + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) + __repr__ = __str__ class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. These may be created directly or through :meth:`.PreparedStatement.bind()`. - - All attributes of :class:`Statement` apply to this class as well. """ prepared_statement = None @@ -190,50 +533,131 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, *args, **kwargs): + def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, + custom_payload=None): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. - All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`. + + See :class:`Statement` attributes for a description of the other parameters. """ - self.consistency_level = prepared_statement.consistency_level self.prepared_statement = prepared_statement + + self.retry_policy = prepared_statement.retry_policy + self.consistency_level = prepared_statement.consistency_level + self.serial_consistency_level = prepared_statement.serial_consistency_level + self.fetch_size = prepared_statement.fetch_size + self.custom_payload = prepared_statement.custom_payload + self.is_idempotent = prepared_statement.is_idempotent self.values = [] - Statement.__init__(self, *args, **kwargs) + meta = prepared_statement.column_metadata + if meta: + self.keyspace = meta[0].keyspace_name + + Statement.__init__(self, retry_policy, consistency_level, routing_key, + serial_consistency_level, fetch_size, keyspace, custom_payload, + prepared_statement.is_idempotent) def bind(self, values): """ Binds a sequence of values for the prepared statement parameters - and returns this instance. Note that `values` *must* be a - sequence, even if you are only binding one value. + and returns this instance. Note that `values` *must* be: + + * a sequence, even if you are only binding one value, or + * a dict that relates 1-to-1 between dict keys and columns + + .. versionchanged:: 2.6.0 + + :data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters + in a sequence, or by name in a dict. Additionally, when using protocol v4+: + + * short sequences will be extended to match bind parameters with UNSET_VALUE + * names may be omitted from a dict with UNSET_VALUE implied. + + .. versionchanged:: 3.0.0 + + method will not throw if extra keys are present in bound dict (PYTHON-178) """ + if values is None: + values = () + proto_version = self.prepared_statement.protocol_version col_meta = self.prepared_statement.column_metadata - if len(values) > len(col_meta): + ce_policy = self.prepared_statement.column_encryption_policy + + # special case for binding dicts + if isinstance(values, dict): + values_dict = values + values = [] + + # sort values accordingly + for col in col_meta: + try: + values.append(values_dict[col.name]) + except KeyError: + if proto_version >= 4: + values.append(UNSET_VALUE) + else: + raise KeyError( + 'Column name `%s` not found in bound dict.' % + (col.name)) + + value_len = len(values) + col_meta_len = len(col_meta) + + if value_len > col_meta_len: raise ValueError( "Too many arguments provided to bind() (got %d, expected %d)" % (len(values), len(col_meta))) + # this is fail-fast for clarity pre-v4. When v4 can be assumed, + # the error will be better reported when UNSET_VALUE is implicitly added. + if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ + value_len < len(self.prepared_statement.routing_key_indexes): + raise ValueError( + "Too few arguments provided to bind() (got %d, required %d for routing key)" % + (value_len, len(self.prepared_statement.routing_key_indexes))) + + self.raw_values = values self.values = [] for value, col_spec in zip(values, col_meta): if value is None: self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: - col_type = col_spec[-1] - try: - self.values.append(col_type.serialize(value)) - except struct.error: - col_name = col_spec[2] - expected_type = col_type + col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + uses_ce = ce_policy and ce_policy.contains_column(col_desc) + col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type + col_bytes = col_type.serialize(value, proto_version) + if uses_ce: + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + self.values.append(col_bytes) + except (TypeError, struct.error) as exc: actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) - err = InvalidParameterTypeError(col_name=col_name, - expected_type=expected_type, - actual_type=actual_type) - raise err + if proto_version >= 4: + diff = col_meta_len - len(self.values) + if diff: + for _ in range(diff): + self._append_unset_value() return self + def _append_unset_value(self): + next_index = len(self.values) + if self.prepared_statement.is_routing_key_index(next_index): + col_meta = self.prepared_statement.column_metadata[next_index] + raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) + self.values.append(UNSET_VALUE) + @property def routing_key(self): if not self.prepared_statement.routing_key_indexes: @@ -246,78 +670,248 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - components = [] - for statement_index in routing_indexes: - val = self.values[statement_index] - components.append(struct.pack("HsB", len(val), val, 0)) - - self._routing_key = "".join(components) + self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) return self._routing_key - @property - def keyspace(self): - meta = self.prepared_statement.column_metadata - if meta: - return meta[0][0] - else: - return None + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') + return (u'' % + (self.prepared_statement.query_string, self.raw_values, consistency)) + __repr__ = __str__ -class ValueSequence(object): +class BatchType(object): """ - A wrapper class that is used to specify that a sequence of values should - be treated as a CQL list of values instead of a single column collection when used - as part of the `parameters` argument for :meth:`.Session.execute()`. + A BatchType is used with :class:`.BatchStatement` instances to control + the atomicity of the batch operation. - This is typically needed when supplying a list of keys to select. - For example:: + .. versionadded:: 2.0.0 + """ - >>> my_user_ids = ('alice', 'bob', 'charles') - >>> query = "SELECT * FROM users WHERE user_id IN ?" - >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) + LOGGED = None + """ + Atomic batch operation. + """ + + UNLOGGED = None + """ + Non-atomic batch operation. + """ + COUNTER = None + """ + Batches of counter operations. """ - def __init__(self, sequence): - self.sequence = sequence + def __init__(self, name, value): + self.name = name + self.value = value def __str__(self): - return cql_encode_sequence(self.sequence) + return self.name + + def __repr__(self): + return "BatchType.%s" % (self.name, ) -def bind_params(query, params): - if isinstance(params, dict): - return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) - for k, v in params.iteritems()) - else: - return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) - for v in params) +BatchType.LOGGED = BatchType("LOGGED", 0) +BatchType.UNLOGGED = BatchType("UNLOGGED", 1) +BatchType.COUNTER = BatchType("COUNTER", 2) -class TraceUnavailable(Exception): +class BatchStatement(Statement): """ - Raised when complete trace details cannot be fetched from Cassandra. + A protocol-level batch of operations which are applied atomically + by default. + + .. versionadded:: 2.0.0 """ - pass + batch_type = None + """ + The :class:`.BatchType` for the batch operation. Defaults to + :attr:`.BatchType.LOGGED`. + """ -class InvalidParameterTypeError(TypeError): + serial_consistency_level = None """ - Raised when a used tries to bind a prepared statement with an argument of an - invalid type. + The same as :attr:`.Statement.serial_consistency_level`, but is only + supported when using protocol version 3 or higher. """ - def __init__(self, col_name, expected_type, actual_type): - self.col_name = col_name - self.expected_type = expected_type - self.actual_type = actual_type + _statements_and_parameters = None + _session = None + + def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, + consistency_level=None, serial_consistency_level=None, + session=None, custom_payload=None): + """ + `batch_type` specifies The :class:`.BatchType` for the batch operation. + Defaults to :attr:`.BatchType.LOGGED`. + + `retry_policy` should be a :class:`~.RetryPolicy` instance for + controlling retries on the operation. + + `consistency_level` should be a :class:`~.ConsistencyLevel` value + to be used for all operations in the batch. + + `custom_payload` is a :ref:`custom_payload` passed to the server. + Note: as Statement objects are added to the batch, this map is + updated with any values found in their custom payloads. These are + only allowed when using protocol version 4 or higher. + + Example usage: + + .. code-block:: python + + insert_user = session.prepare("INSERT INTO users (name, age) VALUES (?, ?)") + batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM) + + for (name, age) in users_to_insert: + batch.add(insert_user, (name, age)) + + session.execute(batch) + + You can also mix different types of operations within a batch: + + .. code-block:: python + + batch = BatchStatement() + batch.add(SimpleStatement("INSERT INTO users (name, age) VALUES (%s, %s)"), (name, age)) + batch.add(SimpleStatement("DELETE FROM pending_users WHERE name=%s"), (name,)) + session.execute(batch) + + .. versionadded:: 2.0.0 + + .. versionchanged:: 2.1.0 + Added `serial_consistency_level` as a parameter + + .. versionchanged:: 2.6.0 + Added `custom_payload` as a parameter + """ + self.batch_type = batch_type + self._statements_and_parameters = [] + self._session = session + Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + + def clear(self): + """ + This is a convenience method to clear a batch statement for reuse. + + *Note:* it should not be used concurrently with uncompleted execution futures executing the same + ``BatchStatement``. + """ + del self._statements_and_parameters[:] + self.keyspace = None + self.routing_key = None + if self.custom_payload: + self.custom_payload.clear() + + def add(self, statement, parameters=None): + """ + Adds a :class:`.Statement` and optional sequence of parameters + to be used with the statement to the batch. + + Like with other statements, parameters must be a sequence, even + if there is only one item. + """ + if isinstance(statement, str): + if parameters: + encoder = Encoder() if self._session is None else self._session.encoder + statement = bind_params(statement, parameters, encoder) + self._add_statement_and_params(False, statement, ()) + elif isinstance(statement, PreparedStatement): + query_id = statement.query_id + bound_statement = statement.bind(() if parameters is None else parameters) + self._update_state(bound_statement) + self._add_statement_and_params(True, query_id, bound_statement.values) + elif isinstance(statement, BoundStatement): + if parameters: + raise ValueError( + "Parameters cannot be passed with a BoundStatement " + "to BatchStatement.add()") + self._update_state(statement) + self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) + else: + # it must be a SimpleStatement + query_string = statement.query_string + if parameters: + encoder = Encoder() if self._session is None else self._session.encoder + query_string = bind_params(query_string, parameters, encoder) + self._update_state(statement) + self._add_statement_and_params(False, query_string, ()) + return self + + def add_all(self, statements, parameters): + """ + Adds a sequence of :class:`.Statement` objects and a matching sequence + of parameters to the batch. Statement and parameter sequences must be of equal length or + one will be truncated. :const:`None` can be used in the parameters position where are needed. + """ + for statement, value in zip(statements, parameters): + self.add(statement, value) + + def _add_statement_and_params(self, is_prepared, statement, parameters): + if len(self._statements_and_parameters) >= 0xFFFF: + raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + self._statements_and_parameters.append((is_prepared, statement, parameters)) + + def _maybe_set_routing_attributes(self, statement): + if self.routing_key is None: + if statement.keyspace and statement.routing_key: + self.routing_key = statement.routing_key + self.keyspace = statement.keyspace + + def _update_custom_payload(self, statement): + if statement.custom_payload: + if self.custom_payload is None: + self.custom_payload = {} + self.custom_payload.update(statement.custom_payload) + + def _update_state(self, statement): + self._maybe_set_routing_attributes(statement) + self._update_custom_payload(statement) + + def __len__(self): + return len(self._statements_and_parameters) + + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') + return (u'' % + (self.batch_type, len(self), consistency)) + __repr__ = __str__ + + +ValueSequence = cassandra.encoder.ValueSequence +""" +A wrapper class that is used to specify that a sequence of values should +be treated as a CQL list of values instead of a single column collection when used +as part of the `parameters` argument for :meth:`.Session.execute()`. + +This is typically needed when supplying a list of keys to select. +For example:: - values = (self.col_name, self.expected_type, self.actual_type) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s' % values) + >>> my_user_ids = ('alice', 'bob', 'charles') + >>> query = "SELECT * FROM users WHERE user_id IN %s" + >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) + +""" - super(InvalidParameterTypeError, self).__init__(message) + +def bind_params(query, params, encoder): + if isinstance(params, dict): + return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) + else: + return query % tuple(encoder.cql_encode_all_types(v) for v in params) + + +class TraceUnavailable(Exception): + """ + Raised when complete trace details cannot be fetched from Cassandra. + """ + pass class QueryTrace(object): @@ -343,6 +937,13 @@ class QueryTrace(object): A :class:`datetime.timedelta` measure of the duration of the query. """ + client = None + """ + The IP address of the client that issued this request + + This is only available when using Cassandra 2.2+ + """ + coordinator = None """ The IP address of the host that acted as coordinator for this request. @@ -373,39 +974,78 @@ class QueryTrace(object): _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 - _MAX_ATTEMPTS = 5 def __init__(self, trace_id, session): self.trace_id = trace_id self._session = session - def populate(self): + def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): """ Retrieves the actual tracing details from Cassandra and populates the attributes of this instance. Because tracing details are stored asynchronously by Cassandra, this may need to retry the session - detail fetch up to five times before raising :exc:`.TraceUnavailable`. + detail fetch. If the trace is still not available after `max_wait` + seconds, :exc:`.TraceUnavailable` will be raised; if `max_wait` is + :const:`None`, this will retry forever. + + `wait_for_complete=False` bypasses the wait for duration to be populated. + This can be used to query events from partial sessions. - Currently intended for internal use only. + `query_cl` specifies a consistency level to use for polling the trace tables, + if different from the session default. """ attempt = 0 - while attempt <= self._MAX_ATTEMPTS: - attempt += 1 - session_results = self._session.execute(self._SELECT_SESSIONS_FORMAT, (self.trace_id,)) - if not session_results or session_results[0].duration is None: - time.sleep(self._BASE_RETRY_SLEEP * attempt) + start = time.time() + while True: + time_spent = time.time() - start + if max_wait is not None and time_spent >= max_wait: + raise TraceUnavailable( + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + + log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) + session_results = self._execute( + SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + + # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries + session_row = session_results.one() if session_results else None + is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None + if not session_results or (wait_for_complete and not is_complete): + time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + attempt += 1 continue + if is_complete: + log.debug("Fetched trace info for trace ID: %s", self.trace_id) + else: + log.debug("Fetching partial trace info for trace ID: %s", self.trace_id) - session_row = session_results[0] self.request_type = session_row.request - self.duration = timedelta(microseconds=session_row.duration) + self.duration = timedelta(microseconds=session_row.duration) if is_complete else None self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters - - event_results = self._session.execute(self._SELECT_EVENTS_FORMAT, (self.trace_id,)) + # since C* 2.2 + self.client = getattr(session_row, 'client', None) + + log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + time_spent = time.time() - start + event_results = self._execute( + SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + log.debug("Fetched trace events for trace ID: %s", self.trace_id) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) + break + + def _execute(self, query, parameters, time_spent, max_wait): + timeout = (max_wait - time_spent) if max_wait is not None else None + future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + # in case the user switched the row factory, set it to namedtuple for this query + future.row_factory = named_tuple_factory + future.send_request() + + try: + return future.result() + except OperationTimedOut: + raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) def __str__(self): return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ @@ -447,10 +1087,27 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) + self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) self.source = source - self.source_elapsed = timedelta(microseconds=source_elapsed) + if source_elapsed is not None: + self.source_elapsed = timedelta(microseconds=source_elapsed) + else: + self.source_elapsed = None self.thread_name = thread_name def __str__(self): return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + + +# TODO remove next major since we can target using the `host` attribute of session.execute +class HostTargetingStatement(object): + """ + Wraps any query statement and attaches a target host, making + it usable in a targeted LBP without modifying the user's statement. + """ + def __init__(self, inner_statement, target_host): + self.__class__ = type(inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx new file mode 100644 index 0000000000..d172f1bcaf --- /dev/null +++ b/cassandra/row_parser.pyx @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.parsing cimport ParseDesc, ColumnParser +from cassandra.policies import ColDesc +from cassandra.obj_parser import TupleRowParser +from cassandra.deserializers import make_deserializers + +include "ioutils.pyx" + +def make_recv_results_rows(ColumnParser colparser): + def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata, column_encryption_policy): + """ + Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) + This is used as the recv_results_rows method of (Fast)ResultMessage + """ + self.recv_results_metadata(f, user_type_map) + + column_metadata = self.column_metadata or result_metadata + + self.column_names = [md[2] for md in column_metadata] + self.column_types = [md[3] for md in column_metadata] + + desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, + [ColDesc(md[0], md[1], md[2]) for md in column_metadata], + make_deserializers(self.column_types), protocol_version) + reader = BytesIOReader(f.read()) + try: + self.parsed_rows = colparser.parse_rows(reader, desc) + except Exception as e: + # Use explicitly the TupleRowParser to display better error messages for column decoding failures + rowparser = TupleRowParser() + reader.buf_ptr = reader.buf + reader.pos = 0 + rowcount = read_int(reader) + for i in range(rowcount): + rowparser.unpack_row(reader, desc) + + return recv_results_rows diff --git a/cassandra/segment.py b/cassandra/segment.py new file mode 100644 index 0000000000..2d7a107566 --- /dev/null +++ b/cassandra/segment.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import zlib + +from cassandra import DriverException +from cassandra.marshal import int32_pack +from cassandra.protocol import write_uint_le, read_uint_le + +CRC24_INIT = 0x875060 +CRC24_POLY = 0x1974F0B +CRC24_LENGTH = 3 +CRC32_LENGTH = 4 +CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca") + + +class CrcException(Exception): + """ + CRC mismatch error. + + TODO: here to avoid import cycles with cassandra.connection. In the next + major, the exceptions should be declared in a separated exceptions.py + file. + """ + pass + + +def compute_crc24(data, length): + crc = CRC24_INIT + + for _ in range(length): + crc ^= (data & 0xff) << 16 + data >>= 8 + + for i in range(8): + crc <<= 1 + if crc & 0x1000000 != 0: + crc ^= CRC24_POLY + + return crc + + +def compute_crc32(data, value): + crc32 = zlib.crc32(data, value) + return crc32 + + +class SegmentHeader(object): + + payload_length = None + uncompressed_payload_length = None + is_self_contained = None + + def __init__(self, payload_length, uncompressed_payload_length, is_self_contained): + self.payload_length = payload_length + self.uncompressed_payload_length = uncompressed_payload_length + self.is_self_contained = is_self_contained + + @property + def segment_length(self): + """ + Return the total length of the segment, including the CRC. + """ + hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \ + else SegmentCodec.COMPRESSED_HEADER_LENGTH + return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH + + +class Segment(object): + + MAX_PAYLOAD_LENGTH = 128 * 1024 - 1 + + payload = None + is_self_contained = None + + def __init__(self, payload, is_self_contained): + self.payload = payload + self.is_self_contained = is_self_contained + + +class SegmentCodec(object): + + COMPRESSED_HEADER_LENGTH = 5 + UNCOMPRESSED_HEADER_LENGTH = 3 + FLAG_OFFSET = 17 + + compressor = None + decompressor = None + + def __init__(self, compressor=None, decompressor=None): + self.compressor = compressor + self.decompressor = decompressor + + @property + def header_length(self): + return self.COMPRESSED_HEADER_LENGTH if self.compression \ + else self.UNCOMPRESSED_HEADER_LENGTH + + @property + def header_length_with_crc(self): + return (self.COMPRESSED_HEADER_LENGTH if self.compression + else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH + + @property + def compression(self): + return self.compressor and self.decompressor + + def compress(self, data): + # the uncompressed length is already encoded in the header, so + # we remove it here + return self.compressor(data)[4:] + + def decompress(self, encoded_data, uncompressed_length): + return self.decompressor(int32_pack(uncompressed_length) + encoded_data) + + def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained): + if payload_length > Segment.MAX_PAYLOAD_LENGTH: + raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH') + + header_data = payload_length + + flag_offset = self.FLAG_OFFSET + if self.compression: + header_data |= uncompressed_length << flag_offset + flag_offset += 17 + + if is_self_contained: + header_data |= 1 << flag_offset + + write_uint_le(buffer, header_data, size=self.header_length) + header_crc = compute_crc24(header_data, self.header_length) + write_uint_le(buffer, header_crc, size=CRC24_LENGTH) + + def _encode_segment(self, buffer, payload, is_self_contained): + """ + Encode a message to a single segment. + """ + uncompressed_payload = payload + uncompressed_payload_length = len(payload) + + if self.compression: + compressed_payload = self.compress(uncompressed_payload) + if len(compressed_payload) >= uncompressed_payload_length: + encoded_payload = uncompressed_payload + uncompressed_payload_length = 0 + else: + encoded_payload = compressed_payload + else: + encoded_payload = uncompressed_payload + + payload_length = len(encoded_payload) + self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained) + payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL) + buffer.write(encoded_payload) + write_uint_le(buffer, payload_crc) + + def encode(self, buffer, msg): + """ + Encode a message to one of more segments. + """ + msg_length = len(msg) + + if msg_length > Segment.MAX_PAYLOAD_LENGTH: + payloads = [] + for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH): + payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH]) + else: + payloads = [msg] + + is_self_contained = len(payloads) == 1 + for payload in payloads: + self._encode_segment(buffer, payload, is_self_contained) + + def decode_header(self, buffer): + header_data = read_uint_le(buffer, self.header_length) + + expected_header_crc = read_uint_le(buffer, CRC24_LENGTH) + actual_header_crc = compute_crc24(header_data, self.header_length) + if actual_header_crc != expected_header_crc: + raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format( + header_data, expected_header_crc, actual_header_crc)) + + payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH + header_data >>= 17 + + if self.compression: + uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH + header_data >>= 17 + else: + uncompressed_payload_length = -1 + + is_self_contained = (header_data & 1) == 1 + + return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained) + + def decode(self, buffer, header): + encoded_payload = buffer.read(header.payload_length) + expected_payload_crc = read_uint_le(buffer) + + actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL) + if actual_payload_crc != expected_payload_crc: + raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format( + expected_payload_crc, actual_payload_crc)) + + payload = encoded_payload + if self.compression and header.uncompressed_payload_length > 0: + payload = self.decompress(encoded_payload, header.uncompressed_payload_length) + + return Segment(payload, header.is_self_contained) diff --git a/cassandra/timestamps.py b/cassandra/timestamps.py new file mode 100644 index 0000000000..e2a2c1ea4c --- /dev/null +++ b/cassandra/timestamps.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains utilities for generating timestamps for client-side +timestamp specification. +""" + +import logging +import time +from threading import Lock + +log = logging.getLogger(__name__) + +class MonotonicTimestampGenerator(object): + """ + An object that, when called, returns ``int(time.time() * 1e6)`` when + possible, but, if the value returned by ``time.time`` doesn't increase, + drifts into the future and logs warnings. + Exposed configuration attributes can be configured with arguments to + ``__init__`` or by changing attributes on an initialized object. + + .. versionadded:: 3.8.0 + """ + + warn_on_drift = True + """ + If true, log warnings when timestamps drift into the future as allowed by + :attr:`warning_threshold` and :attr:`warning_interval`. + """ + + warning_threshold = 1 + """ + This object will only issue warnings when the returned timestamp drifts + more than ``warning_threshold`` seconds into the future. + Defaults to 1 second. + """ + + warning_interval = 1 + """ + This object will only issue warnings every ``warning_interval`` seconds. + Defaults to 1 second. + """ + + def __init__(self, warn_on_drift=True, warning_threshold=1, warning_interval=1): + self.lock = Lock() + with self.lock: + self.last = 0 + self._last_warn = 0 + self.warn_on_drift = warn_on_drift + self.warning_threshold = warning_threshold + self.warning_interval = warning_interval + + def _next_timestamp(self, now, last): + """ + Returns the timestamp that should be used if ``now`` is the current + time and ``last`` is the last timestamp returned by this object. + Intended for internal and testing use only; to generate timestamps, + call an instantiated ``MonotonicTimestampGenerator`` object. + + :param int now: an integer to be used as the current time, typically + representing the current time in microseconds since the UNIX epoch + :param int last: an integer representing the last timestamp returned by + this object + """ + if now > last: + self.last = now + return now + else: + self._maybe_warn(now=now) + self.last = last + 1 + return self.last + + def __call__(self): + """ + Makes ``MonotonicTimestampGenerator`` objects callable; defers + internally to _next_timestamp. + """ + with self.lock: + return self._next_timestamp(now=int(time.time() * 1e6), + last=self.last) + + def _maybe_warn(self, now): + # should be called from inside the self.lock. + diff = self.last - now + since_last_warn = now - self._last_warn + + warn = (self.warn_on_drift and + (diff >= self.warning_threshold * 1e6) and + (since_last_warn >= self.warning_interval * 1e6)) + if warn: + log.warning( + "Clock skew detected: current tick ({now}) was {diff} " + "microseconds behind the last generated timestamp " + "({last}), returned timestamps will be artificially " + "incremented to guarantee monotonicity.".format( + now=now, diff=diff, last=self.last)) + self._last_warn = now diff --git a/cassandra/tuple.pxd b/cassandra/tuple.pxd new file mode 100644 index 0000000000..b519e177bb --- /dev/null +++ b/cassandra/tuple.pxd @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cpython.tuple cimport ( + PyTuple_New, + # Return value: New reference. + # Return a new tuple object of size len, or NULL on failure. + PyTuple_SET_ITEM, + # Like PyTuple_SetItem(), but does no error checking, and should + # only be used to fill in brand new tuples. Note: This function + # ``steals'' a reference to o. + ) + +from cpython.ref cimport ( + Py_INCREF + # void Py_INCREF(object o) + # Increment the reference count for object o. The object must not + # be NULL; if you aren't sure that it isn't NULL, use + # Py_XINCREF(). + ) + +cdef inline tuple tuple_new(Py_ssize_t n): + """Allocate a new tuple object""" + return PyTuple_New(n) + +cdef inline void tuple_set(tuple tup, Py_ssize_t idx, object item): + """Insert new object into tuple. No item must have been set yet.""" + # PyTuple_SET_ITEM steals a reference, so we need to INCREF + Py_INCREF(item) + PyTuple_SET_ITEM(tup, idx, item) diff --git a/cassandra/type_codes.pxd b/cassandra/type_codes.pxd new file mode 100644 index 0000000000..336263b83c --- /dev/null +++ b/cassandra/type_codes.pxd @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef enum: + CUSTOM_TYPE + AsciiType + LongType + BytesType + BooleanType + CounterColumnType + DecimalType + DoubleType + FloatType + Int32Type + UTF8Type + DateType + UUIDType + VarcharType + IntegerType + TimeUUIDType + InetAddressType + SimpleDateType + TimeType + ShortType + ByteType + ListType + MapType + SetType + UserType + TupleType + diff --git a/cassandra/type_codes.py b/cassandra/type_codes.py new file mode 100644 index 0000000000..eab9a3344a --- /dev/null +++ b/cassandra/type_codes.py @@ -0,0 +1,67 @@ +""" +Module with constants for Cassandra type codes. + +These constants are useful for + + a) mapping messages to cqltypes (cassandra/cqltypes.py) + b) optimized dispatching for (de)serialization (cassandra/encoding.py) + +Type codes are repeated here from the Cassandra binary protocol specification: + + 0x0000 Custom: the value is a [string], see above. + 0x0001 Ascii + 0x0002 Bigint + 0x0003 Blob + 0x0004 Boolean + 0x0005 Counter + 0x0006 Decimal + 0x0007 Double + 0x0008 Float + 0x0009 Int + 0x000A Text + 0x000B Timestamp + 0x000C Uuid + 0x000D Varchar + 0x000E Varint + 0x000F Timeuuid + 0x0010 Inet + 0x0011 SimpleDateType + 0x0012 TimeType + 0x0013 ShortType + 0x0014 ByteType + 0x0015 DurationType + 0x0020 List: the value is an [option], representing the type + of the elements of the list. + 0x0021 Map: the value is two [option], representing the types of the + keys and values of the map + 0x0022 Set: the value is an [option], representing the type + of the elements of the set +""" + +CUSTOM_TYPE = 0x0000 +AsciiType = 0x0001 +LongType = 0x0002 +BytesType = 0x0003 +BooleanType = 0x0004 +CounterColumnType = 0x0005 +DecimalType = 0x0006 +DoubleType = 0x0007 +FloatType = 0x0008 +Int32Type = 0x0009 +UTF8Type = 0x000A +DateType = 0x000B +UUIDType = 0x000C +VarcharType = 0x000D +IntegerType = 0x000E +TimeUUIDType = 0x000F +InetAddressType = 0x0010 +SimpleDateType = 0x0011 +TimeType = 0x0012 +ShortType = 0x0013 +ByteType = 0x0014 +DurationType = 0x0015 +ListType = 0x0020 +MapType = 0x0021 +SetType = 0x0022 +UserType = 0x0030 +TupleType = 0x0031 diff --git a/cassandra/util.py b/cassandra/util.py index bbdb85087f..408211ed05 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -1,142 +1,222 @@ -from __future__ import with_statement - -# OrderedDict from Python 2.7+ - -# Copyright (c) 2009 Raymond Hettinger +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Permission is hereby granted, free of charge, to any person -# obtaining a copy of this software and associated documentation files -# (the "Software"), to deal in the Software without restriction, -# including without limitation the rights to use, copy, modify, merge, -# publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, -# subject to the following conditions: +# http://www.apache.org/licenses/LICENSE-2.0 # -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -# OTHER DEALINGS IN THE SOFTWARE. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from UserDict import DictMixin +from _weakref import ref +import calendar +from collections import OrderedDict +from collections.abc import Mapping +import datetime +from functools import total_ordering +from itertools import chain +import keyword +import logging +import pickle +import random +import re +import socket +import sys +import time +import uuid -class OrderedDict(dict, DictMixin): - """ A dictionary which maintains the insertion order of keys. """ +_HAS_GEOMET = True +try: + from geomet import wkt +except: + _HAS_GEOMET = False - def __init__(self, *args, **kwds): - """ A dictionary which maintains the insertion order of keys. """ - if len(args) > 1: - raise TypeError('expected at most 1 arguments, got %d' % len(args)) - try: - self.__end - except AttributeError: - self.clear() - self.update(*args, **kwds) +from cassandra import DriverException - def clear(self): - self.__end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.__map = {} # key --> [key, prev, next] - dict.clear(self) - - def __setitem__(self, key, value): - if key not in self: - end = self.__end - curr = end[1] - curr[2] = end[1] = self.__map[key] = [key, curr, end] - dict.__setitem__(self, key, value) +DATETIME_EPOC = datetime.datetime(1970, 1, 1).replace(tzinfo=None) +UTC_DATETIME_EPOC = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) - def __delitem__(self, key): - dict.__delitem__(self, key) - key, prev, next = self.__map.pop(key) - prev[2] = next - next[1] = prev +_nan = float('nan') - def __iter__(self): - end = self.__end - curr = end[2] - while curr is not end: - yield curr[0] - curr = curr[2] +log = logging.getLogger(__name__) - def __reversed__(self): - end = self.__end - curr = end[1] - while curr is not end: - yield curr[0] - curr = curr[1] - - def popitem(self, last=True): - if not self: - raise KeyError('dictionary is empty') - if last: - key = reversed(self).next() - else: - key = iter(self).next() - value = self.pop(key) - return key, value +assert sys.byteorder in ('little', 'big') +is_little_endian = sys.byteorder == 'little' - def __reduce__(self): - items = [[k, self[k]] for k in self] - tmp = self.__map, self.__end - del self.__map, self.__end - inst_dict = vars(self).copy() - self.__map, self.__end = tmp - if inst_dict: - return (self.__class__, (items,), inst_dict) - return self.__class__, (items,) - - def keys(self): - return list(self) - - setdefault = DictMixin.setdefault - update = DictMixin.update - pop = DictMixin.pop - values = DictMixin.values - items = DictMixin.items - iterkeys = DictMixin.iterkeys - itervalues = DictMixin.itervalues - iteritems = DictMixin.iteritems - def __repr__(self): - if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, self.items()) +def datetime_from_timestamp(timestamp): + """ + Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner. + Works around a Windows issue with large negative timestamps (PYTHON-119), + and rounding differences in Python 3.4 (PYTHON-340). - def copy(self): - return self.__class__(self) + :param timestamp: a unix timestamp, in seconds + """ + dt = DATETIME_EPOC + datetime.timedelta(seconds=timestamp) + return dt - @classmethod - def fromkeys(cls, iterable, value=None): - d = cls() - for key in iterable: - d[key] = value - return d - def __eq__(self, other): - if isinstance(other, OrderedDict): - if len(self) != len(other): - return False - for p, q in zip(self.items(), other.items()): - if p != q: - return False - return True - return dict.__eq__(self, other) +def utc_datetime_from_ms_timestamp(timestamp): + """ + Creates a UTC datetime from a timestamp in milliseconds. See + :meth:`datetime_from_timestamp`. - def __ne__(self, other): - return not self == other + Raises an `OverflowError` if the timestamp is out of range for + :class:`~datetime.datetime`. + :param timestamp: timestamp, in milliseconds + """ + return UTC_DATETIME_EPOC + datetime.timedelta(milliseconds=timestamp) -# WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset) -from _weakref import ref +def ms_timestamp_from_datetime(dt): + """ + Converts a datetime to a timestamp expressed in milliseconds. + + :param dt: a :class:`datetime.datetime` + """ + return int(round((dt - UTC_DATETIME_EPOC).total_seconds() * 1000)) + + +def unix_time_from_uuid1(uuid_arg): + """ + Converts a version 1 :class:`uuid.UUID` to a timestamp with the same precision + as :meth:`time.time()` returns. This is useful for examining the + results of queries returning a v1 :class:`~uuid.UUID`. + + :param uuid_arg: a version 1 :class:`~uuid.UUID` + """ + return (uuid_arg.time - 0x01B21DD213814000) / 1e7 + + +def datetime_from_uuid1(uuid_arg): + """ + Creates a timezone-agnostic datetime from the timestamp in the + specified type-1 UUID. + + :param uuid_arg: a version 1 :class:`~uuid.UUID` + """ + return datetime_from_timestamp(unix_time_from_uuid1(uuid_arg)) + + +def min_uuid_from_time(timestamp): + """ + Generates the minimum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. + + See :func:`uuid_from_time` for argument and return types. + """ + return uuid_from_time(timestamp, 0x808080808080, 0x80) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128) + + +def max_uuid_from_time(timestamp): + """ + Generates the maximum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. + + See :func:`uuid_from_time` for argument and return types. + """ + return uuid_from_time(timestamp, 0x7f7f7f7f7f7f, 0x3f7f) # Max signed bytes (0x7f = 127) + + +def uuid_from_time(time_arg, node=None, clock_seq=None): + """ + Converts a datetime or timestamp to a type 1 :class:`uuid.UUID`. + + :param time_arg: + The time to use for the timestamp portion of the UUID. + This can either be a :class:`datetime` object or a timestamp + in seconds (as returned from :meth:`time.time()`). + :type datetime: :class:`datetime` or timestamp + + :param node: + None integer for the UUID (up to 48 bits). If not specified, this + field is randomized. + :type node: long + + :param clock_seq: + Clock sequence field for the UUID (up to 14 bits). If not specified, + a random sequence is generated. + :type clock_seq: int + + :rtype: :class:`uuid.UUID` + + """ + if hasattr(time_arg, 'utctimetuple'): + seconds = int(calendar.timegm(time_arg.utctimetuple())) + microseconds = (seconds * 1e6) + time_arg.time().microsecond + else: + microseconds = int(time_arg * 1e6) + + # 0x01b21dd213814000 is the number of 100-ns intervals between the + # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. + intervals = int(microseconds * 10) + 0x01b21dd213814000 + + time_low = intervals & 0xffffffff + time_mid = (intervals >> 32) & 0xffff + time_hi_version = (intervals >> 48) & 0x0fff + + if clock_seq is None: + clock_seq = random.getrandbits(14) + else: + if clock_seq > 0x3fff: + raise ValueError('clock_seq is out of range (need a 14-bit value)') + + clock_seq_low = clock_seq & 0xff + clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3f) + + if node is None: + node = random.getrandbits(48) + + return uuid.UUID(fields=(time_low, time_mid, time_hi_version, + clock_seq_hi_variant, clock_seq_low, node), version=1) + +LOWEST_TIME_UUID = uuid.UUID('00000000-0000-1000-8080-808080808080') +""" The lowest possible TimeUUID, as sorted by Cassandra. """ + +HIGHEST_TIME_UUID = uuid.UUID('ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f') +""" The highest possible TimeUUID, as sorted by Cassandra. """ + + +def _addrinfo_or_none(contact_point, port): + """ + A helper function that wraps socket.getaddrinfo and returns None + when it fails to, e.g. resolve one of the hostnames. Used to address + PYTHON-895. + """ + try: + value = socket.getaddrinfo(contact_point, port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + return value + except socket.gaierror: + log.debug('Could not resolve hostname "{}" ' + 'with port {}'.format(contact_point, port)) + return None + + +def _addrinfo_to_ip_strings(addrinfo): + """ + Helper function that consumes the data output by socket.getaddrinfo and + extracts the IP address from the sockaddr portion of the result. + + Since this is meant to be used in conjunction with _addrinfo_or_none, + this will pass None and EndPoint instances through unaffected. + """ + if addrinfo is None: + return None + return [(entry[4][0], entry[4][1]) for entry in addrinfo] + + +def _resolve_contact_points_to_string_map(contact_points): + return OrderedDict( + ('{cp}:{port}'.format(cp=cp, port=port), _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port))) + for cp, port in contact_points + ) class _IterationGuard(object): @@ -167,6 +247,7 @@ def __exit__(self, e, t, b): class WeakSet(object): def __init__(self, data=None): self.data = set() + def _remove(item, selfref=ref(self)): self = selfref() if self is not None: @@ -174,6 +255,7 @@ def _remove(item, selfref=ref(self)): self._pending_removals.append(item) else: self.data.discard(item) + self._remove = _remove # A list of keys to be removed self._pending_removals = [] @@ -274,6 +356,7 @@ def difference_update(self, other): self.data.clear() else: self.data.difference_update(ref(item) for item in other) + def __isub__(self, other): if self._pending_removals: self._commit_removals() @@ -291,6 +374,7 @@ def intersection_update(self, other): if self._pending_removals: self._commit_removals() self.data.intersection_update(ref(item) for item in other) + def __iand__(self, other): if self._pending_removals: self._commit_removals() @@ -327,6 +411,7 @@ def symmetric_difference_update(self, other): self.data.clear() else: self.data.symmetric_difference_update(ref(item) for item in other) + def __ixor__(self, other): if self._pending_removals: self._commit_removals() @@ -341,4 +426,1357 @@ def union(self, other): __or__ = union def isdisjoint(self, other): - return len(self.intersection(other)) == 0 \ No newline at end of file + return len(self.intersection(other)) == 0 + + +class SortedSet(object): + ''' + A sorted set based on sorted list + + A sorted set implementation is used in this case because it does not + require its elements to be immutable/hashable. + + #Not implemented: update functions, inplace operators + ''' + + def __init__(self, iterable=()): + self._items = [] + self.update(iterable) + + def __len__(self): + return len(self._items) + + def __getitem__(self, i): + return self._items[i] + + def __iter__(self): + return iter(self._items) + + def __reversed__(self): + return reversed(self._items) + + def __repr__(self): + return '%s(%r)' % ( + self.__class__.__name__, + self._items) + + def __reduce__(self): + return self.__class__, (self._items,) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._items == other._items + else: + try: + return len(other) == len(self._items) and all(item in self for item in other) + except TypeError: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, self.__class__): + return self._items != other._items + else: + try: + return len(other) != len(self._items) or any(item not in self for item in other) + except TypeError: + return NotImplemented + + def __le__(self, other): + return self.issubset(other) + + def __lt__(self, other): + return len(other) > len(self._items) and self.issubset(other) + + def __ge__(self, other): + return self.issuperset(other) + + def __gt__(self, other): + return len(self._items) > len(other) and self.issuperset(other) + + def __and__(self, other): + return self._intersect(other) + __rand__ = __and__ + + def __iand__(self, other): + isect = self._intersect(other) + self._items = isect._items + return self + + def __or__(self, other): + return self.union(other) + __ror__ = __or__ + + def __ior__(self, other): + union = self.union(other) + self._items = union._items + return self + + def __sub__(self, other): + return self._diff(other) + + def __rsub__(self, other): + return sortedset(other) - self + + def __isub__(self, other): + diff = self._diff(other) + self._items = diff._items + return self + + def __xor__(self, other): + return self.symmetric_difference(other) + __rxor__ = __xor__ + + def __ixor__(self, other): + sym_diff = self.symmetric_difference(other) + self._items = sym_diff._items + return self + + def __contains__(self, item): + i = self._find_insertion(item) + return i < len(self._items) and self._items[i] == item + + def __delitem__(self, i): + del self._items[i] + + def __delslice__(self, i, j): + del self._items[i:j] + + def add(self, item): + i = self._find_insertion(item) + if i < len(self._items): + if self._items[i] != item: + self._items.insert(i, item) + else: + self._items.append(item) + + def update(self, iterable): + for i in iterable: + self.add(i) + + def clear(self): + del self._items[:] + + def copy(self): + new = sortedset() + new._items = list(self._items) + return new + + def isdisjoint(self, other): + return len(self._intersect(other)) == 0 + + def issubset(self, other): + return len(self._intersect(other)) == len(self._items) + + def issuperset(self, other): + return len(self._intersect(other)) == len(other) + + def pop(self): + if not self._items: + raise KeyError("pop from empty set") + return self._items.pop() + + def remove(self, item): + i = self._find_insertion(item) + if i < len(self._items): + if self._items[i] == item: + self._items.pop(i) + return + raise KeyError('%r' % item) + + def union(self, *others): + union = sortedset() + union._items = list(self._items) + for other in others: + for item in other: + union.add(item) + return union + + def intersection(self, *others): + isect = self.copy() + for other in others: + isect = isect._intersect(other) + if not isect: + break + return isect + + def difference(self, *others): + diff = self.copy() + for other in others: + diff = diff._diff(other) + if not diff: + break + return diff + + def symmetric_difference(self, other): + diff_self_other = self._diff(other) + diff_other_self = other.difference(self) + return diff_self_other.union(diff_other_self) + + def _diff(self, other): + diff = sortedset() + for item in self._items: + if item not in other: + diff.add(item) + return diff + + def _intersect(self, other): + isect = sortedset() + for item in self._items: + if item in other: + isect.add(item) + return isect + + def _find_insertion(self, x): + # this uses bisect_left algorithm unless it has elements it can't compare, + # in which case it defaults to grouping non-comparable items at the beginning or end, + # and scanning sequentially to find an insertion point + a = self._items + lo = 0 + hi = len(a) + try: + while lo < hi: + mid = (lo + hi) // 2 + if a[mid] < x: lo = mid + 1 + else: hi = mid + except TypeError: + # could not compare a[mid] with x + # start scanning to find insertion point while swallowing type errors + lo = 0 + compared_one = False # flag is used to determine whether un-comparables are grouped at the front or back + while lo < hi: + try: + if a[lo] == x or a[lo] >= x: break + compared_one = True + except TypeError: + if compared_one: break + lo += 1 + return lo + +sortedset = SortedSet # backwards-compatibility + + +class OrderedMap(Mapping): + ''' + An ordered map that accepts non-hashable types for keys. It also maintains the + insertion order of items, behaving as OrderedDict in that regard. These maps + are constructed and read just as normal mapping types, except that they may + contain arbitrary collections and other non-hashable items as keys:: + + >>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'), + ... ({'three': 3, 'four': 4}, 'value2')]) + >>> list(od.keys()) + [{'two': 2, 'one': 1}, {'three': 3, 'four': 4}] + >>> list(od.values()) + ['value', 'value2'] + + These constructs are needed to support nested collections in Cassandra 2.1.3+, + where frozen collections can be specified as parameters to others:: + + CREATE TABLE example ( + ... + value map>, double> + ... + ) + + This class derives from the (immutable) Mapping API. Objects in these maps + are not intended be modified. + ''' + + def __init__(self, *args, **kwargs): + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + + self._items = [] + self._index = {} + if args: + e = args[0] + if callable(getattr(e, 'keys', None)): + for k in e.keys(): + self._insert(k, e[k]) + else: + for k, v in e: + self._insert(k, v) + + for k, v in kwargs.items(): + self._insert(k, v) + + def _insert(self, key, value): + flat_key = self._serialize_key(key) + i = self._index.get(flat_key, -1) + if i >= 0: + self._items[i] = (key, value) + else: + self._items.append((key, value)) + self._index[flat_key] = len(self._items) - 1 + + __setitem__ = _insert + + def __getitem__(self, key): + try: + index = self._index[self._serialize_key(key)] + return self._items[index][1] + except KeyError: + raise KeyError(str(key)) + + def __delitem__(self, key): + # not efficient -- for convenience only + try: + index = self._index.pop(self._serialize_key(key)) + self._index = dict((k, i if i < index else i - 1) for k, i in self._index.items()) + self._items.pop(index) + except KeyError: + raise KeyError(str(key)) + + def __iter__(self): + for i in self._items: + yield i[0] + + def __len__(self): + return len(self._items) + + def __eq__(self, other): + if isinstance(other, OrderedMap): + return self._items == other._items + try: + d = dict(other) + return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) + except KeyError: + return False + except TypeError: + pass + return NotImplemented + + def __repr__(self): + return '%s([%s])' % ( + self.__class__.__name__, + ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) + + def __str__(self): + return '{%s}' % ', '.join("%r: %r" % (k, v) for k, v in self._items) + + def popitem(self): + try: + kv = self._items.pop() + del self._index[self._serialize_key(kv[0])] + return kv + except IndexError: + raise KeyError() + + def _serialize_key(self, key): + return pickle.dumps(key) + + +class OrderedMapSerializedKey(OrderedMap): + + def __init__(self, cass_type, protocol_version): + super(OrderedMapSerializedKey, self).__init__() + self.cass_key_type = cass_type + self.protocol_version = protocol_version + + def _insert_unchecked(self, key, flat_key, value): + self._items.append((key, value)) + self._index[flat_key] = len(self._items) - 1 + + def _serialize_key(self, key): + return self.cass_key_type.serialize(key, self.protocol_version) + + +@total_ordering +class Time(object): + ''' + Idealized time, independent of day. + + Up to nanosecond resolution + ''' + + MICRO = 1000 + MILLI = 1000 * MICRO + SECOND = 1000 * MILLI + MINUTE = 60 * SECOND + HOUR = 60 * MINUTE + DAY = 24 * HOUR + + nanosecond_time = 0 + + def __init__(self, value): + """ + Initializer value can be: + + - integer_type: absolute nanoseconds in the day + - datetime.time: built-in time + - string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]" + """ + if isinstance(value, int): + self._from_timestamp(value) + elif isinstance(value, datetime.time): + self._from_time(value) + elif isinstance(value, str): + self._from_timestring(value) + else: + raise TypeError('Time arguments must be a whole number, datetime.time, or string') + + @property + def hour(self): + """ + The hour component of this time (0-23) + """ + return self.nanosecond_time // Time.HOUR + + @property + def minute(self): + """ + The minute component of this time (0-59) + """ + minutes = self.nanosecond_time // Time.MINUTE + return minutes % 60 + + @property + def second(self): + """ + The second component of this time (0-59) + """ + seconds = self.nanosecond_time // Time.SECOND + return seconds % 60 + + @property + def nanosecond(self): + """ + The fractional seconds component of the time, in nanoseconds + """ + return self.nanosecond_time % Time.SECOND + + def time(self): + """ + Return a built-in datetime.time (nanosecond precision truncated to micros). + """ + return datetime.time(hour=self.hour, minute=self.minute, second=self.second, + microsecond=self.nanosecond // Time.MICRO) + + def _from_timestamp(self, t): + if t >= Time.DAY: + raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY) + self.nanosecond_time = t + + def _from_timestring(self, s): + try: + parts = s.split('.') + base_time = time.strptime(parts[0], "%H:%M:%S") + self.nanosecond_time = (base_time.tm_hour * Time.HOUR + + base_time.tm_min * Time.MINUTE + + base_time.tm_sec * Time.SECOND) + + if len(parts) > 1: + # right pad to 9 digits + nano_time_str = parts[1] + "0" * (9 - len(parts[1])) + self.nanosecond_time += int(nano_time_str) + + except ValueError: + raise ValueError("can't interpret %r as a time" % (s,)) + + def _from_time(self, t): + self.nanosecond_time = (t.hour * Time.HOUR + + t.minute * Time.MINUTE + + t.second * Time.SECOND + + t.microsecond * Time.MICRO) + + def __hash__(self): + return self.nanosecond_time + + def __eq__(self, other): + if isinstance(other, Time): + return self.nanosecond_time == other.nanosecond_time + + if isinstance(other, int): + return self.nanosecond_time == other + + return self.nanosecond_time % Time.MICRO == 0 and \ + datetime.time(hour=self.hour, minute=self.minute, second=self.second, + microsecond=self.nanosecond // Time.MICRO) == other + + def __ne__(self, other): + return not self.__eq__(other) + + def __lt__(self, other): + if not isinstance(other, Time): + return NotImplemented + return self.nanosecond_time < other.nanosecond_time + + def __repr__(self): + return "Time(%s)" % self.nanosecond_time + + def __str__(self): + return "%02d:%02d:%02d.%09d" % (self.hour, self.minute, + self.second, self.nanosecond) + + +@total_ordering +class Date(object): + ''' + Idealized date: year, month, day + + Offers wider year range than datetime.date. For Dates that cannot be represented + as a datetime.date (because datetime.MINYEAR, datetime.MAXYEAR), this type falls back + to printing days_from_epoch offset. + ''' + + MINUTE = 60 + HOUR = 60 * MINUTE + DAY = 24 * HOUR + + date_format = "%Y-%m-%d" + + days_from_epoch = 0 + + def __init__(self, value): + """ + Initializer value can be: + + - integer_type: absolute days from epoch (1970, 1, 1). Can be negative. + - datetime.date: built-in date + - string_type: a string time of the form "yyyy-mm-dd" + """ + if isinstance(value, int): + self.days_from_epoch = value + elif isinstance(value, (datetime.date, datetime.datetime)): + self._from_timetuple(value.timetuple()) + elif isinstance(value, str): + self._from_datestring(value) + else: + raise TypeError('Date arguments must be a whole number, datetime.date, or string') + + @property + def seconds(self): + """ + Absolute seconds from epoch (can be negative) + """ + return self.days_from_epoch * Date.DAY + + def date(self): + """ + Return a built-in datetime.date for Dates falling in the years [datetime.MINYEAR, datetime.MAXYEAR] + + ValueError is raised for Dates outside this range. + """ + try: + dt = datetime_from_timestamp(self.seconds) + return datetime.date(dt.year, dt.month, dt.day) + except Exception: + raise ValueError("%r exceeds ranges for built-in datetime.date" % self) + + def _from_timetuple(self, t): + self.days_from_epoch = calendar.timegm(t) // Date.DAY + + def _from_datestring(self, s): + if s[0] == '+': + s = s[1:] + dt = datetime.datetime.strptime(s, self.date_format) + self._from_timetuple(dt.timetuple()) + + def __hash__(self): + return self.days_from_epoch + + def __eq__(self, other): + if isinstance(other, Date): + return self.days_from_epoch == other.days_from_epoch + + if isinstance(other, int): + return self.days_from_epoch == other + + try: + return self.date() == other + except Exception: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __lt__(self, other): + if not isinstance(other, Date): + return NotImplemented + return self.days_from_epoch < other.days_from_epoch + + def __repr__(self): + return "Date(%s)" % self.days_from_epoch + + def __str__(self): + try: + dt = datetime_from_timestamp(self.seconds) + return "%04d-%02d-%02d" % (dt.year, dt.month, dt.day) + except: + # If we overflow datetime.[MIN|MAX] + return str(self.days_from_epoch) + + +inet_pton = socket.inet_pton +inet_ntop = socket.inet_ntop + + +# similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic +def _positional_rename_invalid_identifiers(field_names): + names_out = list(field_names) + for index, name in enumerate(field_names): + if (not all(c.isalnum() or c == '_' for c in name) + or keyword.iskeyword(name) + or not name + or name[0].isdigit() + or name.startswith('_')): + names_out[index] = 'field_%d_' % index + return names_out + + +def _sanitize_identifiers(field_names): + names_out = _positional_rename_invalid_identifiers(field_names) + if len(names_out) != len(set(names_out)): + observed_names = set() + for index, name in enumerate(names_out): + while names_out[index] in observed_names: + names_out[index] = "%s_" % (names_out[index],) + observed_names.add(names_out[index]) + return names_out + + +def list_contents_to_tuple(to_convert): + if isinstance(to_convert, list): + for n, i in enumerate(to_convert): + if isinstance(to_convert[n], list): + to_convert[n] = tuple(to_convert[n]) + return tuple(to_convert) + else: + return to_convert + + +class Point(object): + """ + Represents a point geometry for DSE + """ + + x = None + """ + x coordinate of the point + """ + + y = None + """ + y coordinate of the point + """ + + def __init__(self, x=_nan, y=_nan): + self.x = x + self.y = y + + def __eq__(self, other): + return isinstance(other, Point) and self.x == other.x and self.y == other.y + + def __hash__(self): + return hash((self.x, self.y)) + + def __str__(self): + """ + Well-known text representation of the point + """ + return "POINT (%r %r)" % (self.x, self.y) + + def __repr__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.x, self.y) + + @staticmethod + def from_wkt(s): + """ + Parse a Point geometry from a wkt string and return a new Point object. + """ + if not _HAS_GEOMET: + raise DriverException("Geomet is required to deserialize a wkt geometry.") + + try: + geom = wkt.loads(s) + except ValueError: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + if geom['type'] != 'Point': + raise ValueError("Invalid WKT geometry type. Expected 'Point', got '{0}': '{1}'".format(geom['type'], s)) + + coords = geom['coordinates'] + if len(coords) < 2: + x = y = _nan + else: + x = coords[0] + y = coords[1] + + return Point(x=x, y=y) + + +class LineString(object): + """ + Represents a linestring geometry for DSE + """ + + coords = None + """ + Tuple of (x, y) coordinates in the linestring + """ + def __init__(self, coords=tuple()): + """ + 'coords`: a sequence of (x, y) coordinates of points in the linestring + """ + self.coords = tuple(coords) + + def __eq__(self, other): + return isinstance(other, LineString) and self.coords == other.coords + + def __hash__(self): + return hash(self.coords) + + def __str__(self): + """ + Well-known text representation of the LineString + """ + if not self.coords: + return "LINESTRING EMPTY" + return "LINESTRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.coords) + + @staticmethod + def from_wkt(s): + """ + Parse a LineString geometry from a wkt string and return a new LineString object. + """ + if not _HAS_GEOMET: + raise DriverException("Geomet is required to deserialize a wkt geometry.") + + try: + geom = wkt.loads(s) + except ValueError: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + if geom['type'] != 'LineString': + raise ValueError("Invalid WKT geometry type. Expected 'LineString', got '{0}': '{1}'".format(geom['type'], s)) + + geom['coordinates'] = list_contents_to_tuple(geom['coordinates']) + + return LineString(coords=geom['coordinates']) + + +class _LinearRing(object): + # no validation, no implicit closing; just used for poly composition, to + # mimic that of shapely.geometry.Polygon + def __init__(self, coords=tuple()): + self.coords = list_contents_to_tuple(coords) + + def __eq__(self, other): + return isinstance(other, _LinearRing) and self.coords == other.coords + + def __hash__(self): + return hash(self.coords) + + def __str__(self): + if not self.coords: + return "LINEARRING EMPTY" + return "LINEARRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.coords) + + +class Polygon(object): + """ + Represents a polygon geometry for DSE + """ + + exterior = None + """ + _LinearRing representing the exterior of the polygon + """ + + interiors = None + """ + Tuple of _LinearRings representing interior holes in the polygon + """ + + def __init__(self, exterior=tuple(), interiors=None): + """ + 'exterior`: a sequence of (x, y) coordinates of points in the linestring + `interiors`: None, or a sequence of sequences or (x, y) coordinates of points describing interior linear rings + """ + self.exterior = _LinearRing(exterior) + self.interiors = tuple(_LinearRing(e) for e in interiors) if interiors else tuple() + + def __eq__(self, other): + return isinstance(other, Polygon) and self.exterior == other.exterior and self.interiors == other.interiors + + def __hash__(self): + return hash((self.exterior, self.interiors)) + + def __str__(self): + """ + Well-known text representation of the polygon + """ + if not self.exterior.coords: + return "POLYGON EMPTY" + rings = [ring.coords for ring in chain((self.exterior,), self.interiors)] + rings = ["(%s)" % ', '.join("%r %r" % (x, y) for x, y in ring) for ring in rings] + return "POLYGON (%s)" % ', '.join(rings) + + def __repr__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.exterior.coords, [ring.coords for ring in self.interiors]) + + @staticmethod + def from_wkt(s): + """ + Parse a Polygon geometry from a wkt string and return a new Polygon object. + """ + if not _HAS_GEOMET: + raise DriverException("Geomet is required to deserialize a wkt geometry.") + + try: + geom = wkt.loads(s) + except ValueError: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + if geom['type'] != 'Polygon': + raise ValueError("Invalid WKT geometry type. Expected 'Polygon', got '{0}': '{1}'".format(geom['type'], s)) + + coords = geom['coordinates'] + exterior = coords[0] if len(coords) > 0 else tuple() + interiors = coords[1:] if len(coords) > 1 else None + + return Polygon(exterior=exterior, interiors=interiors) + + +_distance_wkt_pattern = re.compile("distance *\\( *\\( *([\\d\\.-]+) *([\\d+\\.-]+) *\\) *([\\d+\\.-]+) *\\) *$", re.IGNORECASE) + + +class Distance(object): + """ + Represents a Distance geometry for DSE + """ + + x = None + """ + x coordinate of the center point + """ + + y = None + """ + y coordinate of the center point + """ + + radius = None + """ + radius to represent the distance from the center point + """ + + def __init__(self, x=_nan, y=_nan, radius=_nan): + self.x = x + self.y = y + self.radius = radius + + def __eq__(self, other): + return isinstance(other, Distance) and self.x == other.x and self.y == other.y and self.radius == other.radius + + def __hash__(self): + return hash((self.x, self.y, self.radius)) + + def __str__(self): + """ + Well-known text representation of the point + """ + return "DISTANCE ((%r %r) %r)" % (self.x, self.y, self.radius) + + def __repr__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.x, self.y, self.radius) + + @staticmethod + def from_wkt(s): + """ + Parse a Distance geometry from a wkt string and return a new Distance object. + """ + + distance_match = _distance_wkt_pattern.match(s) + + if distance_match is None: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + x, y, radius = distance_match.groups() + return Distance(x, y, radius) + + +class Duration(object): + """ + Cassandra Duration Type + """ + + months = 0 + "" + days = 0 + "" + nanoseconds = 0 + "" + + def __init__(self, months=0, days=0, nanoseconds=0): + self.months = months + self.days = days + self.nanoseconds = nanoseconds + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds + + def __repr__(self): + return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds) + + def __str__(self): + has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0 + return '%s%dmo%dd%dns' % ( + '-' if has_negative_values else '', + abs(self.months), + abs(self.days), + abs(self.nanoseconds) + ) + + +class DateRangePrecision(object): + """ + An "enum" representing the valid values for :attr:`DateRange.precision`. + """ + YEAR = 'YEAR' + """ + """ + + MONTH = 'MONTH' + """ + """ + + DAY = 'DAY' + """ + """ + + HOUR = 'HOUR' + """ + """ + + MINUTE = 'MINUTE' + """ + """ + + SECOND = 'SECOND' + """ + """ + + MILLISECOND = 'MILLISECOND' + """ + """ + + PRECISIONS = (YEAR, MONTH, DAY, HOUR, + MINUTE, SECOND, MILLISECOND) + """ + """ + + @classmethod + def _to_int(cls, precision): + return cls.PRECISIONS.index(precision.upper()) + + @classmethod + def _round_to_precision(cls, ms, precision, default_dt): + try: + dt = utc_datetime_from_ms_timestamp(ms) + except OverflowError: + return ms + precision_idx = cls._to_int(precision) + replace_kwargs = {} + if precision_idx <= cls._to_int(DateRangePrecision.YEAR): + replace_kwargs['month'] = default_dt.month + if precision_idx <= cls._to_int(DateRangePrecision.MONTH): + replace_kwargs['day'] = default_dt.day + if precision_idx <= cls._to_int(DateRangePrecision.DAY): + replace_kwargs['hour'] = default_dt.hour + if precision_idx <= cls._to_int(DateRangePrecision.HOUR): + replace_kwargs['minute'] = default_dt.minute + if precision_idx <= cls._to_int(DateRangePrecision.MINUTE): + replace_kwargs['second'] = default_dt.second + if precision_idx <= cls._to_int(DateRangePrecision.SECOND): + # truncate to nearest 1000, so we deal in ms, not us + replace_kwargs['microsecond'] = (default_dt.microsecond // 1000) * 1000 + if precision_idx == cls._to_int(DateRangePrecision.MILLISECOND): + replace_kwargs['microsecond'] = int(round(dt.microsecond, -3)) + return ms_timestamp_from_datetime(dt.replace(**replace_kwargs)) + + @classmethod + def round_up_to_precision(cls, ms, precision): + # PYTHON-912: this is the only case in which we can't take as upper bound + # datetime.datetime.max because the month from ms may be February, and we'd + # be setting 31 as the month day + if precision == cls.MONTH: + date_ms = utc_datetime_from_ms_timestamp(ms) + upper_date = datetime.datetime.max.replace(year=date_ms.year, month=date_ms.month, + day=calendar.monthrange(date_ms.year, date_ms.month)[1]) + else: + upper_date = datetime.datetime.max + return cls._round_to_precision(ms, precision, upper_date) + + @classmethod + def round_down_to_precision(cls, ms, precision): + return cls._round_to_precision(ms, precision, datetime.datetime.min) + + +@total_ordering +class DateRangeBound(object): + """DateRangeBound(value, precision) + Represents a single date value and its precision for :class:`DateRange`. + + .. attribute:: milliseconds + + Integer representing milliseconds since the UNIX epoch. May be negative. + + .. attribute:: precision + + String representing the precision of a bound. Must be a valid + :class:`DateRangePrecision` member. + + :class:`DateRangeBound` uses a millisecond offset from the UNIX epoch to + allow :class:`DateRange` to represent values `datetime.datetime` cannot. + For such values, string representions will show this offset rather than the + CQL representation. + """ + milliseconds = None + precision = None + + def __init__(self, value, precision): + """ + :param value: a value representing ms since the epoch. Accepts an + integer or a datetime. + :param precision: a string representing precision + """ + if precision is not None: + try: + self.precision = precision.upper() + except AttributeError: + raise TypeError('precision must be a string; got %r' % precision) + + if value is None: + milliseconds = None + elif isinstance(value, int): + milliseconds = value + elif isinstance(value, datetime.datetime): + value = value.replace( + microsecond=int(round(value.microsecond, -3)) + ) + milliseconds = ms_timestamp_from_datetime(value) + else: + raise ValueError('%r is not a valid value for DateRangeBound' % value) + + self.milliseconds = milliseconds + self.validate() + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.milliseconds == other.milliseconds and + self.precision == other.precision) + + def __lt__(self, other): + return ((str(self.milliseconds), str(self.precision)) < + (str(other.milliseconds), str(other.precision))) + + def datetime(self): + """ + Return :attr:`milliseconds` as a :class:`datetime.datetime` if possible. + Raises an `OverflowError` if the value is out of range. + """ + return utc_datetime_from_ms_timestamp(self.milliseconds) + + def validate(self): + attrs = self.milliseconds, self.precision + if attrs == (None, None): + return + if None in attrs: + raise TypeError( + ("%s.datetime and %s.precision must not be None unless both " + "are None; Got: %r") % (self.__class__.__name__, + self.__class__.__name__, + self) + ) + if self.precision not in DateRangePrecision.PRECISIONS: + raise ValueError( + "%s.precision: expected value in %r; got %r" % ( + self.__class__.__name__, + DateRangePrecision.PRECISIONS, + self.precision + ) + ) + + @classmethod + def from_value(cls, value): + """ + Construct a new :class:`DateRangeBound` from a given value. If + possible, use the `value['milliseconds']` and `value['precision']` keys + of the argument. Otherwise, use the argument as a `(milliseconds, + precision)` iterable. + + :param value: a dictlike or iterable object + """ + if isinstance(value, cls): + return value + + # if possible, use as a mapping + try: + milliseconds, precision = value.get('milliseconds'), value.get('precision') + except AttributeError: + milliseconds = precision = None + if milliseconds is not None and precision is not None: + return DateRangeBound(value=milliseconds, precision=precision) + + # otherwise, use as an iterable + return DateRangeBound(*value) + + def round_up(self): + if self.milliseconds is None or self.precision is None: + return self + self.milliseconds = DateRangePrecision.round_up_to_precision( + self.milliseconds, self.precision + ) + return self + + def round_down(self): + if self.milliseconds is None or self.precision is None: + return self + self.milliseconds = DateRangePrecision.round_down_to_precision( + self.milliseconds, self.precision + ) + return self + + _formatter_map = { + DateRangePrecision.YEAR: '%Y', + DateRangePrecision.MONTH: '%Y-%m', + DateRangePrecision.DAY: '%Y-%m-%d', + DateRangePrecision.HOUR: '%Y-%m-%dT%HZ', + DateRangePrecision.MINUTE: '%Y-%m-%dT%H:%MZ', + DateRangePrecision.SECOND: '%Y-%m-%dT%H:%M:%SZ', + DateRangePrecision.MILLISECOND: '%Y-%m-%dT%H:%M:%S', + } + + def __str__(self): + if self == OPEN_BOUND: + return '*' + + try: + dt = self.datetime() + except OverflowError: + return '%sms' % (self.milliseconds,) + + formatted = dt.strftime(self._formatter_map[self.precision]) + + if self.precision == DateRangePrecision.MILLISECOND: + # we'd like to just format with '%Y-%m-%dT%H:%M:%S.%fZ', but %f + # gives us more precision than we want, so we strftime up to %S and + # do the rest ourselves + return '%s.%03dZ' % (formatted, dt.microsecond / 1000) + + return formatted + + def __repr__(self): + return '%s(milliseconds=%r, precision=%r)' % ( + self.__class__.__name__, self.milliseconds, self.precision + ) + + +OPEN_BOUND = DateRangeBound(value=None, precision=None) +""" +Represents `*`, an open value or bound for :class:`DateRange`. +""" + + +@total_ordering +class DateRange(object): + """DateRange(lower_bound=None, upper_bound=None, value=None) + DSE DateRange Type + + .. attribute:: lower_bound + + :class:`~DateRangeBound` representing the lower bound of a bounded range. + + .. attribute:: upper_bound + + :class:`~DateRangeBound` representing the upper bound of a bounded range. + + .. attribute:: value + + :class:`~DateRangeBound` representing the value of a single-value range. + + As noted in its documentation, :class:`DateRangeBound` uses a millisecond + offset from the UNIX epoch to allow :class:`DateRange` to represent values + `datetime.datetime` cannot. For such values, string representions will show + this offset rather than the CQL representation. + """ + lower_bound = None + upper_bound = None + value = None + + def __init__(self, lower_bound=None, upper_bound=None, value=None): + """ + :param lower_bound: a :class:`DateRangeBound` or object accepted by + :meth:`DateRangeBound.from_value` to be used as a + :attr:`lower_bound`. Mutually exclusive with `value`. If + `upper_bound` is specified and this is not, the :attr:`lower_bound` + will be open. + :param upper_bound: a :class:`DateRangeBound` or object accepted by + :meth:`DateRangeBound.from_value` to be used as a + :attr:`upper_bound`. Mutually exclusive with `value`. If + `lower_bound` is specified and this is not, the :attr:`upper_bound` + will be open. + :param value: a :class:`DateRangeBound` or object accepted by + :meth:`DateRangeBound.from_value` to be used as :attr:`value`. Mutually + exclusive with `lower_bound` and `lower_bound`. + """ + + # if necessary, transform non-None args to DateRangeBounds + lower_bound = (DateRangeBound.from_value(lower_bound).round_down() + if lower_bound else lower_bound) + upper_bound = (DateRangeBound.from_value(upper_bound).round_up() + if upper_bound else upper_bound) + value = (DateRangeBound.from_value(value).round_down() + if value else value) + + # if we're using a 2-ended range but one bound isn't specified, specify + # an open bound + if lower_bound is None and upper_bound is not None: + lower_bound = OPEN_BOUND + if upper_bound is None and lower_bound is not None: + upper_bound = OPEN_BOUND + + self.lower_bound, self.upper_bound, self.value = ( + lower_bound, upper_bound, value + ) + self.validate() + + def validate(self): + if self.value is None: + if self.lower_bound is None or self.upper_bound is None: + raise ValueError( + '%s instances where value attribute is None must set ' + 'lower_bound or upper_bound; got %r' % ( + self.__class__.__name__, + self + ) + ) + else: # self.value is not None + if self.lower_bound is not None or self.upper_bound is not None: + raise ValueError( + '%s instances where value attribute is not None must not ' + 'set lower_bound or upper_bound; got %r' % ( + self.__class__.__name__, + self + ) + ) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.lower_bound == other.lower_bound and + self.upper_bound == other.upper_bound and + self.value == other.value) + + def __lt__(self, other): + return ((str(self.lower_bound), str(self.upper_bound), str(self.value)) < + (str(other.lower_bound), str(other.upper_bound), str(other.value))) + + def __str__(self): + if self.value: + return str(self.value) + else: + return '[%s TO %s]' % (self.lower_bound, self.upper_bound) + + def __repr__(self): + return '%s(lower_bound=%r, upper_bound=%r, value=%r)' % ( + self.__class__.__name__, + self.lower_bound, self.upper_bound, self.value + ) + +VERSION_REGEX = re.compile("^(\\d+)\\.(\\d+)(\\.\\d+)?(\\.\\d+)?([~\\-]\\w[.\\w]*(?:-\\w[.\\w]*)*)?(\\+[.\\w]+)?$") + +@total_ordering +class Version(object): + """ + Representation of a Cassandra version. Mostly follows the implementation of the same logic in the Java driver; + see https://github.com/apache/cassandra-java-driver/blob/4.19.2/core/src/main/java/com/datastax/oss/driver/api/core/Version.java. + + Cassandra versions are assumed to correspond to major.minor.patch with an optional additional numeric build field as well as a + string prerelease field. + """ + + def __init__(self, version): + self._version = version + + match = VERSION_REGEX.match(version) + if not match: + raise ValueError("Version string {0} did not match expected format".format(version)) + + self.major = int(match[1]) + self.minor = int(match[2]) + + try: + self.patch = self._cleanup_int(match[3]) + except: + self.patch = 0 + + try: + self.build = self._cleanup_int(match[4]) + except: + self.build = 0 + + try: + self.prerelease = self._cleanup_str(match[5]) + except: + self.prerelease = "" + + # This is used in a few places below so let's just build it now + self._tuple = (self.major, self.minor, self.patch, self.build, self.prerelease) + + # Trim off the leading '.' characters and convert the discovered value to an integer + def _cleanup_int(self, instr): + return int(instr[1:]) if instr else 0 + + # Trim off the leading '.' or '~' characters and just return the string directly + def _cleanup_str(self, instr): + return instr[1:] if instr else "" + + def __hash__(self): + return hash(self._tuple) + + def __repr__(self): + version_string = "Version({0}, {1}, {2}".format(self.major, self.minor, self.patch) + if self.build: + version_string += ", {}".format(self.build) + if self.prerelease: + version_string += ", {}".format(self.prerelease) + version_string += ")" + + return version_string + + def __str__(self): + return self._version + + # Methods below leverage left-to-right positional comparison of tuples + def __eq__(self, other): + if not isinstance(other, Version): + return NotImplemented + + return self._tuple == other._tuple + + def __gt__(self, other): + if not isinstance(other, Version): + return NotImplemented + + # We start by comparing the first four fields directly + self_tuple = self._tuple[:4] + other_tuple = (other.major, other.minor, other.patch, other.build) + if self_tuple != other_tuple: + return self_tuple > other_tuple + # If we're still around we have to check prereleases... prereleases always come before + # the corresponding version + elif self.prerelease and not other.prerelease: + return False + elif other.prerelease and not self.prerelease: + return True + else: + return self.prerelease > other.prerelease diff --git a/distribute_setup.py b/distribute_setup.py deleted file mode 100644 index 3553b2135e..0000000000 --- a/distribute_setup.py +++ /dev/null @@ -1,556 +0,0 @@ -#!python -"""Bootstrap distribute installation - -If you want to use setuptools in your package's setup.py, just include this -file in the same directory with it, and add this to the top of your setup.py:: - - from distribute_setup import use_setuptools - use_setuptools() - -If you want to require a specific version of setuptools, set a download -mirror, or use an alternate download directory, you can do so by supplying -the appropriate options to ``use_setuptools()``. - -This file can also be run as a script to install or upgrade setuptools. -""" -import os -import shutil -import sys -import time -import fnmatch -import tempfile -import tarfile -import optparse - -from distutils import log - -try: - from site import USER_SITE -except ImportError: - USER_SITE = None - -try: - import subprocess - - def _python_cmd(*args): - args = (sys.executable,) + args - return subprocess.call(args) == 0 - -except ImportError: - # will be used for python 2.3 - def _python_cmd(*args): - args = (sys.executable,) + args - # quoting arguments if windows - if sys.platform == 'win32': - def quote(arg): - if ' ' in arg: - return '"%s"' % arg - return arg - args = [quote(arg) for arg in args] - return os.spawnl(os.P_WAIT, sys.executable, *args) == 0 - -DEFAULT_VERSION = "0.6.49" -DEFAULT_URL = "http://pypi.python.org/packages/source/d/distribute/" -SETUPTOOLS_FAKED_VERSION = "0.6c11" - -SETUPTOOLS_PKG_INFO = """\ -Metadata-Version: 1.0 -Name: setuptools -Version: %s -Summary: xxxx -Home-page: xxx -Author: xxx -Author-email: xxx -License: xxx -Description: xxx -""" % SETUPTOOLS_FAKED_VERSION - - -def _install(tarball, install_args=()): - # extracting the tarball - tmpdir = tempfile.mkdtemp() - log.warn('Extracting in %s', tmpdir) - old_wd = os.getcwd() - try: - os.chdir(tmpdir) - tar = tarfile.open(tarball) - _extractall(tar) - tar.close() - - # going in the directory - subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) - os.chdir(subdir) - log.warn('Now working in %s', subdir) - - # installing - log.warn('Installing Distribute') - if not _python_cmd('setup.py', 'install', *install_args): - log.warn('Something went wrong during the installation.') - log.warn('See the error message above.') - # exitcode will be 2 - return 2 - finally: - os.chdir(old_wd) - shutil.rmtree(tmpdir) - - -def _build_egg(egg, tarball, to_dir): - # extracting the tarball - tmpdir = tempfile.mkdtemp() - log.warn('Extracting in %s', tmpdir) - old_wd = os.getcwd() - try: - os.chdir(tmpdir) - tar = tarfile.open(tarball) - _extractall(tar) - tar.close() - - # going in the directory - subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) - os.chdir(subdir) - log.warn('Now working in %s', subdir) - - # building an egg - log.warn('Building a Distribute egg in %s', to_dir) - _python_cmd('setup.py', '-q', 'bdist_egg', '--dist-dir', to_dir) - - finally: - os.chdir(old_wd) - shutil.rmtree(tmpdir) - # returning the result - log.warn(egg) - if not os.path.exists(egg): - raise IOError('Could not build the egg.') - - -def _do_download(version, download_base, to_dir, download_delay): - egg = os.path.join(to_dir, 'distribute-%s-py%d.%d.egg' - % (version, sys.version_info[0], sys.version_info[1])) - if not os.path.exists(egg): - tarball = download_setuptools(version, download_base, - to_dir, download_delay) - _build_egg(egg, tarball, to_dir) - sys.path.insert(0, egg) - import setuptools - setuptools.bootstrap_install_from = egg - - -def use_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, - to_dir=os.curdir, download_delay=15, no_fake=True): - # making sure we use the absolute path - to_dir = os.path.abspath(to_dir) - was_imported = 'pkg_resources' in sys.modules or \ - 'setuptools' in sys.modules - try: - try: - import pkg_resources - - # Setuptools 0.7b and later is a suitable (and preferable) - # substitute for any Distribute version. - try: - pkg_resources.require("setuptools>=0.7b") - return - except (pkg_resources.DistributionNotFound, - pkg_resources.VersionConflict): - pass - - if not hasattr(pkg_resources, '_distribute'): - if not no_fake: - _fake_setuptools() - raise ImportError - except ImportError: - return _do_download(version, download_base, to_dir, download_delay) - try: - pkg_resources.require("distribute>=" + version) - return - except pkg_resources.VersionConflict: - e = sys.exc_info()[1] - if was_imported: - sys.stderr.write( - "The required version of distribute (>=%s) is not available,\n" - "and can't be installed while this script is running. Please\n" - "install a more recent version first, using\n" - "'easy_install -U distribute'." - "\n\n(Currently using %r)\n" % (version, e.args[0])) - sys.exit(2) - else: - del pkg_resources, sys.modules['pkg_resources'] # reload ok - return _do_download(version, download_base, to_dir, - download_delay) - except pkg_resources.DistributionNotFound: - return _do_download(version, download_base, to_dir, - download_delay) - finally: - if not no_fake: - _create_fake_setuptools_pkg_info(to_dir) - - -def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, - to_dir=os.curdir, delay=15): - """Download distribute from a specified location and return its filename - - `version` should be a valid distribute version number that is available - as an egg for download under the `download_base` URL (which should end - with a '/'). `to_dir` is the directory where the egg will be downloaded. - `delay` is the number of seconds to pause before an actual download - attempt. - """ - # making sure we use the absolute path - to_dir = os.path.abspath(to_dir) - try: - from urllib.request import urlopen - except ImportError: - from urllib2 import urlopen - tgz_name = "distribute-%s.tar.gz" % version - url = download_base + tgz_name - saveto = os.path.join(to_dir, tgz_name) - src = dst = None - if not os.path.exists(saveto): # Avoid repeated downloads - try: - log.warn("Downloading %s", url) - src = urlopen(url) - # Read/write all in one block, so we don't create a corrupt file - # if the download is interrupted. - data = src.read() - dst = open(saveto, "wb") - dst.write(data) - finally: - if src: - src.close() - if dst: - dst.close() - return os.path.realpath(saveto) - - -def _no_sandbox(function): - def __no_sandbox(*args, **kw): - try: - from setuptools.sandbox import DirectorySandbox - if not hasattr(DirectorySandbox, '_old'): - def violation(*args): - pass - DirectorySandbox._old = DirectorySandbox._violation - DirectorySandbox._violation = violation - patched = True - else: - patched = False - except ImportError: - patched = False - - try: - return function(*args, **kw) - finally: - if patched: - DirectorySandbox._violation = DirectorySandbox._old - del DirectorySandbox._old - - return __no_sandbox - - -def _patch_file(path, content): - """Will backup the file then patch it""" - f = open(path) - existing_content = f.read() - f.close() - if existing_content == content: - # already patched - log.warn('Already patched.') - return False - log.warn('Patching...') - _rename_path(path) - f = open(path, 'w') - try: - f.write(content) - finally: - f.close() - return True - -_patch_file = _no_sandbox(_patch_file) - - -def _same_content(path, content): - f = open(path) - existing_content = f.read() - f.close() - return existing_content == content - - -def _rename_path(path): - new_name = path + '.OLD.%s' % time.time() - log.warn('Renaming %s to %s', path, new_name) - os.rename(path, new_name) - return new_name - - -def _remove_flat_installation(placeholder): - if not os.path.isdir(placeholder): - log.warn('Unkown installation at %s', placeholder) - return False - found = False - for file in os.listdir(placeholder): - if fnmatch.fnmatch(file, 'setuptools*.egg-info'): - found = True - break - if not found: - log.warn('Could not locate setuptools*.egg-info') - return - - log.warn('Moving elements out of the way...') - pkg_info = os.path.join(placeholder, file) - if os.path.isdir(pkg_info): - patched = _patch_egg_dir(pkg_info) - else: - patched = _patch_file(pkg_info, SETUPTOOLS_PKG_INFO) - - if not patched: - log.warn('%s already patched.', pkg_info) - return False - # now let's move the files out of the way - for element in ('setuptools', 'pkg_resources.py', 'site.py'): - element = os.path.join(placeholder, element) - if os.path.exists(element): - _rename_path(element) - else: - log.warn('Could not find the %s element of the ' - 'Setuptools distribution', element) - return True - -_remove_flat_installation = _no_sandbox(_remove_flat_installation) - - -def _after_install(dist): - log.warn('After install bootstrap.') - placeholder = dist.get_command_obj('install').install_purelib - _create_fake_setuptools_pkg_info(placeholder) - - -def _create_fake_setuptools_pkg_info(placeholder): - if not placeholder or not os.path.exists(placeholder): - log.warn('Could not find the install location') - return - pyver = '%s.%s' % (sys.version_info[0], sys.version_info[1]) - setuptools_file = 'setuptools-%s-py%s.egg-info' % \ - (SETUPTOOLS_FAKED_VERSION, pyver) - pkg_info = os.path.join(placeholder, setuptools_file) - if os.path.exists(pkg_info): - log.warn('%s already exists', pkg_info) - return - - log.warn('Creating %s', pkg_info) - try: - f = open(pkg_info, 'w') - except EnvironmentError: - log.warn("Don't have permissions to write %s, skipping", pkg_info) - return - try: - f.write(SETUPTOOLS_PKG_INFO) - finally: - f.close() - - pth_file = os.path.join(placeholder, 'setuptools.pth') - log.warn('Creating %s', pth_file) - f = open(pth_file, 'w') - try: - f.write(os.path.join(os.curdir, setuptools_file)) - finally: - f.close() - -_create_fake_setuptools_pkg_info = _no_sandbox( - _create_fake_setuptools_pkg_info -) - - -def _patch_egg_dir(path): - # let's check if it's already patched - pkg_info = os.path.join(path, 'EGG-INFO', 'PKG-INFO') - if os.path.exists(pkg_info): - if _same_content(pkg_info, SETUPTOOLS_PKG_INFO): - log.warn('%s already patched.', pkg_info) - return False - _rename_path(path) - os.mkdir(path) - os.mkdir(os.path.join(path, 'EGG-INFO')) - pkg_info = os.path.join(path, 'EGG-INFO', 'PKG-INFO') - f = open(pkg_info, 'w') - try: - f.write(SETUPTOOLS_PKG_INFO) - finally: - f.close() - return True - -_patch_egg_dir = _no_sandbox(_patch_egg_dir) - - -def _before_install(): - log.warn('Before install bootstrap.') - _fake_setuptools() - - -def _under_prefix(location): - if 'install' not in sys.argv: - return True - args = sys.argv[sys.argv.index('install') + 1:] - for index, arg in enumerate(args): - for option in ('--root', '--prefix'): - if arg.startswith('%s=' % option): - top_dir = arg.split('root=')[-1] - return location.startswith(top_dir) - elif arg == option: - if len(args) > index: - top_dir = args[index + 1] - return location.startswith(top_dir) - if arg == '--user' and USER_SITE is not None: - return location.startswith(USER_SITE) - return True - - -def _fake_setuptools(): - log.warn('Scanning installed packages') - try: - import pkg_resources - except ImportError: - # we're cool - log.warn('Setuptools or Distribute does not seem to be installed.') - return - ws = pkg_resources.working_set - try: - setuptools_dist = ws.find( - pkg_resources.Requirement.parse('setuptools', replacement=False) - ) - except TypeError: - # old distribute API - setuptools_dist = ws.find( - pkg_resources.Requirement.parse('setuptools') - ) - - if setuptools_dist is None: - log.warn('No setuptools distribution found') - return - # detecting if it was already faked - setuptools_location = setuptools_dist.location - log.warn('Setuptools installation detected at %s', setuptools_location) - - # if --root or --preix was provided, and if - # setuptools is not located in them, we don't patch it - if not _under_prefix(setuptools_location): - log.warn('Not patching, --root or --prefix is installing Distribute' - ' in another location') - return - - # let's see if its an egg - if not setuptools_location.endswith('.egg'): - log.warn('Non-egg installation') - res = _remove_flat_installation(setuptools_location) - if not res: - return - else: - log.warn('Egg installation') - pkg_info = os.path.join(setuptools_location, 'EGG-INFO', 'PKG-INFO') - if (os.path.exists(pkg_info) and - _same_content(pkg_info, SETUPTOOLS_PKG_INFO)): - log.warn('Already patched.') - return - log.warn('Patching...') - # let's create a fake egg replacing setuptools one - res = _patch_egg_dir(setuptools_location) - if not res: - return - log.warn('Patching complete.') - _relaunch() - - -def _relaunch(): - log.warn('Relaunching...') - # we have to relaunch the process - # pip marker to avoid a relaunch bug - _cmd1 = ['-c', 'install', '--single-version-externally-managed'] - _cmd2 = ['-c', 'install', '--record'] - if sys.argv[:3] == _cmd1 or sys.argv[:3] == _cmd2: - sys.argv[0] = 'setup.py' - args = [sys.executable] + sys.argv - sys.exit(subprocess.call(args)) - - -def _extractall(self, path=".", members=None): - """Extract all members from the archive to the current working - directory and set owner, modification time and permissions on - directories afterwards. `path' specifies a different directory - to extract to. `members' is optional and must be a subset of the - list returned by getmembers(). - """ - import copy - import operator - from tarfile import ExtractError - directories = [] - - if members is None: - members = self - - for tarinfo in members: - if tarinfo.isdir(): - # Extract directories with a safe mode. - directories.append(tarinfo) - tarinfo = copy.copy(tarinfo) - tarinfo.mode = 448 # decimal for oct 0700 - self.extract(tarinfo, path) - - # Reverse sort directories. - if sys.version_info < (2, 4): - def sorter(dir1, dir2): - return cmp(dir1.name, dir2.name) - directories.sort(sorter) - directories.reverse() - else: - directories.sort(key=operator.attrgetter('name'), reverse=True) - - # Set correct owner, mtime and filemode on directories. - for tarinfo in directories: - dirpath = os.path.join(path, tarinfo.name) - try: - self.chown(tarinfo, dirpath) - self.utime(tarinfo, dirpath) - self.chmod(tarinfo, dirpath) - except ExtractError: - e = sys.exc_info()[1] - if self.errorlevel > 1: - raise - else: - self._dbg(1, "tarfile: %s" % e) - - -def _build_install_args(options): - """ - Build the arguments to 'python setup.py install' on the distribute package - """ - install_args = [] - if options.user_install: - if sys.version_info < (2, 6): - log.warn("--user requires Python 2.6 or later") - raise SystemExit(1) - install_args.append('--user') - return install_args - -def _parse_args(): - """ - Parse the command line for options - """ - parser = optparse.OptionParser() - parser.add_option( - '--user', dest='user_install', action='store_true', default=False, - help='install in user site package (requires Python 2.6 or later)') - parser.add_option( - '--download-base', dest='download_base', metavar="URL", - default=DEFAULT_URL, - help='alternative URL from where to download the distribute package') - options, args = parser.parse_args() - # positional arguments are ignored - return options - -def main(version=DEFAULT_VERSION): - """Install or upgrade setuptools and EasyInstall""" - options = _parse_args() - tarball = download_setuptools(download_base=options.download_base) - return _install(tarball, _build_install_args(options)) - -if __name__ == '__main__': - sys.exit(main()) diff --git a/docs.yaml b/docs.yaml new file mode 100644 index 0000000000..63269a3001 --- /dev/null +++ b/docs.yaml @@ -0,0 +1,117 @@ +title: DataStax Python Driver +summary: DataStax Python Driver for Apache Cassandra® +output: docs/_build/ +swiftype_drivers: pythondrivers +sections: + - title: N/A + prefix: / + type: sphinx + directory: docs + virtualenv_init: | + set -x + CASS_DRIVER_NO_CYTHON=1 pip install -r test-datastax-requirements.txt + # for newer versions this is redundant, but in older versions we need to + # install, e.g., the cassandra driver, and those versions don't specify + # the cassandra driver version in requirements files + CASS_DRIVER_NO_CYTHON=1 python setup.py develop + pip install "jinja2==2.8.1;python_version<'3.6'" "sphinx>=1.3,<2" geomet + # build extensions like libev + CASS_DRIVER_NO_CYTHON=1 python setup.py build_ext --inplace --force +versions: + - name: '3.29' + ref: 434b1f52 + - name: '3.28' + ref: 4325afb6 + - name: '3.27' + ref: 910f0282 + - name: '3.26' + ref: f1e9126 + - name: '3.25' + ref: a83c36a5 + - name: '3.24' + ref: 21cac12b + - name: '3.23' + ref: a40a2af7 + - name: '3.22' + ref: 1ccd5b99 + - name: '3.21' + ref: 5589d96b + - name: '3.20' + ref: d30d166f + - name: '3.19' + ref: ac2471f9 + - name: '3.18' + ref: ec36b957 + - name: '3.17' + ref: 38e359e1 + - name: '3.16' + ref: '3.16.0' + - name: '3.15' + ref: '2ce0bd97' + - name: '3.14' + ref: '9af8bd19' + - name: '3.13' + ref: '3.13.0' + - name: '3.12' + ref: '43b9c995' + - name: '3.11' + ref: '3.11.0' + - name: '3.10' + ref: 64572368 + - name: 3.9 + ref: 3.9-doc + - name: 3.8 + ref: 3.8-doc + - name: 3.7 + ref: 3.7-doc + - name: 3.6 + ref: 3.6-doc + - name: 3.5 + ref: 3.5-doc +redirects: + - \A\/(.*)/\Z: /\1.html +rewrites: + - search: http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH + replace: https://docs.datastax.com/en/dse/6.7/cql/cql/cql_reference/cql_commands/cqlBatch.html + - search: http://www.datastax.com/documentation/cql/3.1/ + replace: https://docs.datastax.com/en/archived/cql/3.1/ + - search: 'https://community.datastax.com' + replace: 'https://www.datastax.com/dev/community' + - search: 'https://docs.datastax.com/en/astra/aws/doc/index.html' + replace: 'https://docs.datastax.com/en/astra-serverless/docs/connect/drivers/connect-python.html' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#timeuuid-functions' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#token' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#collections' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/types.html#collections' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#batchStmt' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/dml.html#batch_statement' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#timeuuid-functions' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#token' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/types.html#collections' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#batchStmt' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/dml.html#batch_statement' +checks: + external_links: + exclude: + - 'https://twitter.com/dsJavaDriver' + - 'https://twitter.com/datastaxeng' + - 'https://twitter.com/datastax' + - 'https://projectreactor.io' + - 'https://docs.datastax.com/en/drivers/java/4.[0-9]+/com/datastax/oss/driver/internal/' + - 'http://www.planetcassandra.org/blog/user-defined-functions-in-cassandra-3-0/' + - 'http://www.planetcassandra.org/making-the-change-from-thrift-to-cql/' + - 'https://academy.datastax.com/slack' + - 'https://community.datastax.com/index.html' + - 'https://micrometer.io/docs' + - 'http://datastax.github.io/java-driver/features/shaded_jar/' + - 'http://aka.ms/vcpython27' + internal_links: + exclude: + - 'netty_pipeline/' + - '../core/' + - '%5Bguava%20eviction%5D' diff --git a/docs/.nav b/docs/.nav new file mode 100644 index 0000000000..79f3029073 --- /dev/null +++ b/docs/.nav @@ -0,0 +1,21 @@ +installation +getting_started +execution_profiles +lwt +object_mapper +performance +query_paging +security +upgrading +user_defined_types +dates_and_times +cloud +column_encryption +geo_types +graph +classic_graph +graph_fluent +CHANGELOG +faq +api + diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst new file mode 100644 index 0000000000..592a2c0efa --- /dev/null +++ b/docs/CHANGELOG.rst @@ -0,0 +1,5 @@ +********* +CHANGELOG +********* + +.. include:: ../CHANGELOG.rst diff --git a/docs/Makefile b/docs/Makefile index b076f89380..bf300ec71d 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -72,17 +72,17 @@ qthelp: @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/CassandraDriver.qhcp" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/cassandra-driver.qhcp" @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/CassandraDriver.qhc" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/cassandra-driver.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/CassandraDriver" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/CassandraDriver" + @echo "# mkdir -p $$HOME/.local/share/devhelp/cassandra-driver" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/cassandra-driver" @echo "# devhelp" epub: @@ -100,7 +100,7 @@ latex: latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." - make -C $(BUILDDIR)/latex all-pdf + $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: diff --git a/docs/api/cassandra.rst b/docs/api/cassandra.rst index 50e4856e2f..d46aae56cb 100644 --- a/docs/api/cassandra.rst +++ b/docs/api/cassandra.rst @@ -3,9 +3,34 @@ .. module:: cassandra +.. data:: __version_info__ + + The version of the driver in a tuple format + +.. data:: __version__ + + The version of the driver in a string format + .. autoclass:: ConsistencyLevel :members: +.. autoclass:: ProtocolVersion + :members: + +.. autoclass:: UserFunctionDescriptor + :members: + :inherited-members: + +.. autoclass:: UserAggregateDescriptor + :members: + :inherited-members: + +.. autoexception:: DriverException() + :members: + +.. autoexception:: RequestExecutionException() + :members: + .. autoexception:: Unavailable() :members: @@ -18,6 +43,27 @@ .. autoexception:: WriteTimeout() :members: +.. autoexception:: CoordinationFailure() + :members: + +.. autoexception:: ReadFailure() + :members: + +.. autoexception:: WriteFailure() + :members: + +.. autoexception:: FunctionFailure() + :members: + +.. autoexception:: RequestValidationException() + :members: + +.. autoexception:: ConfigurationException() + :members: + +.. autoexception:: AlreadyExists() + :members: + .. autoexception:: InvalidRequest() :members: @@ -26,3 +72,6 @@ .. autoexception:: AuthenticationFailed() :members: + +.. autoexception:: OperationTimedOut() + :members: diff --git a/docs/api/cassandra/auth.rst b/docs/api/cassandra/auth.rst new file mode 100644 index 0000000000..58c964cf89 --- /dev/null +++ b/docs/api/cassandra/auth.rst @@ -0,0 +1,22 @@ +``cassandra.auth`` - Authentication +=================================== + +.. module:: cassandra.auth + +.. autoclass:: AuthProvider + :members: + +.. autoclass:: Authenticator + :members: + +.. autoclass:: PlainTextAuthProvider + :members: + +.. autoclass:: PlainTextAuthenticator + :members: + +.. autoclass:: SaslAuthProvider + :members: + +.. autoclass:: SaslAuthenticator + :members: diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 17a39a10f7..a9a9d378a4 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -4,16 +4,225 @@ .. module:: cassandra.cluster .. autoclass:: Cluster ([contact_points=('127.0.0.1',)][, port=9042][, executor_threads=2], **attr_kwargs) + + .. autoattribute:: contact_points + + .. autoattribute:: port + + .. autoattribute:: cql_version + + .. autoattribute:: protocol_version + + .. autoattribute:: compression + + .. autoattribute:: auth_provider + + .. autoattribute:: load_balancing_policy + + .. autoattribute:: reconnection_policy + + .. autoattribute:: default_retry_policy + :annotation: = + + .. autoattribute:: conviction_policy_factory + + .. autoattribute:: address_translator + + .. autoattribute:: metrics_enabled + + .. autoattribute:: metrics + + .. autoattribute:: ssl_context + + .. autoattribute:: ssl_options + + .. autoattribute:: sockopts + + .. autoattribute:: max_schema_agreement_wait + + .. autoattribute:: metadata + + .. autoattribute:: connection_class + + .. autoattribute:: control_connection_timeout + + .. autoattribute:: idle_heartbeat_interval + + .. autoattribute:: idle_heartbeat_timeout + + .. autoattribute:: schema_event_refresh_window + + .. autoattribute:: topology_event_refresh_window + + .. autoattribute:: status_event_refresh_window + + .. autoattribute:: prepare_on_all_hosts + + .. autoattribute:: reprepare_on_up + + .. autoattribute:: connect_timeout + + .. autoattribute:: schema_metadata_enabled + :annotation: = True + + .. autoattribute:: token_metadata_enabled + :annotation: = True + + .. autoattribute:: timestamp_generator + + .. autoattribute:: endpoint_factory + + .. autoattribute:: cloud + + .. automethod:: connect + + .. automethod:: shutdown + + .. automethod:: register_user_type + + .. automethod:: register_listener + + .. automethod:: unregister_listener + + .. automethod:: add_execution_profile + + .. automethod:: set_max_requests_per_connection + + .. automethod:: get_max_requests_per_connection + + .. automethod:: set_min_requests_per_connection + + .. automethod:: get_min_requests_per_connection + + .. automethod:: get_core_connections_per_host + + .. automethod:: set_core_connections_per_host + + .. automethod:: get_max_connections_per_host + + .. automethod:: set_max_connections_per_host + + .. automethod:: get_control_connection_host + + .. automethod:: refresh_schema_metadata + + .. automethod:: refresh_keyspace_metadata + + .. automethod:: refresh_table_metadata + + .. automethod:: refresh_user_type_metadata + + .. automethod:: refresh_user_function_metadata + + .. automethod:: refresh_user_aggregate_metadata + + .. automethod:: refresh_nodes + + .. automethod:: set_meta_refresh_enabled + +.. autoclass:: ExecutionProfile (load_balancing_policy=, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=, speculative_execution_policy=None) :members: - :exclude-members: on_up, on_down, add_host, remove_host, connection_factory + :exclude-members: consistency_level -.. autoclass:: Session () + .. autoattribute:: consistency_level + :annotation: = LOCAL_ONE + +.. autoclass:: GraphExecutionProfile (load_balancing_policy=_NOT_SET, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=30.0, row_factory=None, graph_options=None, continuous_paging_options=_NOT_SET) + :members: + +.. autoclass:: GraphAnalyticsExecutionProfile (load_balancing_policy=None, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=3600. * 24. * 7., row_factory=None, graph_options=None) :members: - :exclude-members: on_up, on_down, on_add, on_remove, add_host, prepare_on_all_hosts, submit + +.. autodata:: EXEC_PROFILE_DEFAULT + :annotation: + +.. autodata:: EXEC_PROFILE_GRAPH_DEFAULT + :annotation: + +.. autodata:: EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + :annotation: + +.. autodata:: EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT + :annotation: + +.. autoclass:: Session () + + .. autoattribute:: default_timeout + :annotation: = 10.0 + + .. autoattribute:: default_consistency_level + :annotation: = LOCAL_ONE + + .. autoattribute:: default_serial_consistency_level + :annotation: = None + + .. autoattribute:: row_factory + :annotation: = + + .. autoattribute:: default_fetch_size + + .. autoattribute:: use_client_timestamp + + .. autoattribute:: timestamp_generator + + .. autoattribute:: encoder + + .. autoattribute:: client_protocol_handler + + .. automethod:: execute(statement[, parameters][, timeout][, trace][, custom_payload][, paging_state][, host][, execute_as]) + + .. automethod:: execute_async(statement[, parameters][, trace][, custom_payload][, paging_state][, host][, execute_as]) + + .. automethod:: execute_graph(statement[, parameters][, trace][, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT][, execute_as]) + + .. automethod:: execute_graph_async(statement[, parameters][, trace][, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT][, execute_as]) + + .. automethod:: prepare(statement) + + .. automethod:: shutdown() + + .. automethod:: set_keyspace(keyspace) + + .. automethod:: get_execution_profile + + .. automethod:: execution_profile_clone_update + + .. automethod:: add_request_init_listener + + .. automethod:: remove_request_init_listener .. autoclass:: ResponseFuture () + + .. autoattribute:: query + + .. automethod:: result() + + .. automethod:: get_query_trace() + + .. automethod:: get_all_query_traces() + + .. autoattribute:: custom_payload() + + .. autoattribute:: is_schema_agreed + + .. autoattribute:: has_more_pages + + .. autoattribute:: warnings + + .. automethod:: start_fetching_next_page() + + .. automethod:: add_callback(fn, *args, **kwargs) + + .. automethod:: add_errback(fn, *args, **kwargs) + + .. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_kwargs=None) + +.. autoclass:: ResultSet () :members: - :exclude-members: send_request + +.. autoexception:: QueryExhausted () .. autoexception:: NoHostAvailable () :members: + +.. autoexception:: UserTypeDoesNotExist () diff --git a/docs/api/cassandra/concurrent.rst b/docs/api/cassandra/concurrent.rst new file mode 100644 index 0000000000..f4bab6f048 --- /dev/null +++ b/docs/api/cassandra/concurrent.rst @@ -0,0 +1,8 @@ +``cassandra.concurrent`` - Utilities for Concurrent Statement Execution +======================================================================= + +.. module:: cassandra.concurrent + +.. autofunction:: execute_concurrent + +.. autofunction:: execute_concurrent_with_args diff --git a/docs/api/cassandra/connection.rst b/docs/api/cassandra/connection.rst index 8a21d57618..32cca590c0 100644 --- a/docs/api/cassandra/connection.rst +++ b/docs/api/cassandra/connection.rst @@ -4,5 +4,18 @@ .. module:: cassandra.connection .. autoexception:: ConnectionException () +.. autoexception:: ConnectionShutdown () .. autoexception:: ConnectionBusy () .. autoexception:: ProtocolError () + +.. autoclass:: EndPoint + :members: + +.. autoclass:: EndPointFactory + :members: + +.. autoclass:: SniEndPoint + +.. autoclass:: SniEndPointFactory + +.. autoclass:: UnixSocketEndPoint diff --git a/docs/api/cassandra/cqlengine/columns.rst b/docs/api/cassandra/cqlengine/columns.rst new file mode 100644 index 0000000000..d44be8adb8 --- /dev/null +++ b/docs/api/cassandra/cqlengine/columns.rst @@ -0,0 +1,89 @@ +``cassandra.cqlengine.columns`` - Column types for object mapping models +======================================================================== + +.. module:: cassandra.cqlengine.columns + +Columns +------- + +Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. +For a model to be valid it needs at least one primary key column and one non-primary key column. + +Just as in CQL, the order you define your columns in is important, and is the same order they are defined in on a model's corresponding table. + +Each column on your model definitions needs to be an instance of a Column class. + +.. autoclass:: Column(**kwargs) + + .. autoattribute:: primary_key + + .. autoattribute:: partition_key + + .. autoattribute:: index + + .. autoattribute:: custom_index + + .. autoattribute:: db_field + + .. autoattribute:: default + + .. autoattribute:: required + + .. autoattribute:: clustering_order + + .. autoattribute:: discriminator_column + + .. autoattribute:: static + +Column Types +------------ + +Columns of all types are initialized by passing :class:`.Column` attributes to the constructor by keyword. + +.. autoclass:: Ascii(**kwargs) + +.. autoclass:: BigInt(**kwargs) + +.. autoclass:: Blob(**kwargs) + +.. autoclass:: Bytes(**kwargs) + +.. autoclass:: Boolean(**kwargs) + +.. autoclass:: Counter + +.. autoclass:: Date(**kwargs) + +.. autoclass:: DateTime(**kwargs) + + .. autoattribute:: truncate_microseconds + +.. autoclass:: Decimal(**kwargs) + +.. autoclass:: Double(**kwargs) + +.. autoclass:: Float + +.. autoclass:: Integer(**kwargs) + +.. autoclass:: List + +.. autoclass:: Map + +.. autoclass:: Set + +.. autoclass:: SmallInt(**kwargs) + +.. autoclass:: Text + +.. autoclass:: Time(**kwargs) + +.. autoclass:: TimeUUID(**kwargs) + +.. autoclass:: TinyInt(**kwargs) + +.. autoclass:: UserDefinedType + +.. autoclass:: UUID(**kwargs) + +.. autoclass:: VarInt(**kwargs) diff --git a/docs/api/cassandra/cqlengine/connection.rst b/docs/api/cassandra/cqlengine/connection.rst new file mode 100644 index 0000000000..0f584fcca2 --- /dev/null +++ b/docs/api/cassandra/cqlengine/connection.rst @@ -0,0 +1,16 @@ +``cassandra.cqlengine.connection`` - Connection management for cqlengine +======================================================================== + +.. module:: cassandra.cqlengine.connection + +.. autofunction:: default + +.. autofunction:: set_session + +.. autofunction:: setup + +.. autofunction:: register_connection + +.. autofunction:: unregister_connection + +.. autofunction:: set_default_connection diff --git a/docs/api/cassandra/cqlengine/management.rst b/docs/api/cassandra/cqlengine/management.rst new file mode 100644 index 0000000000..fb483abc81 --- /dev/null +++ b/docs/api/cassandra/cqlengine/management.rst @@ -0,0 +1,19 @@ +``cassandra.cqlengine.management`` - Schema management for cqlengine +======================================================================== + +.. module:: cassandra.cqlengine.management + +A collection of functions for managing keyspace and table schema. + +.. autofunction:: create_keyspace_simple + +.. autofunction:: create_keyspace_network_topology + +.. autofunction:: drop_keyspace + +.. autofunction:: sync_table + +.. autofunction:: sync_type + +.. autofunction:: drop_table + diff --git a/docs/api/cassandra/cqlengine/models.rst b/docs/api/cassandra/cqlengine/models.rst new file mode 100644 index 0000000000..ee689a2b48 --- /dev/null +++ b/docs/api/cassandra/cqlengine/models.rst @@ -0,0 +1,197 @@ +``cassandra.cqlengine.models`` - Table models for object mapping +================================================================ + +.. module:: cassandra.cqlengine.models + +Model +----- +.. autoclass:: Model(\*\*kwargs) + + The initializer creates an instance of the model. Pass in keyword arguments for columns you've defined on the model. + + .. code-block:: python + + class Person(Model): + id = columns.UUID(primary_key=True) + first_name = columns.Text() + last_name = columns.Text() + + person = Person(first_name='Blake', last_name='Eggleston') + person.first_name #returns 'Blake' + person.last_name #returns 'Eggleston' + + Model attributes define how the model maps to tables in the database. These are class variables that should be set + when defining Model deriviatives. + + .. autoattribute:: __abstract__ + :annotation: = False + + .. autoattribute:: __table_name__ + + .. autoattribute:: __table_name_case_sensitive__ + + .. autoattribute:: __keyspace__ + + .. autoattribute:: __connection__ + + .. attribute:: __default_ttl__ + :annotation: = None + + Will be deprecated in release 4.0. You can set the default ttl by configuring the table ``__options__``. See :ref:`ttl-change` for more details. + + .. autoattribute:: __discriminator_value__ + + See :ref:`model_inheritance` for usage examples. + + Each table can have its own set of configuration options, including compaction. Unspecified, these default to sensible values in + the server. To override defaults, set options using the model ``__options__`` attribute, which allows options specified a dict. + + When a table is synced, it will be altered to match the options set on your table. + This means that if you are changing settings manually they will be changed back on resync. + + Do not use the options settings of cqlengine if you want to manage your compaction settings manually. + + See the `list of supported table properties for more information + `_. + + .. attribute:: __options__ + + For example: + + .. code-block:: python + + class User(Model): + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy', + 'sstable_size_in_mb': '64', + 'tombstone_threshold': '.2'}, + 'comment': 'User data stored here'} + + user_id = columns.UUID(primary_key=True) + name = columns.Text() + + or : + + .. code-block:: python + + class TimeData(Model): + __options__ = {'compaction': {'class': 'SizeTieredCompactionStrategy', + 'bucket_low': '.3', + 'bucket_high': '2', + 'min_threshold': '2', + 'max_threshold': '64', + 'tombstone_compaction_interval': '86400'}, + 'gc_grace_seconds': '0'} + + .. autoattribute:: __compute_routing_key__ + + + The base methods allow creating, storing, and querying modeled objects. + + .. automethod:: create + + .. method:: if_not_exists() + + Check the existence of an object before insertion. The existence of an + object is determined by its primary key(s). And please note using this flag + would incur performance cost. + + If the insertion isn't applied, a :class:`~cassandra.cqlengine.query.LWTException` is raised. + + .. code-block:: python + + try: + TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111') + except LWTException as e: + # handle failure case + print(e.existing # dict containing LWT result fields) + + This method is supported on Cassandra 2.0 or later. + + .. method:: if_exists() + + Check the existence of an object before an update or delete. The existence of an + object is determined by its primary key(s). And please note using this flag + would incur performance cost. + + If the update or delete isn't applied, a :class:`~cassandra.cqlengine.query.LWTException` is raised. + + .. code-block:: python + + try: + TestIfExistsModel.objects(id=id).if_exists().update(count=9, text='111111111111') + except LWTException as e: + # handle failure case + pass + + This method is supported on Cassandra 2.0 or later. + + .. automethod:: save + + .. automethod:: update + + .. method:: iff(**values) + + Checks to ensure that the values specified are correct on the Cassandra cluster. + Simply specify the column(s) and the expected value(s). As with if_not_exists, + this incurs a performance cost. + + If the insertion isn't applied, a :class:`~cassandra.cqlengine.query.LWTException` is raised. + + .. code-block:: python + + t = TestTransactionModel(text='some text', count=5) + try: + t.iff(count=5).update('other text') + except LWTException as e: + # handle failure case + print(e.existing # existing object) + + .. automethod:: get + + .. automethod:: filter + + .. automethod:: all + + .. automethod:: delete + + .. method:: batch(batch_object) + + Sets the batch object to run instance updates and inserts queries with. + + See :doc:`/cqlengine/batches` for usage examples + + .. automethod:: timeout + + .. method:: timestamp(timedelta_or_datetime) + + Sets the timestamp for the query + + .. method:: ttl(ttl_in_sec) + + Sets the ttl values to run instance updates and inserts queries with. + + .. method:: using(connection=None) + + Change the context on the fly of the model instance (keyspace, connection) + + .. automethod:: column_family_name + + Models also support dict-like access: + + .. method:: len(m) + + Returns the number of columns defined in the model + + .. method:: m[col_name] + + Returns the value of column ``col_name`` + + .. method:: m[col_name] = value + + Set ``m[col_name]`` to value + + .. automethod:: keys + + .. automethod:: values + + .. automethod:: items diff --git a/docs/api/cassandra/cqlengine/query.rst b/docs/api/cassandra/cqlengine/query.rst new file mode 100644 index 0000000000..ce8f764b6b --- /dev/null +++ b/docs/api/cassandra/cqlengine/query.rst @@ -0,0 +1,71 @@ +``cassandra.cqlengine.query`` - Query and filter model objects +================================================================= + +.. module:: cassandra.cqlengine.query + +QuerySet +-------- +QuerySet objects are typically obtained by calling :meth:`~.cassandra.cqlengine.models.Model.objects` on a model class. +The methods here are used to filter, order, and constrain results. + +.. autoclass:: ModelQuerySet + + .. automethod:: all + + .. automethod:: batch + + .. automethod:: consistency + + .. automethod:: count + + .. method:: len(queryset) + + Returns the number of rows matched by this query. This function uses :meth:`~.cassandra.cqlengine.query.ModelQuerySet.count` internally. + + *Note: This function executes a SELECT COUNT() and has a performance cost on large datasets* + + .. automethod:: distinct + + .. automethod:: filter + + .. automethod:: get + + .. automethod:: limit + + .. automethod:: fetch_size + + .. automethod:: if_not_exists + + .. automethod:: if_exists + + .. automethod:: order_by + + .. automethod:: allow_filtering + + .. automethod:: only + + .. automethod:: defer + + .. automethod:: timestamp + + .. automethod:: ttl + + .. automethod:: using + + .. _blind_updates: + + .. automethod:: update + +.. autoclass:: BatchQuery + :members: + + .. automethod:: add_query + .. automethod:: execute + +.. autoclass:: ContextQuery + +.. autoclass:: DoesNotExist + +.. autoclass:: MultipleObjectsReturned + +.. autoclass:: LWTException diff --git a/docs/api/cassandra/cqlengine/usertype.rst b/docs/api/cassandra/cqlengine/usertype.rst new file mode 100644 index 0000000000..ebed187da9 --- /dev/null +++ b/docs/api/cassandra/cqlengine/usertype.rst @@ -0,0 +1,10 @@ +``cassandra.cqlengine.usertype`` - Model classes for User Defined Types +======================================================================= + +.. module:: cassandra.cqlengine.usertype + +UserType +-------- +.. autoclass:: UserType + + .. autoattribute:: __type_name__ diff --git a/docs/api/cassandra/datastax/graph/fluent/index.rst b/docs/api/cassandra/datastax/graph/fluent/index.rst new file mode 100644 index 0000000000..5547e0fdd7 --- /dev/null +++ b/docs/api/cassandra/datastax/graph/fluent/index.rst @@ -0,0 +1,24 @@ +:mod:`cassandra.datastax.graph.fluent` +====================================== + +.. module:: cassandra.datastax.graph.fluent + +.. autoclass:: DseGraph + + .. autoattribute:: DSE_GRAPH_QUERY_LANGUAGE + + .. automethod:: create_execution_profile + + .. automethod:: query_from_traversal + + .. automethod:: traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, traversal_class=None) + + .. automethod:: batch(session=None, execution_profile=None) + +.. autoclass:: DSESessionRemoteGraphConnection(session[, graph_name, execution_profile]) + +.. autoclass:: BaseGraphRowFactory + +.. autoclass:: graph_traversal_row_factory + +.. autoclass:: graph_traversal_dse_object_row_factory diff --git a/docs/api/cassandra/datastax/graph/fluent/predicates.rst b/docs/api/cassandra/datastax/graph/fluent/predicates.rst new file mode 100644 index 0000000000..f6e86f6451 --- /dev/null +++ b/docs/api/cassandra/datastax/graph/fluent/predicates.rst @@ -0,0 +1,14 @@ +:mod:`cassandra.datastax.graph.fluent.predicates` +================================================= + +.. module:: cassandra.datastax.graph.fluent.predicates + + +.. autoclass:: Search + :members: + +.. autoclass:: CqlCollection + :members: + +.. autoclass:: Geo + :members: diff --git a/docs/api/cassandra/datastax/graph/fluent/query.rst b/docs/api/cassandra/datastax/graph/fluent/query.rst new file mode 100644 index 0000000000..3dd859f96e --- /dev/null +++ b/docs/api/cassandra/datastax/graph/fluent/query.rst @@ -0,0 +1,8 @@ +:mod:`cassandra.datastax.graph.fluent.query` +============================================ + +.. module:: cassandra.datastax.graph.fluent.query + + +.. autoclass:: TraversalBatch + :members: diff --git a/docs/api/cassandra/datastax/graph/index.rst b/docs/api/cassandra/datastax/graph/index.rst new file mode 100644 index 0000000000..dafd5f65fd --- /dev/null +++ b/docs/api/cassandra/datastax/graph/index.rst @@ -0,0 +1,121 @@ +``cassandra.datastax.graph`` - Graph Statements, Options, and Row Factories +=========================================================================== + +.. _api-datastax-graph: + +.. module:: cassandra.datastax.graph + +.. autofunction:: single_object_row_factory + +.. autofunction:: graph_result_row_factory + +.. autofunction:: graph_object_row_factory + +.. autofunction:: graph_graphson2_row_factory + +.. autofunction:: graph_graphson3_row_factory + +.. function:: to_int(value) + + Wraps a value to be explicitly serialized as a graphson Int. + +.. function:: to_bigint(value) + + Wraps a value to be explicitly serialized as a graphson Bigint. + +.. function:: to_smallint(value) + + Wraps a value to be explicitly serialized as a graphson Smallint. + +.. function:: to_float(value) + + Wraps a value to be explicitly serialized as a graphson Float. + +.. function:: to_double(value) + + Wraps a value to be explicitly serialized as a graphson Double. + +.. autoclass:: GraphProtocol + :members: + +.. autoclass:: GraphOptions + + .. autoattribute:: graph_name + + .. autoattribute:: graph_source + + .. autoattribute:: graph_language + + .. autoattribute:: graph_read_consistency_level + + .. autoattribute:: graph_write_consistency_level + + .. autoattribute:: is_default_source + + .. autoattribute:: is_analytics_source + + .. autoattribute:: is_graph_source + + .. automethod:: set_source_default + + .. automethod:: set_source_analytics + + .. automethod:: set_source_graph + + +.. autoclass:: SimpleGraphStatement + :members: + +.. autoclass:: Result + :members: + +.. autoclass:: Vertex + :members: + +.. autoclass:: VertexProperty + :members: + +.. autoclass:: Edge + :members: + +.. autoclass:: Path + :members: + +.. autoclass:: T + :members: + +.. autoclass:: GraphSON1Serializer + :members: + +.. autoclass:: GraphSON1Deserializer + + .. automethod:: deserialize_date + + .. automethod:: deserialize_timestamp + + .. automethod:: deserialize_time + + .. automethod:: deserialize_duration + + .. automethod:: deserialize_int + + .. automethod:: deserialize_bigint + + .. automethod:: deserialize_double + + .. automethod:: deserialize_float + + .. automethod:: deserialize_uuid + + .. automethod:: deserialize_blob + + .. automethod:: deserialize_decimal + + .. automethod:: deserialize_point + + .. automethod:: deserialize_linestring + + .. automethod:: deserialize_polygon + +.. autoclass:: GraphSON2Reader + :members: diff --git a/docs/api/cassandra/decoder.rst b/docs/api/cassandra/decoder.rst new file mode 100644 index 0000000000..e213cc6d74 --- /dev/null +++ b/docs/api/cassandra/decoder.rst @@ -0,0 +1,20 @@ +``cassandra.decoder`` - Data Return Formats +=========================================== + +.. module:: cassandra.decoder + +.. function:: tuple_factory + + **Deprecated in 2.0.0.** Use :meth:`cassandra.query.tuple_factory` + +.. function:: named_tuple_factory + + **Deprecated in 2.0.0.** Use :meth:`cassandra.query.named_tuple_factory` + +.. function:: dict_factory + + **Deprecated in 2.0.0.** Use :meth:`cassandra.query.dict_factory` + +.. function:: ordered_dict_factory + + **Deprecated in 2.0.0.** Use :meth:`cassandra.query.ordered_dict_factory` diff --git a/docs/api/cassandra/encoder.rst b/docs/api/cassandra/encoder.rst new file mode 100644 index 0000000000..de3b180510 --- /dev/null +++ b/docs/api/cassandra/encoder.rst @@ -0,0 +1,36 @@ +``cassandra.encoder`` - Encoders for non-prepared Statements +============================================================ + +.. module:: cassandra.encoder + +.. autoclass:: Encoder () + + .. autoattribute:: cassandra.encoder.Encoder.mapping + + .. automethod:: cassandra.encoder.Encoder.cql_encode_none () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_object () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_all_types () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_sequence () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_str () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_unicode () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_bytes () + + Converts strings, buffers, and bytearrays into CQL blob literals. + + .. automethod:: cassandra.encoder.Encoder.cql_encode_datetime () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_date () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_map_collection () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_list_collection () + + .. automethod:: cassandra.encoder.Encoder.cql_encode_set_collection () + + .. automethod:: cql_encode_tuple () diff --git a/docs/api/cassandra/graph.rst b/docs/api/cassandra/graph.rst new file mode 100644 index 0000000000..43ddd3086c --- /dev/null +++ b/docs/api/cassandra/graph.rst @@ -0,0 +1,121 @@ +``cassandra.graph`` - Graph Statements, Options, and Row Factories +================================================================== + +.. note:: This module is only for backward compatibility for dse-driver users. Consider using :ref:`cassandra.datastax.graph `. + +.. module:: cassandra.graph + +.. autofunction:: single_object_row_factory + +.. autofunction:: graph_result_row_factory + +.. autofunction:: graph_object_row_factory + +.. autofunction:: graph_graphson2_row_factory + +.. autofunction:: graph_graphson3_row_factory + +.. function:: to_int(value) + + Wraps a value to be explicitly serialized as a graphson Int. + +.. function:: to_bigint(value) + + Wraps a value to be explicitly serialized as a graphson Bigint. + +.. function:: to_smallint(value) + + Wraps a value to be explicitly serialized as a graphson Smallint. + +.. function:: to_float(value) + + Wraps a value to be explicitly serialized as a graphson Float. + +.. function:: to_double(value) + + Wraps a value to be explicitly serialized as a graphson Double. + +.. autoclass:: GraphProtocol + :members: + +.. autoclass:: GraphOptions + + .. autoattribute:: graph_name + + .. autoattribute:: graph_source + + .. autoattribute:: graph_language + + .. autoattribute:: graph_read_consistency_level + + .. autoattribute:: graph_write_consistency_level + + .. autoattribute:: is_default_source + + .. autoattribute:: is_analytics_source + + .. autoattribute:: is_graph_source + + .. automethod:: set_source_default + + .. automethod:: set_source_analytics + + .. automethod:: set_source_graph + + +.. autoclass:: SimpleGraphStatement + :members: + +.. autoclass:: Result + :members: + +.. autoclass:: Vertex + :members: + +.. autoclass:: VertexProperty + :members: + +.. autoclass:: Edge + :members: + +.. autoclass:: Path + :members: + +.. autoclass:: GraphSON1Serializer + :members: + +.. autoclass:: GraphSON1Deserializer + + .. automethod:: deserialize_date + + .. automethod:: deserialize_timestamp + + .. automethod:: deserialize_time + + .. automethod:: deserialize_duration + + .. automethod:: deserialize_int + + .. automethod:: deserialize_bigint + + .. automethod:: deserialize_double + + .. automethod:: deserialize_float + + .. automethod:: deserialize_uuid + + .. automethod:: deserialize_blob + + .. automethod:: deserialize_decimal + + .. automethod:: deserialize_point + + .. automethod:: deserialize_linestring + + .. automethod:: deserialize_polygon + +.. autoclass:: GraphSON2Reader + :members: + +.. autoclass:: GraphSON3Reader + :members: diff --git a/docs/api/cassandra/io/asyncioreactor.rst b/docs/api/cassandra/io/asyncioreactor.rst new file mode 100644 index 0000000000..38ae63ca7f --- /dev/null +++ b/docs/api/cassandra/io/asyncioreactor.rst @@ -0,0 +1,7 @@ +``cassandra.io.asyncioreactor`` - ``asyncio`` Event Loop +===================================================================== + +.. module:: cassandra.io.asyncioreactor + +.. autoclass:: AsyncioConnection + :members: diff --git a/docs/api/cassandra/io/eventletreactor.rst b/docs/api/cassandra/io/eventletreactor.rst new file mode 100644 index 0000000000..1ba742c7e9 --- /dev/null +++ b/docs/api/cassandra/io/eventletreactor.rst @@ -0,0 +1,7 @@ +``cassandra.io.eventletreactor`` - ``eventlet``-compatible Connection +===================================================================== + +.. module:: cassandra.io.eventletreactor + +.. autoclass:: EventletConnection + :members: diff --git a/docs/api/cassandra/io/geventreactor.rst b/docs/api/cassandra/io/geventreactor.rst new file mode 100644 index 0000000000..603affe140 --- /dev/null +++ b/docs/api/cassandra/io/geventreactor.rst @@ -0,0 +1,7 @@ +``cassandra.io.geventreactor`` - ``gevent``-compatible Event Loop +================================================================= + +.. module:: cassandra.io.geventreactor + +.. autoclass:: GeventConnection + :members: diff --git a/docs/api/cassandra/io/twistedreactor.rst b/docs/api/cassandra/io/twistedreactor.rst new file mode 100644 index 0000000000..24e93bd432 --- /dev/null +++ b/docs/api/cassandra/io/twistedreactor.rst @@ -0,0 +1,9 @@ +``cassandra.io.twistedreactor`` - Twisted Event Loop +==================================================== + +.. module:: cassandra.io.twistedreactor + +.. class:: TwistedConnection + + An implementation of :class:`~cassandra.io.connection.Connection` that uses + Twisted's reactor as its event loop. diff --git a/docs/api/cassandra/metadata.rst b/docs/api/cassandra/metadata.rst index 7400b0f82b..91fe39fd99 100644 --- a/docs/api/cassandra/metadata.rst +++ b/docs/api/cassandra/metadata.rst @@ -3,9 +3,18 @@ .. module:: cassandra.metadata +.. autodata:: cql_keywords + :annotation: + +.. autodata:: cql_keywords_unreserved + :annotation: + +.. autodata:: cql_keywords_reserved + :annotation: + .. autoclass:: Metadata () :members: - :exclude-members: rebuild_schema, rebuild_token_map, add_host, remove_host, get_host + :exclude-members: rebuild_schema, rebuild_token_map, add_host, remove_host Schemas ------- @@ -13,15 +22,39 @@ Schemas .. autoclass:: KeyspaceMetadata () :members: +.. autoclass:: UserType () + :members: + +.. autoclass:: Function () + :members: + +.. autoclass:: Aggregate () + :members: + .. autoclass:: TableMetadata () :members: +.. autoclass:: TableMetadataV3 () + :members: + +.. autoclass:: TableMetadataDSE68 () + :members: + .. autoclass:: ColumnMetadata () :members: .. autoclass:: IndexMetadata () :members: +.. autoclass:: MaterializedViewMetadata () + :members: + +.. autoclass:: VertexMetadata () + :members: + +.. autoclass:: EdgeMetadata () + :members: + Tokens and Ring Topology ------------------------ @@ -39,3 +72,21 @@ Tokens and Ring Topology .. autoclass:: BytesToken :members: + +.. autoclass:: ReplicationStrategy + :members: + +.. autoclass:: ReplicationFactor + :members: + :exclude-members: create + +.. autoclass:: SimpleStrategy + :members: + +.. autoclass:: NetworkTopologyStrategy + :members: + +.. autoclass:: LocalStrategy + :members: + +.. autofunction:: group_keys_by_replica diff --git a/docs/api/cassandra/metrics.rst b/docs/api/cassandra/metrics.rst new file mode 100644 index 0000000000..0df7f8b5b9 --- /dev/null +++ b/docs/api/cassandra/metrics.rst @@ -0,0 +1,7 @@ +``cassandra.metrics`` - Performance Metrics +=========================================== + +.. module:: cassandra.metrics + +.. autoclass:: cassandra.metrics.Metrics () + :members: diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 833436a94c..387b19ed95 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -18,9 +18,38 @@ Load Balancing .. autoclass:: DCAwareRoundRobinPolicy :members: +.. autoclass:: WhiteListRoundRobinPolicy + :members: + .. autoclass:: TokenAwarePolicy :members: +.. autoclass:: HostFilterPolicy + + .. we document these methods manually so we can specify a param to predicate + + .. automethod:: predicate(host) + .. automethod:: distance + .. automethod:: make_query_plan + +.. autoclass:: DefaultLoadBalancingPolicy + :members: + +.. autoclass:: DSELoadBalancingPolicy + :members: + +Translating Server Node Addresses +--------------------------------- + +.. autoclass:: AddressTranslator + :members: + +.. autoclass:: IdentityTranslator + :members: + +.. autoclass:: EC2MultiRegionTranslator + :members: + Marking Hosts Up or Down ------------------------ @@ -56,3 +85,12 @@ Retrying Failed Operations .. autoclass:: DowngradingConsistencyRetryPolicy :members: + +Retrying Idempotent Operations +------------------------------ + +.. autoclass:: SpeculativeExecutionPolicy + :members: + +.. autoclass:: ConstantSpeculativeExecutionPolicy + :members: diff --git a/docs/api/cassandra/pool.rst b/docs/api/cassandra/pool.rst index c0f5418502..b14d30e19c 100644 --- a/docs/api/cassandra/pool.rst +++ b/docs/api/cassandra/pool.rst @@ -4,8 +4,8 @@ .. automodule:: cassandra.pool .. autoclass:: Host () - :members: - :exclude-members: set_location_info, get_and_set_reconnection_handler + :members: + :exclude-members: set_location_info, get_and_set_reconnection_handler -.. autoclass:: HealthMonitor () - :members: +.. autoexception:: NoConnectionsAvailable + :members: diff --git a/docs/api/cassandra/protocol.rst b/docs/api/cassandra/protocol.rst new file mode 100644 index 0000000000..f615ab1a70 --- /dev/null +++ b/docs/api/cassandra/protocol.rst @@ -0,0 +1,55 @@ +``cassandra.protocol`` - Protocol Features +===================================================================== + +.. module:: cassandra.protocol + +.. _custom_payload: + +Custom Payloads +--------------- +Native protocol version 4+ allows for a custom payload to be sent between clients +and custom query handlers. The payload is specified as a string:binary_type dict +holding custom key/value pairs. + +By default these are ignored by the server. They can be useful for servers implementing +a custom QueryHandler. + +See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`. + +.. autoclass:: _ProtocolHandler + + .. autoattribute:: message_types_by_opcode + :annotation: = {default mapping} + + .. automethod:: encode_message + + .. automethod:: decode_message + +.. _faster_deser: + +Faster Deserialization +---------------------- +When python-driver is compiled with Cython, it uses a Cython-based deserialization path +to deserialize messages. By default, the driver will use a Cython-based parser that returns +lists of rows similar to the pure-Python version. In addition, there are two additional +ProtocolHandler classes that can be used to deserialize response messages: ``LazyProtocolHandler`` +and ``NumpyProtocolHandler``. They can be used as follows: + +.. code:: python + + from cassandra.protocol import NumpyProtocolHandler, LazyProtocolHandler + from cassandra.query import tuple_factory + s.client_protocol_handler = LazyProtocolHandler # for a result iterator + s.row_factory = tuple_factory #required for Numpy results + s.client_protocol_handler = NumpyProtocolHandler # for a dict of NumPy arrays as result + +These protocol handlers comprise different parsers, and return results as described below: + +- ProtocolHandler: this default implementation is a drop-in replacement for the pure-Python version. + The rows are all parsed upfront, before results are returned. + +- LazyProtocolHandler: near drop-in replacement for the above, except that it returns an iterator over rows, + lazily decoded into the default row format (this is more efficient since all decoded results are not materialized at once) + +- NumpyProtocolHander: deserializes results directly into NumPy arrays. This facilitates efficient integration with + analysis toolkits such as Pandas. diff --git a/docs/api/cassandra/query.rst b/docs/api/cassandra/query.rst index 39f700b90c..fcd79739b9 100644 --- a/docs/api/cassandra/query.rst +++ b/docs/api/cassandra/query.rst @@ -1,10 +1,15 @@ -``cassandra.query`` - Prepared Statements and Query Policies -============================================================ +``cassandra.query`` - Prepared Statements, Batch Statements, Tracing, and Row Factories +======================================================================================= .. module:: cassandra.query -.. autoclass:: Query - :members: +.. autofunction:: tuple_factory + +.. autofunction:: named_tuple_factory + +.. autofunction:: dict_factory + +.. autofunction:: ordered_dict_factory .. autoclass:: SimpleStatement :members: @@ -15,11 +20,40 @@ .. autoclass:: BoundStatement :members: -.. autoclass:: ValueSequence +.. autoclass:: Statement () :members: +.. autodata:: UNSET_VALUE + :annotation: + +.. autoclass:: BatchStatement (batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None) + :members: + +.. autoclass:: BatchType () + + .. autoattribute:: LOGGED + + .. autoattribute:: UNLOGGED + + .. autoattribute:: COUNTER + +.. autoclass:: cassandra.query.ValueSequence + + A wrapper class that is used to specify that a sequence of values should + be treated as a CQL list of values instead of a single column collection when used + as part of the `parameters` argument for :meth:`.Session.execute()`. + + This is typically needed when supplying a list of keys to select. + For example:: + + >>> my_user_ids = ('alice', 'bob', 'charles') + >>> query = "SELECT * FROM users WHERE user_id IN %s" + >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) + .. autoclass:: QueryTrace () :members: .. autoclass:: TraceEvent () :members: + +.. autoexception:: TraceUnavailable diff --git a/docs/api/cassandra/timestamps.rst b/docs/api/cassandra/timestamps.rst new file mode 100644 index 0000000000..7c7f534aea --- /dev/null +++ b/docs/api/cassandra/timestamps.rst @@ -0,0 +1,14 @@ +``cassandra.timestamps`` - Timestamp Generation +============================================= + +.. module:: cassandra.timestamps + +.. autoclass:: MonotonicTimestampGenerator (warn_on_drift=True, warning_threshold=0, warning_interval=0) + + .. autoattribute:: warn_on_drift + + .. autoattribute:: warning_threshold + + .. autoattribute:: warning_interval + + .. automethod:: _next_timestamp diff --git a/docs/api/cassandra/util.rst b/docs/api/cassandra/util.rst new file mode 100644 index 0000000000..848d4d5fc2 --- /dev/null +++ b/docs/api/cassandra/util.rst @@ -0,0 +1,5 @@ +``cassandra.util`` - Utilities +=================================== + +.. automodule:: cassandra.util + :members: diff --git a/docs/api/index.rst b/docs/api/index.rst index f70dfb5401..9e778d508c 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -1,18 +1,54 @@ API Documentation ================= -Cassandra Modules ------------------ - +Core Driver +----------- .. toctree:: :maxdepth: 2 cassandra cassandra/cluster cassandra/policies + cassandra/auth + cassandra/graph cassandra/metadata + cassandra/metrics cassandra/query cassandra/pool + cassandra/protocol + cassandra/encoder + cassandra/decoder + cassandra/concurrent cassandra/connection + cassandra/util + cassandra/timestamps + cassandra/io/asyncioreactor cassandra/io/asyncorereactor + cassandra/io/eventletreactor cassandra/io/libevreactor + cassandra/io/geventreactor + cassandra/io/twistedreactor + +.. _om_api: + +Object Mapper +------------- +.. toctree:: + :maxdepth: 1 + + cassandra/cqlengine/models + cassandra/cqlengine/columns + cassandra/cqlengine/query + cassandra/cqlengine/connection + cassandra/cqlengine/management + cassandra/cqlengine/usertype + +DataStax Graph +-------------- +.. toctree:: + :maxdepth: 1 + + cassandra/datastax/graph/index + cassandra/datastax/graph/fluent/index + cassandra/datastax/graph/fluent/query + cassandra/datastax/graph/fluent/predicates diff --git a/docs/classic_graph.rst b/docs/classic_graph.rst new file mode 100644 index 0000000000..ef68c86359 --- /dev/null +++ b/docs/classic_graph.rst @@ -0,0 +1,299 @@ +DataStax Classic Graph Queries +============================== + +Getting Started +~~~~~~~~~~~~~~~ + +First, we need to create a graph in the system. To access the system API, we +use the system execution profile :: + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + + cluster = Cluster() + session = cluster.connect() + + graph_name = 'movies' + session.execute_graph("system.graph(name).ifNotExists().engine(Classic).create()", {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + + +To execute requests on our newly created graph, we need to setup an execution +profile. Additionally, we also need to set the schema_mode to `development` +for the schema creation:: + + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.graph import GraphOptions + + graph_name = 'movies' + ep = GraphExecutionProfile(graph_options=GraphOptions(graph_name=graph_name)) + + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + + session.execute_graph("schema.config().option('graph.schema_mode').set('development')") + + +We are ready to configure our graph schema. We will create a simple one for movies:: + + # properties are used to define a vertex + properties = """ + schema.propertyKey("genreId").Text().create(); + schema.propertyKey("personId").Text().create(); + schema.propertyKey("movieId").Text().create(); + schema.propertyKey("name").Text().create(); + schema.propertyKey("title").Text().create(); + schema.propertyKey("year").Int().create(); + schema.propertyKey("country").Text().create(); + """ + + session.execute_graph(properties) # we can execute multiple statements in a single request + + # A Vertex represents a "thing" in the world. + vertices = """ + schema.vertexLabel("genre").properties("genreId","name").create(); + schema.vertexLabel("person").properties("personId","name").create(); + schema.vertexLabel("movie").properties("movieId","title","year","country").create(); + """ + + session.execute_graph(vertices) + + # An edge represents a relationship between two vertices + edges = """ + schema.edgeLabel("belongsTo").single().connection("movie","genre").create(); + schema.edgeLabel("actor").connection("movie","person").create(); + """ + + session.execute_graph(edges) + + # Indexes to execute graph requests efficiently + indexes = """ + schema.vertexLabel("genre").index("genresById").materialized().by("genreId").add(); + schema.vertexLabel("genre").index("genresByName").materialized().by("name").add(); + schema.vertexLabel("person").index("personsById").materialized().by("personId").add(); + schema.vertexLabel("person").index("personsByName").materialized().by("name").add(); + schema.vertexLabel("movie").index("moviesById").materialized().by("movieId").add(); + schema.vertexLabel("movie").index("moviesByTitle").materialized().by("title").add(); + schema.vertexLabel("movie").index("moviesByYear").secondary().by("year").add(); + """ + +Next, we'll add some data:: + + session.execute_graph(""" + g.addV('genre').property('genreId', 1).property('name', 'Action').next(); + g.addV('genre').property('genreId', 2).property('name', 'Drama').next(); + g.addV('genre').property('genreId', 3).property('name', 'Comedy').next(); + g.addV('genre').property('genreId', 4).property('name', 'Horror').next(); + """) + + session.execute_graph(""" + g.addV('person').property('personId', 1).property('name', 'Mark Wahlberg').next(); + g.addV('person').property('personId', 2).property('name', 'Leonardo DiCaprio').next(); + g.addV('person').property('personId', 3).property('name', 'Iggy Pop').next(); + """) + + session.execute_graph(""" + g.addV('movie').property('movieId', 1).property('title', 'The Happening'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 2).property('title', 'The Italian Job'). + property('year', 2003).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 3).property('title', 'Revolutionary Road'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 4).property('title', 'The Man in the Iron Mask'). + property('year', 1998).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 5).property('title', 'Dead Man'). + property('year', 1995).property('country', 'United States').next(); + """) + +Now that our genre, actor and movie vertices are added, we'll create the relationships (edges) between them:: + + session.execute_graph(""" + genre_horror = g.V().hasLabel('genre').has('name', 'Horror').next(); + genre_drama = g.V().hasLabel('genre').has('name', 'Drama').next(); + genre_action = g.V().hasLabel('genre').has('name', 'Action').next(); + + leo = g.V().hasLabel('person').has('name', 'Leonardo DiCaprio').next(); + mark = g.V().hasLabel('person').has('name', 'Mark Wahlberg').next(); + iggy = g.V().hasLabel('person').has('name', 'Iggy Pop').next(); + + the_happening = g.V().hasLabel('movie').has('title', 'The Happening').next(); + the_italian_job = g.V().hasLabel('movie').has('title', 'The Italian Job').next(); + rev_road = g.V().hasLabel('movie').has('title', 'Revolutionary Road').next(); + man_mask = g.V().hasLabel('movie').has('title', 'The Man in the Iron Mask').next(); + dead_man = g.V().hasLabel('movie').has('title', 'Dead Man').next(); + + the_happening.addEdge('belongsTo', genre_horror); + the_italian_job.addEdge('belongsTo', genre_action); + rev_road.addEdge('belongsTo', genre_drama); + man_mask.addEdge('belongsTo', genre_drama); + man_mask.addEdge('belongsTo', genre_action); + dead_man.addEdge('belongsTo', genre_drama); + + the_happening.addEdge('actor', mark); + the_italian_job.addEdge('actor', mark); + rev_road.addEdge('actor', leo); + man_mask.addEdge('actor', leo); + dead_man.addEdge('actor', iggy); + """) + +We are all set. You can now query your graph. Here are some examples:: + + # Find all movies of the genre Drama + for r in session.execute_graph(""" + g.V().has('genre', 'name', 'Drama').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of the same genre than the movie 'Dead Man' + for r in session.execute_graph(""" + g.V().has('movie', 'title', 'Dead Man').out('belongsTo').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of Mark Wahlberg + for r in session.execute_graph(""" + g.V().has('person', 'name', 'Mark Wahlberg').in('actor').valueMap();"""): + print(r) + +To see a more graph examples, see `DataStax Graph Examples `_. + +Graph Types +~~~~~~~~~~~ + +Here are the supported graph types with their python representations: + +========== ================ +DSE Graph Python +========== ================ +boolean bool +bigint long, int (PY3) +int int +smallint int +varint int +float float +double double +uuid uuid.UUID +Decimal Decimal +inet str +timestamp datetime.datetime +date datetime.date +time datetime.time +duration datetime.timedelta +point Point +linestring LineString +polygon Polygon +blob bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +========== ================ + +Graph Row Factory +~~~~~~~~~~~~~~~~~ + +By default (with :class:`.GraphExecutionProfile.row_factory` set to :func:`.graph.graph_object_row_factory`), known graph result +types are unpacked and returned as specialized types (:class:`.Vertex`, :class:`.Edge`). If the result is not one of these +types, a :class:`.graph.Result` is returned, containing the graph result parsed from JSON and removed from its outer dict. +The class has some accessor convenience methods for accessing top-level properties by name (`type`, `properties` above), +or lists by index:: + + # dicts with `__getattr__` or `__getitem__` + result = session.execute_graph("[[key_str: 'value', key_int: 3]]", execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)[0] # Using system exec just because there is no graph defined + result # dse.graph.Result({u'key_str': u'value', u'key_int': 3}) + result.value # {u'key_int': 3, u'key_str': u'value'} (dict) + result.key_str # u'value' + result.key_int # 3 + result['key_str'] # u'value' + result['key_int'] # 3 + + # lists with `__getitem__` + result = session.execute_graph('[[0, 1, 2]]', execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)[0] + result # dse.graph.Result([0, 1, 2]) + result.value # [0, 1, 2] (list) + result[1] # 1 (list[1]) + +You can use a different row factory by setting :attr:`.Session.default_graph_row_factory` or passing it to +:meth:`.Session.execute_graph`. For example, :func:`.graph.single_object_row_factory` returns the JSON result string`, +unparsed. :func:`.graph.graph_result_row_factory` returns parsed, but unmodified results (such that all metadata is retained, +unlike :func:`.graph.graph_object_row_factory`, which sheds some as attributes and properties are unpacked). These results +also provide convenience methods for converting to known types (:meth:`~.Result.as_vertex`, :meth:`~.Result.as_edge`, :meth:`~.Result.as_path`). + +Vertex and Edge properties are never unpacked since their types are unknown. If you know your graph schema and want to +deserialize properties, use the :class:`.GraphSON1Deserializer`. It provides convenient methods to deserialize by types (e.g. +deserialize_date, deserialize_uuid, deserialize_polygon etc.) Example:: + + # ... + from cassandra.graph import GraphSON1Deserializer + + row = session.execute_graph("g.V().toList()")[0] + value = row.properties['my_property_key'][0].value # accessing the VertexProperty value + value = GraphSON1Deserializer.deserialize_timestamp(value) + + print(value) # 2017-06-26 08:27:05 + print(type(value)) # + + +Named Parameters +~~~~~~~~~~~~~~~~ + +Named parameters are passed in a dict to :meth:`.cluster.Session.execute_graph`:: + + result_set = session.execute_graph('[a, b]', {'a': 1, 'b': 2}, execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + [r.value for r in result_set] # [1, 2] + +All python types listed in `Graph Types`_ can be passed as named parameters and will be serialized +automatically to their graph representation: + +Example:: + + session.execute_graph(""" + g.addV('person'). + property('name', text_value). + property('age', integer_value). + property('birthday', timestamp_value). + property('house_yard', polygon_value).toList() + """, { + 'text_value': 'Mike Smith', + 'integer_value': 34, + 'timestamp_value': datetime.datetime(1967, 12, 30), + 'polygon_value': Polygon(((30, 10), (40, 40), (20, 40), (10, 20), (30, 10))) + }) + + +As with all Execution Profile parameters, graph options can be set in the cluster default (as shown in the first example) +or specified per execution:: + + ep = session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, + graph_options=GraphOptions(graph_name='something-else')) + session.execute_graph(statement, execution_profile=ep) + +Using GraphSON2 Protocol +~~~~~~~~~~~~~~~~~~~~~~~~ + +The default graph protocol used is GraphSON1. However GraphSON1 may +cause problems of type conversion happening during the serialization +of the query to the DSE Graph server, or the deserialization of the +responses back from a string Gremlin query. GraphSON2 offers better +support for the complex data types handled by DSE Graph. + +DSE >=5.0.4 now offers the possibility to use the GraphSON2 protocol +for graph queries. Enabling GraphSON2 can be done by `changing the +graph protocol of the execution profile` and `setting the graphson2 row factory`:: + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.graph import GraphOptions, GraphProtocol, graph_graphson2_row_factory + + # Create a GraphSON2 execution profile + ep = GraphExecutionProfile(graph_options=GraphOptions(graph_name='types', + graph_protocol=GraphProtocol.GRAPHSON_2_0), + row_factory=graph_graphson2_row_factory) + + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + session.execute_graph(...) + +Using GraphSON2, all properties will be automatically deserialized to +its Python representation. Note that it may bring significant +behavioral change at runtime. + +It is generally recommended to switch to GraphSON2 as it brings more +consistent support for complex data types in the Graph driver and will +be activated by default in the next major version (Python dse-driver +driver 3.0). diff --git a/docs/cloud.rst b/docs/cloud.rst new file mode 100644 index 0000000000..3230720ec9 --- /dev/null +++ b/docs/cloud.rst @@ -0,0 +1,105 @@ +Cloud +----- +Connecting +========== +To connect to a DataStax Astra cluster: + +1. Download the secure connect bundle from your Astra account. +2. Connect to your cluster with + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip' + } + auth_provider = PlainTextAuthProvider(username='user', password='pass') + cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider) + session = cluster.connect() + +Cloud Config Options +==================== + +use_default_tempdir ++++++++++++++++++++ + +The secure connect bundle needs to be extracted to load the certificates into the SSLContext. +By default, the zip location is used as the base dir for the extraction. In some environments, +the zip location file system is read-only (e.g Azure Function). With *use_default_tempdir* set to *True*, +the default temporary directory of the system will be used as base dir. + +.. code:: python + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip', + 'use_default_tempdir': True + } + ... + +connect_timeout ++++++++++++++++++++ + +As part of the process of connecting to Astra the Python driver will query a service to retrieve +current information about your cluster. You can control the connection timeout for this operation +using *connect_timeout*. If you observe errors in `read_metadata_info` you might consider increasing +this parameter. This timeout is specified in seconds. + +.. code:: python + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip', + 'connect_timeout': 120 + } + ... + +Astra Differences +================== +In most circumstances, the client code for interacting with an Astra cluster will be the same as interacting with any other Cassandra cluster. The exceptions being: + +* A cloud configuration must be passed to a :class:`~.Cluster` instance via the `cloud` attribute (as demonstrated above). +* An SSL connection will be established automatically. Manual SSL configuration is not allowed, and using `ssl_context` or `ssl_options` will result in an exception. +* A :class:`~.Cluster`'s `contact_points` attribute should not be used. The cloud config contains all of the necessary contact information. +* If a consistency level is not specified for an execution profile or query, then :attr:`.ConsistencyLevel.LOCAL_QUORUM` will be used as the default. + + +Limitations +=========== + +Event loops +^^^^^^^^^^^ +Evenlet isn't yet supported for python 3.7+ due to an `issue in Eventlet `_. + + +CqlEngine +========= + +When using the object mapper, you can configure cqlengine with :func:`~.cqlengine.connection.set_session`: + +.. code:: python + + from cassandra.cqlengine import connection + ... + + c = Cluster(cloud={'secure_connect_bundle':'/path/to/secure-connect-test.zip'}, + auth_provider=PlainTextAuthProvider('user', 'pass')) + s = c.connect('myastrakeyspace') + connection.set_session(s) + ... + +If you are using some third-party libraries (flask, django, etc.), you might not be able to change the +configuration mechanism. For this reason, the `hosts` argument of the default +:func:`~.cqlengine.connection.setup` function will be ignored if a `cloud` config is provided: + +.. code:: python + + from cassandra.cqlengine import connection + ... + + connection.setup( + None, # or anything else + "myastrakeyspace", cloud={ + 'secure_connect_bundle':'/path/to/secure-connect-test.zip' + }, + auth_provider=PlainTextAuthProvider('user', 'pass')) diff --git a/docs/column_encryption.rst b/docs/column_encryption.rst new file mode 100644 index 0000000000..ab67ef16d0 --- /dev/null +++ b/docs/column_encryption.rst @@ -0,0 +1,101 @@ +Column Encryption +================= + +Overview +-------- +Support for client-side encryption of data was added in version 3.27.0 of the Python driver. When using +this feature data will be encrypted on-the-fly according to a specified :class:`~.ColumnEncryptionPolicy` +instance. This policy is also used to decrypt data in returned rows. If a prepared statement is used +this decryption is transparent to the user; retrieved data will be decrypted and converted into the original +type (according to definitions in the encryption policy). Support for simple (i.e. non-prepared) queries is +also available, although in this case values must be manually encrypted and/or decrypted. The +:class:`~.ColumnEncryptionPolicy` instance provides methods to assist with these operations. + +Client-side encryption and decryption should work against all versions of Cassandra and DSE. It does not +utilize any server-side functionality to do its work. + +WARNING: Encryption format changes in 3.28.0 +------------------------------------------------ +Python driver 3.28.0 introduces a new encryption format for data written by :class:`~.AES256ColumnEncryptionPolicy`. +As a result, any encrypted data written by Python driver 3.27.0 will **NOT** be readable. +If you upgraded from 3.27.0, you should re-encrypt your data with 3.28.0. + +Configuration +------------- +Client-side encryption is enabled by creating an instance of a subclass of :class:`~.ColumnEncryptionPolicy` +and adding information about columns to be encrypted to it. This policy is then supplied to :class:`~.Cluster` +when it's created. + +.. code-block:: python + + import os + + from cassandra.policies import ColDesc + from cassandra.column_encryption.policies import AES256ColumnEncryptionPolicy, AES256_KEY_SIZE_BYTES + + key = os.urandom(AES256_KEY_SIZE_BYTES) + cl_policy = AES256ColumnEncryptionPolicy() + col_desc = ColDesc('ks1','table1','column1') + cql_type = "int" + cl_policy.add_column(col_desc, key, cql_type) + cluster = Cluster(column_encryption_policy=cl_policy) + +:class:`~.AES256ColumnEncryptionPolicy` is a subclass of :class:`~.ColumnEncryptionPolicy` which provides +encryption and decryption via AES-256. This class is currently the only available column encryption policy +implementation, although users can certainly implement their own by subclassing :class:`~.ColumnEncryptionPolicy`. + +:class:`~.ColDesc` is a named tuple which uniquely identifies a column in a given keyspace and table. When we +have this tuple, the encryption key and the CQL type contained by this column we can add the column to the policy +using :func:`~.ColumnEncryptionPolicy.add_column`. Once we have added all column definitions to the policy we +pass it along to the cluster. + +The CQL type for the column only has meaning at the client; it is never sent to Cassandra. The encryption key +is also never sent to the server; all the server ever sees are random bytes reflecting the encrypted data. As a +result all columns containing client-side encrypted values should be declared with the CQL type "blob" at the +Cassandra server. + +Usage +----- + +Encryption +^^^^^^^^^^ +Client-side encryption shines most when used with prepared statements. A prepared statement is aware of information +about the columns in the query it was built from and we can use this information to transparently encrypt any +supplied parameters. For example, we can create a prepared statement to insert a value into column1 (as defined above) +by executing the following code after creating a :class:`~.Cluster` in the manner described above: + +.. code-block:: python + + session = cluster.connect() + prepared = session.prepare("insert into ks1.table1 (column1) values (?)") + session.execute(prepared, (1000,)) + +Our encryption policy will detect that "column1" is an encrypted column and take appropriate action. + +As mentioned above client-side encryption can also be used with simple queries, although such use cases are +certainly not transparent. :class:`~.ColumnEncryptionPolicy` provides a helper named +:func:`~.ColumnEncryptionPolicy.encode_and_encrypt` which will convert an input value into bytes using the +standard serialization methods employed by the driver. The result is then encrypted according to the configuration +of the policy. Using this approach the example above could be implemented along the lines of the following: + +.. code-block:: python + + session = cluster.connect() + session.execute("insert into ks1.table1 (column1) values (%s)",(cl_policy.encode_and_encrypt(col_desc, 1000),)) + +Decryption +^^^^^^^^^^ +Decryption of values returned from the server is always transparent. Whether we're executing a simple or prepared +statement encrypted columns will be decrypted automatically and made available via rows just like any other +result. + +Limitations +----------- +:class:`~.AES256ColumnEncryptionPolicy` uses the implementation of AES-256 provided by the +`cryptography `_ module. Any limitations of this module should be considered +when deploying client-side encryption. Note specifically that a Rust compiler is required for modern versions +of the cryptography package, although wheels exist for many common platforms. + +Client-side encryption has been implemented for both the default Cython and pure Python row processing logic. +This functionality has not yet been ported to the NumPy Cython implementation. During testing, +the NumPy processing works on Python 3.7 but fails for Python 3.8. diff --git a/docs/conf.py b/docs/conf.py index 46fdb24240..4c0dfb58d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -43,7 +43,7 @@ # General information about the project. project = u'Cassandra Driver' -copyright = u'2013, DataStax' +copyright = u'2013-2017 DataStax' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -96,7 +96,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinxdoc' +html_theme = 'custom' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -104,7 +104,7 @@ #html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +html_theme_path = ['./themes'] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". @@ -125,7 +125,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. @@ -136,7 +136,14 @@ #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +html_sidebars = { + '**': [ + 'about.html', + 'navigation.html', + 'relations.html', + 'searchbox.html' + ] +} # Additional templates that should be rendered to pages, maps page names to # template names. @@ -146,7 +153,7 @@ #html_domain_indices = True # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False @@ -183,8 +190,7 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'CassandraDriver.tex', u'Cassandra Driver Documentation', - u'DataStax', 'manual'), + ('index', 'cassandra-driver.tex', u'Cassandra Driver Documentation', u'DataStax', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -216,6 +222,6 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'cassandradriver', u'Cassandra Driver Documentation', - [u'Tyler Hobbs'], 1) + ('index', 'cassandra-driver', u'Cassandra Driver Documentation', + [u'DataStax'], 1) ] diff --git a/docs/cqlengine/batches.rst b/docs/cqlengine/batches.rst new file mode 100644 index 0000000000..306e7d01a6 --- /dev/null +++ b/docs/cqlengine/batches.rst @@ -0,0 +1,108 @@ +============= +Batch Queries +============= + +cqlengine supports batch queries using the BatchQuery class. Batch queries can be started and stopped manually, or within a context manager. To add queries to the batch object, you just need to precede the create/save/delete call with a call to batch, and pass in the batch object. + + +Batch Query General Use Pattern +=============================== + +You can only create, update, and delete rows with a batch query, attempting to read rows out of the database with a batch query will fail. + +.. code-block:: python + + from cassandra.cqlengine.query import BatchQuery + + #using a context manager + with BatchQuery() as b: + now = datetime.now() + em1 = ExampleModel.batch(b).create(example_type=0, description="1", created_at=now) + em2 = ExampleModel.batch(b).create(example_type=0, description="2", created_at=now) + em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now) + + # -- or -- + + #manually + b = BatchQuery() + now = datetime.now() + em1 = ExampleModel.batch(b).create(example_type=0, description="1", created_at=now) + em2 = ExampleModel.batch(b).create(example_type=0, description="2", created_at=now) + em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now) + b.execute() + + # updating in a batch + + b = BatchQuery() + em1.description = "new description" + em1.batch(b).save() + em2.description = "another new description" + em2.batch(b).save() + b.execute() + + # deleting in a batch + b = BatchQuery() + ExampleModel.objects(id=some_id).batch(b).delete() + ExampleModel.objects(id=some_id2).batch(b).delete() + b.execute() + + +Typically you will not want the block to execute if an exception occurs inside the `with` block. However, in the case that this is desirable, it's achievable by using the following syntax: + +.. code-block:: python + + with BatchQuery(execute_on_exception=True) as b: + LogEntry.batch(b).create(k=1, v=1) + mystery_function() # exception thrown in here + LogEntry.batch(b).create(k=1, v=2) # this code is never reached due to the exception, but anything leading up to here will execute in the batch. + +If an exception is thrown somewhere in the block, any statements that have been added to the batch will still be executed. This is useful for some logging situations. + +Batch Query Execution Callbacks +=============================== + +In order to allow secondary tasks to be chained to the end of batch, BatchQuery instances allow callbacks to be +registered with the batch, to be executed immediately after the batch executes. + +Multiple callbacks can be attached to same BatchQuery instance, they are executed in the same order that they +are added to the batch. + +The callbacks attached to a given batch instance are executed only if the batch executes. If the batch is used as a +context manager and an exception is raised, the queued up callbacks will not be run. + +.. code-block:: python + + def my_callback(*args, **kwargs): + pass + + batch = BatchQuery() + + batch.add_callback(my_callback) + batch.add_callback(my_callback, 'positional arg', named_arg='named arg value') + + # if you need reference to the batch within the callback, + # just trap it in the arguments to be passed to the callback: + batch.add_callback(my_callback, cqlengine_batch=batch) + + # once the batch executes... + batch.execute() + + # the effect of the above scheduled callbacks will be similar to + my_callback() + my_callback('positional arg', named_arg='named arg value') + my_callback(cqlengine_batch=batch) + +Failure in any of the callbacks does not affect the batch's execution, as the callbacks are started after the execution +of the batch is complete. + +Logged vs Unlogged Batches +--------------------------- +By default, queries in cqlengine are LOGGED, which carries additional overhead from UNLOGGED. To explicitly state which batch type to use, simply: + + +.. code-block:: python + + from cassandra.cqlengine.query import BatchType + with BatchQuery(batch_type=BatchType.Unlogged) as b: + LogEntry.batch(b).create(k=1, v=1) + LogEntry.batch(b).create(k=1, v=2) diff --git a/docs/cqlengine/connections.rst b/docs/cqlengine/connections.rst new file mode 100644 index 0000000000..fd44303514 --- /dev/null +++ b/docs/cqlengine/connections.rst @@ -0,0 +1,137 @@ +=========== +Connections +=========== + +Connections aim to ease the use of multiple sessions with cqlengine. Connections can be set on a model class, per query or using a context manager. + + +Register a new connection +========================= + +To use cqlengine, you need at least a default connection. If you initialize cqlengine's connections with with :func:`connection.setup <.connection.setup>`, a connection will be created automatically. If you want to use another cluster/session, you need to register a new cqlengine connection. You register a connection with :func:`~.connection.register_connection`: + +.. code-block:: python + + from cassandra.cqlengine import connection + + connection.setup(['127.0.0.1') + connection.register_connection('cluster2', ['127.0.0.2']) + +:func:`~.connection.register_connection` can take a list of hosts, as shown above, in which case it will create a connection with a new session. It can also take a `session` argument if you've already created a session: + +.. code-block:: python + + from cassandra.cqlengine import connection + from cassandra.cluster import Cluster + + session = Cluster(['127.0.0.1']).connect() + connection.register_connection('cluster3', session=session) + + +Change the default connection +============================= + +You can change the default cqlengine connection on registration: + +.. code-block:: python + + from cassandra.cqlengine import connection + + connection.register_connection('cluster2', ['127.0.0.2'] default=True) + +or on the fly using :func:`~.connection.set_default_connection` + +.. code-block:: python + + connection.set_default_connection('cluster2') + +Unregister a connection +======================= + +You can unregister a connection using :func:`~.connection.unregister_connection`: + +.. code-block:: python + + connection.unregister_connection('cluster2') + +Management +========== + +When using multiples connections, you also need to sync your models on all connections (and keyspaces) that you need operate on. Management commands have been improved to ease this part. Here is an example: + +.. code-block:: python + + from cassandra.cqlengine import management + + keyspaces = ['ks1', 'ks2'] + conns = ['cluster1', 'cluster2'] + + # registers your connections + # ... + + # create all keyspaces on all connections + for ks in keyspaces: + management.create_simple_keyspace(ks, connections=conns) + + # define your Automobile model + # ... + + # sync your models + management.sync_table(Automobile, keyspaces=keyspaces, connections=conns) + + +Connection Selection +==================== + +cqlengine will select the default connection, unless your specify a connection using one of the following methods. + +Default Model Connection +------------------------ + +You can specify a default connection per model: + +.. code-block:: python + + class Automobile(Model): + __keyspace__ = 'test' + __connection__ = 'cluster2' + manufacturer = columns.Text(primary_key=True) + year = columns.Integer(primary_key=True) + model = columns.Text(primary_key=True) + + print(len(Automobile.objects.all())) # executed on the connection 'cluster2' + +QuerySet and model instance +--------------------------- + +You can use the :attr:`using() <.query.ModelQuerySet.using>` method to select a connection (or keyspace): + +.. code-block:: python + + Automobile.objects.using(connection='cluster1').create(manufacturer='honda', year=2010, model='civic') + q = Automobile.objects.filter(manufacturer='Tesla') + autos = q.using(keyspace='ks2', connection='cluster2').all() + + for auto in autos: + auto.using(connection='cluster1').save() + +Context Manager +--------------- + +You can use the ContextQuery as well to select a connection: + +.. code-block:: python + + with ContextQuery(Automobile, connection='cluster1') as A: + A.objects.filter(manufacturer='honda').all() # executed on 'cluster1' + + +BatchQuery +---------- + +With a BatchQuery, you can select the connection with the context manager. Note that all operations in the batch need to use the same connection. + +.. code-block:: python + + with BatchQuery(connection='cluster1') as b: + Automobile.objects.batch(b).create(manufacturer='honda', year=2010, model='civic') diff --git a/docs/cqlengine/faq.rst b/docs/cqlengine/faq.rst new file mode 100644 index 0000000000..6c056d02ea --- /dev/null +++ b/docs/cqlengine/faq.rst @@ -0,0 +1,67 @@ +========================== +Frequently Asked Questions +========================== + +Why don't updates work correctly on models instantiated as Model(field=value, field2=value2)? +------------------------------------------------------------------------------------------------ + +The recommended way to create new rows is with the models .create method. The values passed into a model's init method are interpreted by the model as the values as they were read from a row. This allows the model to "know" which rows have changed since the row was read out of cassandra, and create suitable update statements. + +How to preserve ordering in batch query? +------------------------------------------- + +Statement Ordering is not supported by CQL3 batches. Therefore, +once cassandra needs resolving conflict(Updating the same column in one batch), +The algorithm below would be used. + +* If timestamps are different, pick the column with the largest timestamp (the value being a regular column or a tombstone) +* If timestamps are the same, and one of the columns in a tombstone ('null') - pick the tombstone +* If timestamps are the same, and none of the columns are tombstones, pick the column with the largest value + +Below is an example to show this scenario. + +.. code-block:: python + + class MyMode(Model): + id = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + with BatchQuery() as b: + MyModel.batch(b).create(id=1, count=2, text='123') + MyModel.batch(b).create(id=1, count=3, text='111') + + assert MyModel.objects(id=1).first().count == 3 + assert MyModel.objects(id=1).first().text == '123' + +The largest value of count is 3, and the largest value of text would be '123'. + +The workaround is applying timestamp to each statement, then Cassandra would +resolve to the statement with the lastest timestamp. + +.. code-block:: python + + with BatchQuery() as b: + MyModel.timestamp(datetime.now()).batch(b).create(id=1, count=2, text='123') + MyModel.timestamp(datetime.now()).batch(b).create(id=1, count=3, text='111') + + assert MyModel.objects(id=1).first().count == 3 + assert MyModel.objects(id=1).first().text == '111' + +How can I delete individual values from a row? +------------------------------------------------- + +When inserting with CQLEngine, ``None`` is equivalent to CQL ``NULL`` or to +issuing a ``DELETE`` on that column. For example: + +.. code-block:: python + + class MyModel(Model): + id = columns.Integer(primary_key=True) + text = columns.Text() + + m = MyModel.create(id=1, text='We can delete this with None') + assert MyModel.objects(id=1).first().text is not None + + m.update(text=None) + assert MyModel.objects(id=1).first().text is None diff --git a/docs/cqlengine/models.rst b/docs/cqlengine/models.rst new file mode 100644 index 0000000000..719513f4a9 --- /dev/null +++ b/docs/cqlengine/models.rst @@ -0,0 +1,218 @@ +====== +Models +====== + +.. module:: cqlengine.models + +A model is a python class representing a CQL table. Models derive from :class:`Model`, and +define basic table properties and columns for a table. + +Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. +For a model to be valid it needs at least one primary key column and one non-primary key column. Just as in CQL, the order you define +your columns in is important, and is the same order they are defined in on a model's corresponding table. + +Some basic examples defining models are shown below. Consult the :doc:`Model API docs ` and :doc:`Column API docs ` for complete details. + +Example Definitions +=================== + +This example defines a ``Person`` table, with the columns ``first_name`` and ``last_name`` + +.. code-block:: python + + from cassandra.cqlengine import columns + from cassandra.cqlengine.models import Model + + class Person(Model): + id = columns.UUID(primary_key=True) + first_name = columns.Text() + last_name = columns.Text() + + +The Person model would create this CQL table: + +.. code-block:: sql + + CREATE TABLE cqlengine.person ( + id uuid, + first_name text, + last_name text, + PRIMARY KEY (id) + ); + +Here's an example of a comment table created with clustering keys, in descending order: + +.. code-block:: python + + from cassandra.cqlengine import columns + from cassandra.cqlengine.models import Model + + class Comment(Model): + photo_id = columns.UUID(primary_key=True) + comment_id = columns.TimeUUID(primary_key=True, clustering_order="DESC") + comment = columns.Text() + +The Comment model's ``create table`` would look like the following: + +.. code-block:: sql + + CREATE TABLE comment ( + photo_id uuid, + comment_id timeuuid, + comment text, + PRIMARY KEY (photo_id, comment_id) + ) WITH CLUSTERING ORDER BY (comment_id DESC); + +To sync the models to the database, you may do the following*: + +.. code-block:: python + + from cassandra.cqlengine.management import sync_table + sync_table(Person) + sync_table(Comment) + +\*Note: synchronizing models causes schema changes, and should be done with caution. +Please see the discussion in :doc:`/api/cassandra/cqlengine/management` for considerations. + +For examples on manipulating data and creating queries, see :doc:`queryset` + +Manipulating model instances as dictionaries +============================================ + +Model instances can be accessed like dictionaries. + +.. code-block:: python + + class Person(Model): + first_name = columns.Text() + last_name = columns.Text() + + kevin = Person.create(first_name="Kevin", last_name="Deldycke") + dict(kevin) # returns {'first_name': 'Kevin', 'last_name': 'Deldycke'} + kevin['first_name'] # returns 'Kevin' + kevin.keys() # returns ['first_name', 'last_name'] + kevin.values() # returns ['Kevin', 'Deldycke'] + kevin.items() # returns [('first_name', 'Kevin'), ('last_name', 'Deldycke')] + + kevin['first_name'] = 'KEVIN5000' # changes the models first name + +Extending Model Validation +========================== + +Each time you save a model instance in cqlengine, the data in the model is validated against the schema you've defined +for your model. Most of the validation is fairly straightforward, it basically checks that you're not trying to do +something like save text into an integer column, and it enforces the ``required`` flag set on column definitions. +It also performs any transformations needed to save the data properly. + +However, there are often additional constraints or transformations you want to impose on your data, beyond simply +making sure that Cassandra won't complain when you try to insert it. To define additional validation on a model, +extend the model's validation method: + +.. code-block:: python + + class Member(Model): + person_id = UUID(primary_key=True) + name = Text(required=True) + + def validate(self): + super(Member, self).validate() + if self.name == 'jon': + raise ValidationError('no jon\'s allowed') + +*Note*: while not required, the convention is to raise a ``ValidationError`` (``from cassandra.cqlengine import ValidationError``) +if validation fails. + +.. _model_inheritance: + +Model Inheritance +================= +It is possible to save and load different model classes using a single CQL table. +This is useful in situations where you have different object types that you want to store in a single cassandra row. + +For instance, suppose you want a table that stores rows of pets owned by an owner: + +.. code-block:: python + + class Pet(Model): + __table_name__ = 'pet' + owner_id = UUID(primary_key=True) + pet_id = UUID(primary_key=True) + pet_type = Text(discriminator_column=True) + name = Text() + + def eat(self, food): + pass + + def sleep(self, time): + pass + + class Cat(Pet): + __discriminator_value__ = 'cat' + cuteness = Float() + + def tear_up_couch(self): + pass + + class Dog(Pet): + __discriminator_value__ = 'dog' + fierceness = Float() + + def bark_all_night(self): + pass + +After calling ``sync_table`` on each of these tables, the columns defined in each model will be added to the +``pet`` table. Additionally, saving ``Cat`` and ``Dog`` models will save the meta data needed to identify each row +as either a cat or dog. + +To setup a model structure with inheritance, follow these steps + +1. Create a base model with a column set as the distriminator (``distriminator_column=True`` in the column definition) +2. Create subclass models, and define a unique ``__discriminator_value__`` value on each +3. Run ``sync_table`` on each of the sub tables + +**About the discriminator value** + +The discriminator value is what cqlengine uses under the covers to map logical cql rows to the appropriate model type. The +base model maintains a map of discriminator values to subclasses. When a specialized model is saved, its discriminator value is +automatically saved into the discriminator column. The discriminator column may be any column type except counter and container types. +Additionally, if you set ``index=True`` on your discriminator column, you can execute queries against specialized subclasses, and a +``WHERE`` clause will be automatically added to your query, returning only rows of that type. Note that you must +define a unique ``__discriminator_value__`` to each subclass, and that you can only assign a single discriminator column per model. + +.. _user_types: + +User Defined Types +================== +cqlengine models User Defined Types (UDTs) much like tables, with fields defined by column type attributes. However, UDT instances +are only created, presisted, and queried via table Models. A short example to introduce the pattern:: + + from cassandra.cqlengine.columns import * + from cassandra.cqlengine.models import Model + from cassandra.cqlengine.usertype import UserType + + class address(UserType): + street = Text() + zipcode = Integer() + + class users(Model): + __keyspace__ = 'account' + name = Text(primary_key=True) + addr = UserDefinedType(address) + + users.create(name="Joe", addr=address(street="Easy St.", zipcode=99999)) + user = users.objects(name="Joe")[0] + print(user.name, user.addr) + # Joe address(street=u'Easy St.', zipcode=99999) + +UDTs are modeled by inheriting :class:`~.usertype.UserType`, and setting column type attributes. Types are then used in defining +models by declaring a column of type :class:`~.columns.UserDefinedType`, with the ``UserType`` class as a parameter. + +``sync_table`` will implicitly +synchronize any types contained in the table. Alternatively :func:`~.management.sync_type` can be used to create/alter types +explicitly. + +Upon declaration, types are automatically registered with the driver, so query results return instances of your ``UserType`` +class*. + +***Note**: UDTs were not added to the native protocol until v3. When setting up the cqlengine connection, be sure to specify +``protocol_version=3``. If using an earlier version, UDT queries will still work, but the returned type will be a namedtuple. diff --git a/docs/cqlengine/queryset.rst b/docs/cqlengine/queryset.rst new file mode 100644 index 0000000000..fa99585141 --- /dev/null +++ b/docs/cqlengine/queryset.rst @@ -0,0 +1,419 @@ +============== +Making Queries +============== + +.. module:: cqlengine.queryset + +Retrieving objects +================== +Once you've populated Cassandra with data, you'll probably want to retrieve some of it. This is accomplished with QuerySet objects. This section will describe how to use QuerySet objects to retrieve the data you're looking for. + +Retrieving all objects +---------------------- +The simplest query you can make is to return all objects from a table. + +This is accomplished with the ``.all()`` method, which returns a QuerySet of all objects in a table + +Using the Person example model, we would get all Person objects like this: + +.. code-block:: python + + all_objects = Person.objects.all() + +.. _retrieving-objects-with-filters: + +Retrieving objects with filters +------------------------------- +Typically, you'll want to query only a subset of the records in your database. + +That can be accomplished with the QuerySet's ``.filter(\*\*)`` method. + +For example, given the model definition: + +.. code-block:: python + + class Automobile(Model): + manufacturer = columns.Text(primary_key=True) + year = columns.Integer(primary_key=True) + model = columns.Text() + price = columns.Decimal() + options = columns.Set(columns.Text) + +...and assuming the Automobile table contains a record of every car model manufactured in the last 20 years or so, we can retrieve only the cars made by a single manufacturer like this: + + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + +You can also use the more convenient syntax: + +.. code-block:: python + + q = Automobile.objects(Automobile.manufacturer == 'Tesla') + +We can then further filter our query with another call to **.filter** + +.. code-block:: python + + q = q.filter(year=2012) + +*Note: all queries involving any filtering MUST define either an '=' or an 'in' relation to either a primary key column, or an indexed column.* + +Accessing objects in a QuerySet +=============================== + +There are several methods for getting objects out of a queryset + +* iterating over the queryset + .. code-block:: python + + for car in Automobile.objects.all(): + #...do something to the car instance + pass + +* list index + .. code-block:: python + + q = Automobile.objects.all() + q[0] #returns the first result + q[1] #returns the second result + + .. note:: + + * CQL does not support specifying a start position in it's queries. Therefore, accessing elements using array indexing will load every result up to the index value requested + * Using negative indices requires a "SELECT COUNT()" to be executed. This has a performance cost on large datasets. + +* list slicing + .. code-block:: python + + q = Automobile.objects.all() + q[1:] #returns all results except the first + q[1:9] #returns a slice of the results + + .. note:: + + * CQL does not support specifying a start position in it's queries. Therefore, accessing elements using array slicing will load every result up to the index value requested + * Using negative indices requires a "SELECT COUNT()" to be executed. This has a performance cost on large datasets. + +* calling :attr:`get() ` on the queryset + .. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year=2012) + car = q.get() + + this returns the object matching the queryset + +* calling :attr:`first() ` on the queryset + .. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year=2012) + car = q.first() + + this returns the first value in the queryset + +.. _query-filtering-operators: + +Filtering Operators +=================== + +:attr:`Equal To ` + +The default filtering operator. + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year=2012) #year == 2012 + +In addition to simple equal to queries, cqlengine also supports querying with other operators by appending a ``__`` to the field name on the filtering call + +:attr:`in (__in) ` + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__in=[2011, 2012]) + + +:attr:`> (__gt) ` + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__gt=2010) # year > 2010 + + # or the nicer syntax + + q.filter(Automobile.year > 2010) + +:attr:`>= (__gte) ` + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__gte=2010) # year >= 2010 + + # or the nicer syntax + + q.filter(Automobile.year >= 2010) + +:attr:`< (__lt) ` + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__lt=2012) # year < 2012 + + # or... + + q.filter(Automobile.year < 2012) + +:attr:`<= (__lte) ` + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__lte=2012) # year <= 2012 + + q.filter(Automobile.year <= 2012) + +:attr:`CONTAINS (__contains) ` + +The CONTAINS operator is available for all collection types (List, Set, Map). + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q.filter(options__contains='backup camera').allow_filtering() + +Note that we need to use allow_filtering() since the *options* column has no secondary index. + +:attr:`LIKE (__like) ` + +The LIKE operator is available for text columns that have a SASI secondary index. + +.. code-block:: python + + q = Automobile.objects.filter(model__like='%Civic%').allow_filtering() + +:attr:`IS NOT NULL (IsNotNull(column_name)) ` + +The IS NOT NULL operator is not yet supported for C*. + +.. code-block:: python + + q = Automobile.objects.filter(IsNotNull('model')) + +Limitations: + +- Currently, cqlengine does not support SASI index creation. To use this feature, you need to create the SASI index using the core driver. +- Queries using LIKE must use allow_filtering() since the *model* column has no standard secondary index. Note that the server will use the SASI index properly when executing the query. + +TimeUUID Functions +================== + +In addition to querying using regular values, there are two functions you can pass in when querying TimeUUID columns to help make filtering by them easier. Note that these functions don't actually return a value, but instruct the cql interpreter to use the functions in it's query. + +.. class:: MinTimeUUID(datetime) + + returns the minimum time uuid value possible for the given datetime + +.. class:: MaxTimeUUID(datetime) + + returns the maximum time uuid value possible for the given datetime + +*Example* + +.. code-block:: python + + class DataStream(Model): + id = columns.UUID(partition_key=True) + time = columns.TimeUUID(primary_key=True) + data = columns.Bytes() + + min_time = datetime(1982, 1, 1) + max_time = datetime(1982, 3, 9) + + DataStream.filter(time__gt=functions.MinTimeUUID(min_time), time__lt=functions.MaxTimeUUID(max_time)) + +Token Function +============== + +Token functon may be used only on special, virtual column pk__token, representing token of partition key (it also works for composite partition keys). +Cassandra orders returned items by value of partition key token, so using cqlengine.Token we can easy paginate through all table rows. + +See http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun + +*Example* + +.. code-block:: python + + class Items(Model): + id = columns.Text(primary_key=True) + data = columns.Bytes() + + query = Items.objects.all().limit(10) + + first_page = list(query); + last = first_page[-1] + next_page = list(query.filter(pk__token__gt=cqlengine.Token(last.pk))) + +QuerySets are immutable +======================= + +When calling any method that changes a queryset, the method does not actually change the queryset object it's called on, but returns a new queryset object with the attributes of the original queryset, plus the attributes added in the method call. + +*Example* + +.. code-block:: python + + #this produces 3 different querysets + #q does not change after it's initial definition + q = Automobiles.objects.filter(year=2012) + tesla2012 = q.filter(manufacturer='Tesla') + honda2012 = q.filter(manufacturer='Honda') + +Ordering QuerySets +================== + +Since Cassandra is essentially a distributed hash table on steroids, the order you get records back in will not be particularly predictable. + +However, you can set a column to order on with the ``.order_by(column_name)`` method. + +*Example* + +.. code-block:: python + + #sort ascending + q = Automobiles.objects.all().order_by('year') + #sort descending + q = Automobiles.objects.all().order_by('-year') + +*Note: Cassandra only supports ordering on a clustering key. In other words, to support ordering results, your model must have more than one primary key, and you must order on a primary key, excluding the first one.* + +*For instance, given our Automobile model, year is the only column we can order on.* + +Values Lists +============ + +There is a special QuerySet's method ``.values_list()`` - when called, QuerySet returns lists of values instead of model instances. It may significantly speedup things with lower memory footprint for large responses. +Each tuple contains the value from the respective field passed into the ``values_list()`` call — so the first item is the first field, etc. For example: + +.. code-block:: python + + items = list(range(20)) + random.shuffle(items) + for i in items: + TestModel.create(id=1, clustering_key=i) + + values = list(TestModel.objects.values_list('clustering_key', flat=True)) + # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] + +Per Query Timeouts +=================== + +By default all queries are executed with the timeout defined in `~cqlengine.connection.setup()` +The examples below show how to specify a per-query timeout. +A timeout is specified in seconds and can be an int, float or None. +None means no timeout. + + +.. code-block:: python + + class Row(Model): + id = columns.Integer(primary_key=True) + name = columns.Text() + + +Fetch all objects with a timeout of 5 seconds + +.. code-block:: python + + Row.objects().timeout(5).all() + +Create a single row with a 50ms timeout + +.. code-block:: python + + Row(id=1, name='Jon').timeout(0.05).create() + +Delete a single row with no timeout + +.. code-block:: python + + Row(id=1).timeout(None).delete() + +Update a single row with no timeout + +.. code-block:: python + + Row(id=1).timeout(None).update(name='Blake') + +Batch query timeouts + +.. code-block:: python + + with BatchQuery(timeout=10) as b: + Row(id=1, name='Jon').create() + + +NOTE: You cannot set both timeout and batch at the same time, batch will use the timeout defined in it's constructor. +Setting the timeout on the model is meaningless and will raise an AssertionError. + + +.. _ttl-change: + +Default TTL and Per Query TTL +============================= + +Model default TTL now relies on the *default_time_to_live* feature, introduced in Cassandra 2.0. It is not handled anymore in the CQLEngine Model (cassandra-driver >=3.6). You can set the default TTL of a table like this: + +Example: + +.. code-block:: python + + class User(Model): + __options__ = {'default_time_to_live': 20} + + user_id = columns.UUID(primary_key=True) + ... + +You can set TTL per-query if needed. Here are a some examples: + +Example: + +.. code-block:: python + + class User(Model): + __options__ = {'default_time_to_live': 20} + + user_id = columns.UUID(primary_key=True) + ... + + user = User.objects.create(user_id=1) # Default TTL 20 will be set automatically on the server + + user.ttl(30).update(age=21) # Update the TTL to 30 + User.objects.ttl(10).create(user_id=1) # TTL 10 + User(user_id=1, age=21).ttl(10).save() # TTL 10 + + +Named Tables +=================== + +Named tables are a way of querying a table without creating an class. They're useful for querying system tables or exploring an unfamiliar database. + + +.. code-block:: python + + from cassandra.cqlengine.connection import setup + setup("127.0.0.1", "cqlengine_test") + + from cassandra.cqlengine.named import NamedTable + user = NamedTable("cqlengine_test", "user") + user.objects() + user.objects()[0] + + # {u'pk': 1, u't': datetime.datetime(2014, 6, 26, 17, 10, 31, 774000)} diff --git a/docs/cqlengine/third_party.rst b/docs/cqlengine/third_party.rst new file mode 100644 index 0000000000..20c26df304 --- /dev/null +++ b/docs/cqlengine/third_party.rst @@ -0,0 +1,64 @@ +======================== +Third party integrations +======================== + + +Celery +------ + +Here's how, in substance, CQLengine can be plugged to `Celery +`_: + +.. code-block:: python + + from celery import Celery + from celery.signals import worker_process_init, beat_init + from cassandra.cqlengine import connection + from cassandra.cqlengine.connection import ( + cluster as cql_cluster, session as cql_session) + + def cassandra_init(**kwargs): + """ Initialize a clean Cassandra connection. """ + if cql_cluster is not None: + cql_cluster.shutdown() + if cql_session is not None: + cql_session.shutdown() + connection.setup() + + # Initialize worker context for both standard and periodic tasks. + worker_process_init.connect(cassandra_init) + beat_init.connect(cassandra_init) + + app = Celery() + + +uWSGI +----- + +This is the code required for proper connection handling of CQLengine for a +`uWSGI `_-run application: + +.. code-block:: python + + from cassandra.cqlengine import connection + from cassandra.cqlengine.connection import ( + cluster as cql_cluster, session as cql_session) + + try: + from uwsgidecorators import postfork + except ImportError: + # We're not in a uWSGI context, no need to hook Cassandra session + # initialization to the postfork event. + pass + else: + @postfork + def cassandra_init(**kwargs): + """ Initialize a new Cassandra session in the context. + + Ensures that a new session is returned for every new request. + """ + if cql_cluster is not None: + cql_cluster.shutdown() + if cql_session is not None: + cql_session.shutdown() + connection.setup() diff --git a/docs/cqlengine/upgrade_guide.rst b/docs/cqlengine/upgrade_guide.rst new file mode 100644 index 0000000000..5b0ab39360 --- /dev/null +++ b/docs/cqlengine/upgrade_guide.rst @@ -0,0 +1,155 @@ +======================== +Upgrade Guide +======================== + +This is an overview of things that changed as the cqlengine project was merged into +cassandra-driver. While efforts were taken to preserve the API and most functionality exactly, +conversion to this package will still require certain minimal updates (namely, imports). + +**THERE IS ONE FUNCTIONAL CHANGE**, described in the first section below. + +Functional Changes +================== +List Prepend Reversing +---------------------- +Legacy cqlengine included a workaround for a Cassandra bug in which prepended list segments were +reversed (`CASSANDRA-8733 `_). As of +this integration, this workaround is removed. The first released integrated version emits +a warning when prepend is used. Subsequent versions will have this warning removed. + +Date Column Type +---------------- +The Date column type in legacy cqlengine used a ``timestamp`` CQL type and truncated the time. +Going forward, the :class:`~.columns.Date` type represents a ``date`` for Cassandra 2.2+ +(`PYTHON-245 `_). +Users of the legacy functionality should convert models to use :class:`~.columns.DateTime` (which +uses ``timestamp`` internally), and use the build-in ``datetime.date`` for input values. + +Remove cqlengine +================ +To avoid confusion or mistakes using the legacy package in your application, it +is prudent to remove the cqlengine package when upgrading to the integrated version. + +The driver setup script will warn if the legacy package is detected during install, +but it will not prevent side-by-side installation. + +Organization +============ +Imports +------- +cqlengine is now integrated as a sub-package of the driver base package 'cassandra'. +Upgrading will require adjusting imports to cqlengine. For example:: + + from cassandra.cqlengine import columns + +is now:: + + from cassandra.cqlengine import columns + +Package-Level Aliases +--------------------- +Legacy cqlengine defined a number of aliases at the package level, which became redundant +when the package was integrated for a driver. These have been removed in favor of absolute +imports, and referring to cannonical definitions. For example, ``cqlengine.ONE`` was an alias +of ``cassandra.ConsistencyLevel.ONE``. In the integrated package, only the +:class:`cassandra.ConsistencyLevel` remains. + +Additionally, submodule aliases are removed from cqlengine in favor of absolute imports. + +These aliases are removed, and not deprecated because they should be straightforward to iron out +at module load time. + +Exceptions +---------- +The legacy cqlengine.exceptions module had a number of Exception classes that were variously +common to the package, or only used in specific modules. Common exceptions were relocated to +cqlengine, and specialized exceptions were placed in the module that raises them. Below is a +listing of updated locations: + +============================ ========== +Exception class New module +============================ ========== +CQLEngineException cassandra.cqlengine +ModelException cassandra.cqlengine.models +ValidationError cassandra.cqlengine +UndefinedKeyspaceException cassandra.cqlengine.connection +LWTException cassandra.cqlengine.query +IfNotExistsWithCounterColumn cassandra.cqlengine.query +============================ ========== + +UnicodeMixin Consolidation +-------------------------- +``class UnicodeMixin`` was defined in several cqlengine modules. This has been consolidated +to a single definition in the cqlengine package init file. This is not technically part of +the API, but noted here for completeness. + +API Deprecations +================ +This upgrade served as a good juncture to deprecate certain API features and invite users to upgrade +to new ones. The first released version does not change functionality -- only introduces deprecation +warnings. Future releases will remove these features in favor of the alternatives. + +Float/Double Overload +--------------------- +Previously there was no ``Double`` column type. Doubles were modeled by specifying ``Float(double_precision=True)``. +This inititializer parameter is now deprecated. Applications should use :class:`~.columns.Double` for CQL ``double``, and :class:`~.columns.Float` +for CQL ``float``. + +Schema Management +----------------- +``cassandra.cqlengine.management.create_keyspace`` is deprecated. Instead, use the new replication-strategy-specific +functions that accept explicit options for known strategies: + +- :func:`~.create_keyspace_simple` +- :func:`~.create_keyspace_network_topology` + +``cassandra.cqlengine.management.delete_keyspace`` is deprecated in favor of a new function, :func:`~.drop_keyspace`. The +intent is simply to make the function match the CQL verb it invokes. + +Model Inheritance +----------------- +The names for class attributes controlling model inheritance are changing. Changes are as follows: + +- Replace 'polymorphic_key' in the base class Column definition with :attr:`~.discriminator_column` +- Replace the '__polymorphic_key__' class attribute the derived classes with :attr:`~.__discriminator_value__` + +The functionality is unchanged -- the intent here is to make the names and language around these attributes more precise. +For now, the old names are just deprecated, and the mapper will emit warnings if they are used. The old names +will be removed in a future version. + +The example below shows a simple translation: + +Before:: + + class Pet(Model): + __table_name__ = 'pet' + owner_id = UUID(primary_key=True) + pet_id = UUID(primary_key=True) + pet_type = Text(polymorphic_key=True) + name = Text() + + class Cat(Pet): + __polymorphic_key__ = 'cat' + + class Dog(Pet): + __polymorphic_key__ = 'dog' + +After:: + + class Pet(models.Model): + __table_name__ = 'pet' + owner_id = UUID(primary_key=True) + pet_id = UUID(primary_key=True) + pet_type = Text(discriminator_column=True) + name = Text() + + class Cat(Pet): + __discriminator_value__ = 'cat' + + class Dog(Pet): + __discriminator_value__ = 'dog' + + +TimeUUID.from_datetime +---------------------- +This function is deprecated in favor of the core utility function :func:`~.uuid_from_time`. diff --git a/docs/dates_and_times.rst b/docs/dates_and_times.rst new file mode 100644 index 0000000000..7a89f77437 --- /dev/null +++ b/docs/dates_and_times.rst @@ -0,0 +1,87 @@ +Working with Dates and Times +============================ + +This document is meant to provide on overview of the assumptions and limitations of the driver time handling, the +reasoning behind it, and describe approaches to working with these types. + +timestamps (Cassandra DateType) +------------------------------- + +Timestamps in Cassandra are timezone-naive timestamps encoded as millseconds since UNIX epoch. Clients working with +timestamps in this database usually find it easiest to reason about them if they are always assumed to be UTC. To quote the +pytz documentation, "The preferred way of dealing with times is to always work in UTC, converting to localtime only when +generating output to be read by humans." The driver adheres to this tenant, and assumes UTC is always in the database. The +driver attempts to make this correct on the way in, and assumes no timezone on the way out. + +Write Path +~~~~~~~~~~ +When inserting timestamps, the driver handles serialization for the write path as follows: + +If the input is a ``datetime.datetime``, the serialization is normalized by starting with the ``utctimetuple()`` of the +value. + +- If the ``datetime`` object is timezone-aware, the timestamp is shifted, and represents the UTC timestamp equivalent. +- If the ``datetime`` object is timezone-naive, this results in no shift -- any ``datetime`` with no timezone information is assumed to be UTC + +Note the second point above applies even to "local" times created using ``now()``:: + + >>> d = datetime.now() + + >>> print(d.tzinfo) + None + + +These do not contain timezone information intrinsically, so they will be assumed to be UTC and not shifted. When generating +timestamps in the application, it is clearer to use ``datetime.utcnow()`` to be explicit about it. + +If the input for a timestamp is numeric, it is assumed to be a epoch-relative millisecond timestamp, as specified in the +CQL spec -- no scaling or conversion is done. + +Read Path +~~~~~~~~~ +The driver always assumes persisted timestamps are UTC and makes no attempt to localize them. Returned values are +timezone-naive ``datetime.datetime``. We follow this approach because the datetime API has deficiencies around daylight +saving time, and the defacto package for handling this is a third-party package (we try to minimize external dependencies +and not make decisions for the integrator). + +The decision for how to handle timezones is left to the application. For the most part it is straightforward to apply +localization to the ``datetime``\s returned by queries. One prevalent method is to use pytz for localization:: + + import pytz + user_tz = pytz.timezone('US/Central') + timestamp_naive = row.ts + timestamp_utc = pytz.utc.localize(timestamp_naive) + timestamp_presented = timestamp_utc.astimezone(user_tz) + +This is the most robust approach (likely refactored into a function). If it is deemed too cumbersome to apply for all call +sites in the application, it is possible to patch the driver with custom deserialization for this type. However, doing +this depends depends some on internal APIs and what extensions are present, so we will only mention the possibility, and +not spell it out here. + +date, time (Cassandra DateType) +------------------------------- +Date and time in Cassandra are idealized markers, much like ``datetime.date`` and ``datetime.time`` in the Python standard +library. Unlike these Python implementations, the Cassandra encoding supports much wider ranges. To accommodate these +ranges without overflow, this driver returns these data in custom types: :class:`.util.Date` and :class:`.util.Time`. + +Write Path +~~~~~~~~~~ +For simple (not prepared) statements, the input values for each of these can be either a string literal or an encoded +integer. See `Working with dates `_ +or `Working with time `_ for details +on the encoding or string formats. + +For prepared statements, the driver accepts anything that can be used to construct the :class:`.util.Date` or +:class:`.util.Time` classes. See the linked API docs for details. + +Read Path +~~~~~~~~~ +The driver always returns custom types for ``date`` and ``time``. + +The driver returns :class:`.util.Date` for ``date`` in order to accommodate the wider range of values without overflow. +For applications working within the supported range of [``datetime.MINYEAR``, ``datetime.MAXYEAR``], these are easily +converted to standard ``datetime.date`` insances using :meth:`.Date.date`. + +The driver returns :class:`.util.Time` for ``time`` in order to retain nanosecond precision stored in the database. +For applications not concerned with this level of precision, these are easily converted to standard ``datetime.time`` +insances using :meth:`.Time.time`. diff --git a/docs/execution_profiles.rst b/docs/execution_profiles.rst new file mode 100644 index 0000000000..0965d77f3d --- /dev/null +++ b/docs/execution_profiles.rst @@ -0,0 +1,156 @@ +Execution Profiles +================== + +Execution profiles aim at making it easier to execute requests in different ways within +a single connected ``Session``. Execution profiles are being introduced to deal with the exploding number of +configuration options, especially as the database platform evolves more complex workloads. + +The legacy configuration remains intact, but legacy and Execution Profile APIs +cannot be used simultaneously on the same client ``Cluster``. Legacy configuration +will be removed in the next major release (4.0). + +An execution profile and its parameters should be unique across ``Cluster`` instances. +For example, an execution profile and its ``LoadBalancingPolicy`` should +not be applied to more than one ``Cluster`` instance. + +This document explains how Execution Profiles relate to existing settings, and shows how to use the new profiles for +request execution. + +Mapping Legacy Parameters to Profiles +------------------------------------- + +Execution profiles can inherit from :class:`.cluster.ExecutionProfile`, and currently provide the following options, +previously input from the noted attributes: + +- load_balancing_policy - :attr:`.Cluster.load_balancing_policy` +- request_timeout - :attr:`.Session.default_timeout`, optional :meth:`.Session.execute` parameter +- retry_policy - :attr:`.Cluster.default_retry_policy`, optional :attr:`.Statement.retry_policy` attribute +- consistency_level - :attr:`.Session.default_consistency_level`, optional :attr:`.Statement.consistency_level` attribute +- serial_consistency_level - :attr:`.Session.default_serial_consistency_level`, optional :attr:`.Statement.serial_consistency_level` attribute +- row_factory - :attr:`.Session.row_factory` attribute + +When using the new API, these parameters can be defined by instances of :class:`.cluster.ExecutionProfile`. + +Using Execution Profiles +------------------------ +Default +~~~~~~~ + +.. code:: python + + from cassandra.cluster import Cluster + cluster = Cluster() + session = cluster.connect() + local_query = 'SELECT rpc_address FROM system.local' + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query)[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.2') + Row(rpc_address='127.0.0.1') + + +The default execution profile is built from Cluster parameters and default Session attributes. This profile matches existing default +parameters. + +Initializing cluster with profiles +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + from cassandra.cluster import ExecutionProfile + from cassandra.policies import WhiteListRoundRobinPolicy + + node1_profile = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + node2_profile = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.2'])) + + profiles = {'node1': node1_profile, 'node2': node2_profile} + session = Cluster(execution_profiles=profiles).connect() + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query, execution_profile='node1')[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.1') + Row(rpc_address='127.0.0.1') + + +.. code:: python + + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query, execution_profile='node2')[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.2') + Row(rpc_address='127.0.0.2') + + +.. code:: python + + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query)[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.2') + Row(rpc_address='127.0.0.1') + +Note that, even when custom profiles are injected, the default ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` is still +present. To override the default, specify a policy with the :data:`~.cluster.EXEC_PROFILE_DEFAULT` key. + +.. code:: python + + from cassandra.cluster import EXEC_PROFILE_DEFAULT + profile = ExecutionProfile(request_timeout=30) + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: profile}) + + +Adding named profiles +~~~~~~~~~~~~~~~~~~~~~ + +New profiles can be added constructing from scratch, or deriving from default: + +.. code:: python + + locked_execution = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + node1_profile = 'node1_whitelist' + cluster.add_execution_profile(node1_profile, locked_execution) + + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query, execution_profile=node1_profile)[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.1') + Row(rpc_address='127.0.0.1') + +See :meth:`.Cluster.add_execution_profile` for details and optional parameters. + +Passing a profile instance without mapping +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We also have the ability to pass profile instances to be used for execution, but not added to the mapping: + +.. code:: python + + from cassandra.query import tuple_factory + + tmp = session.execution_profile_clone_update('node1', request_timeout=100, row_factory=tuple_factory) + + print(session.execute(local_query, execution_profile=tmp)[0]) + print(session.execute(local_query, execution_profile='node1')[0]) + +.. parsed-literal:: + + ('127.0.0.1',) + Row(rpc_address='127.0.0.1') + +The new profile is a shallow copy, so the ``tmp`` profile shares a load balancing policy with one managed by the cluster. +If reference objects are to be updated in the clone, one would typically set those attributes to a new instance. diff --git a/docs/faq.rst b/docs/faq.rst new file mode 100644 index 0000000000..194d5520e8 --- /dev/null +++ b/docs/faq.rst @@ -0,0 +1,83 @@ +Frequently Asked Questions +========================== + +See also :doc:`cqlengine FAQ ` + +Why do connections or IO operations timeout in my WSGI application? +------------------------------------------------------------------- +Depending on your application process model, it may be forking after driver Session is created. Most IO reactors do not handle this, and problems will manifest as timeouts. + +To avoid this, make sure to create sessions per process, after the fork. Using uWSGI and Flask for example: + +.. code-block:: python + + from flask import Flask + from uwsgidecorators import postfork + from cassandra.cluster import Cluster + + session = None + prepared = None + + @postfork + def connect(): + global session, prepared + session = Cluster().connect() + prepared = session.prepare("SELECT release_version FROM system.local WHERE key=?") + + app = Flask(__name__) + + @app.route('/') + def server_version(): + row = session.execute(prepared, ('local',))[0] + return row.release_version + +uWSGI provides a ``postfork`` hook you can use to create sessions and prepared statements after the child process forks. + +How do I trace a request? +------------------------- +Request tracing can be turned on for any request by setting ``trace=True`` in :meth:`.Session.execute_async`. View the results by waiting on the future, then :meth:`.ResponseFuture.get_query_trace`. +Since tracing is done asynchronously to the request, this method polls until the trace is complete before querying data. + +.. code-block:: python + + >>> future = session.execute_async("SELECT * FROM system.local", trace=True) + >>> result = future.result() + >>> trace = future.get_query_trace() + >>> for e in trace.events: + >>> print(e.source_elapsed, e.description) + + 0:00:00.000077 Parsing select * from system.local + 0:00:00.000153 Preparing statement + 0:00:00.000309 Computing ranges to query + 0:00:00.000368 Submitting range requests on 1 ranges with a concurrency of 1 (279.77142 rows per range expected) + 0:00:00.000422 Submitted 1 concurrent range requests covering 1 ranges + 0:00:00.000480 Executing seq scan across 1 sstables for (min(-9223372036854775808), min(-9223372036854775808)) + 0:00:00.000669 Read 1 live and 0 tombstone cells + 0:00:00.000755 Scanned 1 rows and matched 1 + +``trace`` is a :class:`QueryTrace` object. + +How do I determine the replicas for a query? +---------------------------------------------- +With prepared statements, the replicas are obtained by ``routing_key``, based on current cluster token metadata: + +.. code-block:: python + + >>> prepared = session.prepare("SELECT * FROM example.t WHERE key=?") + >>> bound = prepared.bind((1,)) + >>> replicas = cluster.metadata.get_replicas(bound.keyspace, bound.routing_key) + >>> for h in replicas: + >>> print(h.address) + 127.0.0.1 + 127.0.0.2 + +``replicas`` is a list of :class:`Host` objects. + +How does the driver manage request retries? +------------------------------------------- +By default, retries are managed by the :attr:`.Cluster.default_retry_policy` set on the session Cluster. It can also +be specialized per statement by setting :attr:`.Statement.retry_policy`. + +Retries are presently attempted on the same coordinator, but this may change in the future. + +Please see :class:`.policies.RetryPolicy` for further details. diff --git a/docs/geo_types.rst b/docs/geo_types.rst new file mode 100644 index 0000000000..f8750d687c --- /dev/null +++ b/docs/geo_types.rst @@ -0,0 +1,39 @@ +DSE Geometry Types +================== +This section shows how to query and work with the geometric types provided by DSE. + +These types are enabled implicitly by creating the Session from :class:`cassandra.cluster.Cluster`. +This module implicitly registers these types for use in the driver. This extension provides +some simple representative types in :mod:`cassandra.util` for inserting and retrieving data:: + + from cassandra.cluster import Cluster + from cassandra.util import Point, LineString, Polygon + session = Cluster().connect() + + session.execute("INSERT INTO ks.geo (k, point, line, poly) VALUES (%s, %s, %s, %s)", + 0, Point(1, 2), LineString(((1, 2), (3, 4))), Polygon(((1, 2), (3, 4), (5, 6)))) + +Queries returning geometric types return the :mod:`dse.util` types. Note that these can easily be used to construct +types from third-party libraries using the common attributes:: + + from shapely.geometry import LineString + shapely_linestrings = [LineString(res.line.coords) for res in session.execute("SELECT line FROM ks.geo")] + +For prepared statements, shapely geometry types can be used interchangeably with the built-in types because their +defining attributes are the same:: + + from shapely.geometry import Point + prepared = session.prepare("UPDATE ks.geo SET point = ? WHERE k = ?") + session.execute(prepared, (0, Point(1.2, 3.4))) + +In order to use shapely types in a CQL-interpolated (non-prepared) query, one must update the encoder with those types, specifying +the same string encoder as set for the internal types:: + + from cassandra import util + from shapely.geometry import Point, LineString, Polygon + + encoder_func = session.encoder.mapping[util.Point] + for t in (Point, LineString, Polygon): + session.encoder.mapping[t] = encoder_func + + session.execute("UPDATE ks.geo SET point = %s where k = %s", (0, Point(1.2, 3.4))) diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000000..432e42ec4f --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,502 @@ +Getting Started +=============== + +First, make sure you have the driver properly :doc:`installed `. + +Connecting to a Cluster +----------------------- +Before we can start executing any queries against a Cassandra cluster we need to setup +an instance of :class:`~.Cluster`. As the name suggests, you will typically have one +instance of :class:`~.Cluster` for each Cassandra cluster you want to interact +with. + +First, make sure you have the Cassandra driver properly :doc:`installed `. + +Connecting to Astra ++++++++++++++++++++ + +If you are a DataStax `Astra `_ user, +here is how to connect to your cluster: + +1. Download the secure connect bundle from your Astra account. +2. Connect to your cluster with + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip' + } + auth_provider = PlainTextAuthProvider(username='user', password='pass') + cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider) + session = cluster.connect() + +See `Astra `_ and :doc:`cloud` for more details. + +Connecting to Cassandra ++++++++++++++++++++++++ +The simplest way to create a :class:`~.Cluster` is like this: + +.. code-block:: python + + from cassandra.cluster import Cluster + + cluster = Cluster() + +This will attempt to connection to a Cassandra instance on your +local machine (127.0.0.1). You can also specify a list of IP +addresses for nodes in your cluster: + +.. code-block:: python + + from cassandra.cluster import Cluster + + cluster = Cluster(['192.168.0.1', '192.168.0.2']) + +The set of IP addresses we pass to the :class:`~.Cluster` is simply +an initial set of contact points. After the driver connects to one +of these nodes it will *automatically discover* the rest of the +nodes in the cluster and connect to them, so you don't need to list +every node in your cluster. + +If you need to use a non-standard port, use SSL, or customize the driver's +behavior in some other way, this is the place to do it: + +.. code-block:: python + + from cassandra.cluster import Cluster + cluster = Cluster(['192.168.0.1', '192.168.0.2'], port=..., ssl_context=...) + +Instantiating a :class:`~.Cluster` does not actually connect us to any nodes. +To establish connections and begin executing queries we need a +:class:`~.Session`, which is created by calling :meth:`.Cluster.connect()`: + +.. code-block:: python + + cluster = Cluster() + session = cluster.connect() + +Session Keyspace +---------------- +The :meth:`~.Cluster.connect()` method takes an optional ``keyspace`` argument +which sets the default keyspace for all queries made through that :class:`~.Session`: + +.. code-block:: python + + cluster = Cluster() + session = cluster.connect('mykeyspace') + +You can always change a Session's keyspace using :meth:`~.Session.set_keyspace` or +by executing a ``USE `` query: + +.. code-block:: python + + session.set_keyspace('users') + # or you can do this instead + session.execute('USE users') + +Execution Profiles +------------------ +Profiles are passed in by ``execution_profiles`` dict. + +In this case we can construct the base ``ExecutionProfile`` passing all attributes: + +.. code-block:: python + + from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT + from cassandra.policies import WhiteListRoundRobinPolicy, DowngradingConsistencyRetryPolicy + from cassandra.query import tuple_factory + + profile = ExecutionProfile( + load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1']), + retry_policy=DowngradingConsistencyRetryPolicy(), + consistency_level=ConsistencyLevel.LOCAL_QUORUM, + serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL, + request_timeout=15, + row_factory=tuple_factory + ) + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: profile}) + session = cluster.connect() + + print(session.execute("SELECT release_version FROM system.local").one()) + +Users are free to setup additional profiles to be used by name: + +.. code-block:: python + + profile_long = ExecutionProfile(request_timeout=30) + cluster = Cluster(execution_profiles={'long': profile_long}) + session = cluster.connect() + session.execute(statement, execution_profile='long') + +Also, parameters passed to ``Session.execute`` or attached to ``Statement``\s are still honored as before. + +Executing Queries +----------------- +Now that we have a :class:`.Session` we can begin to execute queries. The simplest +way to execute a query is to use :meth:`~.Session.execute()`: + +.. code-block:: python + + rows = session.execute('SELECT name, age, email FROM users') + for user_row in rows: + print(user_row.name, user_row.age, user_row.email) + +This will transparently pick a Cassandra node to execute the query against +and handle any retries that are necessary if the operation fails. + +By default, each row in the result set will be a +`namedtuple `_. +Each row will have a matching attribute for each column defined in the schema, +such as ``name``, ``age``, and so on. You can also treat them as normal tuples +by unpacking them or accessing fields by position. The following three +examples are equivalent: + +.. code-block:: python + + rows = session.execute('SELECT name, age, email FROM users') + for row in rows: + print(row.name, row.age, row.email) + +.. code-block:: python + + rows = session.execute('SELECT name, age, email FROM users') + for (name, age, email) in rows: + print(name, age, email) + +.. code-block:: python + + rows = session.execute('SELECT name, age, email FROM users') + for row in rows: + print(row[0], row[1], row[2]) + +If you prefer another result format, such as a ``dict`` per row, you +can change the :attr:`~.Session.row_factory` attribute. + +As mentioned in our `Drivers Best Practices Guide `_, +it is highly recommended to use `Prepared statements <#prepared-statement>`_ for your +frequently run queries. + +.. _prepared-statement: + +Prepared Statements +------------------- +Prepared statements are queries that are parsed by Cassandra and then saved +for later use. When the driver uses a prepared statement, it only needs to +send the values of parameters to bind. This lowers network traffic +and CPU utilization within Cassandra because Cassandra does not have to +re-parse the query each time. + +To prepare a query, use :meth:`.Session.prepare()`: + +.. code-block:: python + + user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?") + + users = [] + for user_id in user_ids_to_query: + user = session.execute(user_lookup_stmt, [user_id]) + users.append(user) + +:meth:`~.Session.prepare()` returns a :class:`~.PreparedStatement` instance +which can be used in place of :class:`~.SimpleStatement` instances or literal +string queries. It is automatically prepared against all nodes, and the driver +handles re-preparing against new nodes and restarted nodes when necessary. + +Note that the placeholders for prepared statements are ``?`` characters. This +is different than for simple, non-prepared statements (although future versions +of the driver may use the same placeholders for both). + +Passing Parameters to CQL Queries +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Althought it is not recommended, you can also pass parameters to non-prepared +statements. The driver supports two forms of parameter place-holders: positional +and named. + +Positional parameters are used with a ``%s`` placeholder. For example, +when you execute: + +.. code-block:: python + + session.execute( + """ + INSERT INTO users (name, credits, user_id) + VALUES (%s, %s, %s) + """, + ("John O'Reilly", 42, uuid.uuid1()) + ) + +It is translated to the following CQL query:: + + INSERT INTO users (name, credits, user_id) + VALUES ('John O''Reilly', 42, 2644bada-852c-11e3-89fb-e0b9a54a6d93) + +Note that you should use ``%s`` for all types of arguments, not just strings. +For example, this would be **wrong**: + +.. code-block:: python + + session.execute("INSERT INTO USERS (name, age) VALUES (%s, %d)", ("bob", 42)) # wrong + +Instead, use ``%s`` for the age placeholder. + +If you need to use a literal ``%`` character, use ``%%``. + +**Note**: you must always use a sequence for the second argument, even if you are +only passing in a single variable: + +.. code-block:: python + + session.execute("INSERT INTO foo (bar) VALUES (%s)", "blah") # wrong + session.execute("INSERT INTO foo (bar) VALUES (%s)", ("blah")) # wrong + session.execute("INSERT INTO foo (bar) VALUES (%s)", ("blah", )) # right + session.execute("INSERT INTO foo (bar) VALUES (%s)", ["blah"]) # right + + +Note that the second line is incorrect because in Python, single-element tuples +require a comma. + +Named place-holders use the ``%(name)s`` form: + +.. code-block:: python + + session.execute( + """ + INSERT INTO users (name, credits, user_id, username) + VALUES (%(name)s, %(credits)s, %(user_id)s, %(name)s) + """, + {'name': "John O'Reilly", 'credits': 42, 'user_id': uuid.uuid1()} + ) + +Note that you can repeat placeholders with the same name, such as ``%(name)s`` +in the above example. + +Only data values should be supplied this way. Other items, such as keyspaces, +table names, and column names should be set ahead of time (typically using +normal string formatting). + +.. _type-conversions: + +Type Conversions +^^^^^^^^^^^^^^^^ +For non-prepared statements, Python types are cast to CQL literals in the +following way: + +.. table:: + + +--------------------+-------------------------+ + | Python Type | CQL Literal Type | + +====================+=========================+ + | ``None`` | ``NULL`` | + +--------------------+-------------------------+ + | ``bool`` | ``boolean`` | + +--------------------+-------------------------+ + | ``float`` | | ``float`` | + | | | ``double`` | + +--------------------+-------------------------+ + | | ``int`` | | ``int`` | + | | ``long`` | | ``bigint`` | + | | | ``varint`` | + | | | ``smallint`` | + | | | ``tinyint`` | + | | | ``counter`` | + +--------------------+-------------------------+ + | ``decimal.Decimal``| ``decimal`` | + +--------------------+-------------------------+ + | | ``str`` | | ``ascii`` | + | | ``unicode`` | | ``varchar`` | + | | | ``text`` | + +--------------------+-------------------------+ + | | ``buffer`` | ``blob`` | + | | ``bytearray`` | | + +--------------------+-------------------------+ + | ``date`` | ``date`` | + +--------------------+-------------------------+ + | ``datetime`` | ``timestamp`` | + +--------------------+-------------------------+ + | ``time`` | ``time`` | + +--------------------+-------------------------+ + | | ``list`` | ``list`` | + | | ``tuple`` | | + | | generator | | + +--------------------+-------------------------+ + | | ``set`` | ``set`` | + | | ``frozenset`` | | + +--------------------+-------------------------+ + | | ``dict`` | ``map`` | + | | ``OrderedDict`` | | + +--------------------+-------------------------+ + | ``uuid.UUID`` | | ``timeuuid`` | + | | | ``uuid`` | + +--------------------+-------------------------+ + + +Asynchronous Queries +^^^^^^^^^^^^^^^^^^^^ +The driver supports asynchronous query execution through +:meth:`~.Session.execute_async()`. Instead of waiting for the query to +complete and returning rows directly, this method almost immediately +returns a :class:`~.ResponseFuture` object. There are two ways of +getting the final result from this object. + +The first is by calling :meth:`~.ResponseFuture.result()` on it. If +the query has not yet completed, this will block until it has and +then return the result or raise an Exception if an error occurred. +For example: + +.. code-block:: python + + from cassandra import ReadTimeout + + query = "SELECT * FROM users WHERE user_id=%s" + future = session.execute_async(query, [user_id]) + + # ... do some other work + + try: + rows = future.result() + user = rows[0] + print(user.name, user.age) + except ReadTimeout: + log.exception("Query timed out:") + +This works well for executing many queries concurrently: + +.. code-block:: python + + # build a list of futures + futures = [] + query = "SELECT * FROM users WHERE user_id=%s" + for user_id in ids_to_fetch: + futures.append(session.execute_async(query, [user_id]) + + # wait for them to complete and use the results + for future in futures: + rows = future.result() + print(rows[0].name) + +Alternatively, instead of calling :meth:`~.ResponseFuture.result()`, +you can attach callback and errback functions through the +:meth:`~.ResponseFuture.add_callback()`, +:meth:`~.ResponseFuture.add_errback()`, and +:meth:`~.ResponseFuture.add_callbacks()`, methods. If you have used +Twisted Python before, this is designed to be a lightweight version of +that: + +.. code-block:: python + + def handle_success(rows): + user = rows[0] + try: + process_user(user.name, user.age, user.id) + except Exception: + log.error("Failed to process user %s", user.id) + # don't re-raise errors in the callback + + def handle_error(exception): + log.error("Failed to fetch user info: %s", exception) + + + future = session.execute_async(query) + future.add_callbacks(handle_success, handle_error) + +There are a few important things to remember when working with callbacks: + * **Exceptions that are raised inside the callback functions will be logged and then ignored.** + * Your callback will be run on the event loop thread, so any long-running + operations will prevent other requests from being handled + + +Setting a Consistency Level +--------------------------- +The consistency level used for a query determines how many of the +replicas of the data you are interacting with need to respond for +the query to be considered a success. + +By default, :attr:`.ConsistencyLevel.LOCAL_ONE` will be used for all queries. +You can specify a different default by setting the :attr:`.ExecutionProfile.consistency_level` +for the execution profile with key :data:`~.cluster.EXEC_PROFILE_DEFAULT`. +To specify a different consistency level per request, wrap queries +in a :class:`~.SimpleStatement`: + +.. code-block:: python + + from cassandra import ConsistencyLevel + from cassandra.query import SimpleStatement + + query = SimpleStatement( + "INSERT INTO users (name, age) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + session.execute(query, ('John', 42)) + +Setting a Consistency Level with Prepared Statements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +To specify a consistency level for prepared statements, you have two options. + +The first is to set a default consistency level for every execution of the +prepared statement: + +.. code-block:: python + + from cassandra import ConsistencyLevel + + cluster = Cluster() + session = cluster.connect("mykeyspace") + user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?") + user_lookup_stmt.consistency_level = ConsistencyLevel.QUORUM + + # these will both use QUORUM + user1 = session.execute(user_lookup_stmt, [user_id1])[0] + user2 = session.execute(user_lookup_stmt, [user_id2])[0] + +The second option is to create a :class:`~.BoundStatement` from the +:class:`~.PreparedStatement` and binding parameters and set a consistency +level on that: + +.. code-block:: python + + # override the QUORUM default + user3_lookup = user_lookup_stmt.bind([user_id3]) + user3_lookup.consistency_level = ConsistencyLevel.ALL + user3 = session.execute(user3_lookup) + +Speculative Execution +^^^^^^^^^^^^^^^^^^^^^ + +Speculative execution is a way to minimize latency by preemptively executing several +instances of the same query against different nodes. For more details about this +technique, see `Speculative Execution with DataStax Drivers `_. + +To enable speculative execution: + +* Configure a :class:`~.policies.SpeculativeExecutionPolicy` with the ExecutionProfile +* Mark your query as idempotent, which mean it can be applied multiple + times without changing the result of the initial application. + See `Query Idempotence `_ for more details. + + +Example: + +.. code-block:: python + + from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT + from cassandra.policies import ConstantSpeculativeExecutionPolicy + from cassandra.query import SimpleStatement + + # Configure the speculative execution policy + ep = ExecutionProfile( + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(delay=.5, max_attempts=10) + ) + cluster = Cluster(..., execution_profiles={EXEC_PROFILE_DEFAULT: ep}) + session = cluster.connect() + + # Mark the query idempotent + query = SimpleStatement( + "UPDATE my_table SET list_col = [1] WHERE pk = 1", + is_idempotent=True + ) + + # Execute. A new query will be sent to the server every 0.5 second + # until we receive a response, for a max number attempts of 10. + session.execute(query) diff --git a/docs/graph.rst b/docs/graph.rst new file mode 100644 index 0000000000..47dc53d38d --- /dev/null +++ b/docs/graph.rst @@ -0,0 +1,434 @@ +DataStax Graph Queries +====================== + +The driver executes graph queries over the Cassandra native protocol. Use +:meth:`.Session.execute_graph` or :meth:`.Session.execute_graph_async` for +executing gremlin queries in DataStax Graph. + +The driver defines three Execution Profiles suitable for graph execution: + +* :data:`~.cluster.EXEC_PROFILE_GRAPH_DEFAULT` +* :data:`~.cluster.EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT` +* :data:`~.cluster.EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT` + +See :doc:`getting_started` and :doc:`execution_profiles` +for more detail on working with profiles. + +In DSE 6.8.0, the Core graph engine has been introduced and is now the default. It +provides a better unified multi-model, performance and scale. This guide +is for graphs that use the core engine. If you work with previous versions of +DSE or existing graphs, see :doc:`classic_graph`. + +Getting Started with Graph and the Core Engine +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, we need to create a graph in the system. To access the system API, we +use the system execution profile :: + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + + cluster = Cluster() + session = cluster.connect() + + graph_name = 'movies' + session.execute_graph("system.graph(name).create()", {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + + +Graphs that use the core engine only support GraphSON3. Since they are Cassandra tables under +the hood, we can automatically configure the execution profile with the proper options +(row_factory and graph_protocol) when executing queries. You only need to make sure that +the `graph_name` is set and GraphSON3 will be automatically used:: + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + + graph_name = 'movies' + ep = GraphExecutionProfile(graph_options=GraphOptions(graph_name=graph_name)) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + session.execute_graph("g.addV(...)") + + +Note that this graph engine detection is based on the metadata. You might experience +some query errors if the graph has been newly created and is not yet in the metadata. This +would result to a badly configured execution profile. If you really want to avoid that, +configure your execution profile explicitly:: + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.graph import GraphOptions, GraphProtocol, graph_graphson3_row_factory + + graph_name = 'movies' + ep_graphson3 = GraphExecutionProfile( + row_factory=graph_graphson3_row_factory, + graph_options=GraphOptions( + graph_protocol=GraphProtocol.GRAPHSON_3_0, + graph_name=graph_name)) + + cluster = Cluster(execution_profiles={'core': ep_graphson3}) + session = cluster.connect() + session.execute_graph("g.addV(...)", execution_profile='core') + + +We are ready to configure our graph schema. We will create a simple one for movies:: + + # A Vertex represents a "thing" in the world. + # Create the genre vertex + query = """ + schema.vertexLabel('genre') + .partitionBy('genreId', Int) + .property('name', Text) + .create() + """ + session.execute_graph(query) + + # Create the person vertex + query = """ + schema.vertexLabel('person') + .partitionBy('personId', Int) + .property('name', Text) + .create() + """ + session.execute_graph(query) + + # Create the movie vertex + query = """ + schema.vertexLabel('movie') + .partitionBy('movieId', Int) + .property('title', Text) + .property('year', Int) + .property('country', Text) + .create() + """ + session.execute_graph(query) + + # An edge represents a relationship between two vertices + # Create our edges + queries = """ + schema.edgeLabel('belongsTo').from('movie').to('genre').create(); + schema.edgeLabel('actor').from('movie').to('person').create(); + """ + session.execute_graph(queries) + + # Indexes to execute graph requests efficiently + + # If you have a node with the search workload enabled (solr), use the following: + indexes = """ + schema.vertexLabel('genre').searchIndex() + .by("name") + .create(); + + schema.vertexLabel('person').searchIndex() + .by("name") + .create(); + + schema.vertexLabel('movie').searchIndex() + .by('title') + .by("year") + .create(); + """ + session.execute_graph(indexes) + + # Otherwise, use secondary indexes: + indexes = """ + schema.vertexLabel('genre') + .secondaryIndex('by_genre') + .by('name') + .create() + + schema.vertexLabel('person') + .secondaryIndex('by_name') + .by('name') + .create() + + schema.vertexLabel('movie') + .secondaryIndex('by_title') + .by('title') + .create() + """ + session.execute_graph(indexes) + +Add some edge indexes (materialized views):: + + indexes = """ + schema.edgeLabel('belongsTo') + .from('movie') + .to('genre') + .materializedView('movie__belongsTo__genre_by_in_genreId') + .ifNotExists() + .partitionBy(IN, 'genreId') + .clusterBy(OUT, 'movieId', Asc) + .create() + + schema.edgeLabel('actor') + .from('movie') + .to('person') + .materializedView('movie__actor__person_by_in_personId') + .ifNotExists() + .partitionBy(IN, 'personId') + .clusterBy(OUT, 'movieId', Asc) + .create() + """ + session.execute_graph(indexes) + +Next, we'll add some data:: + + session.execute_graph(""" + g.addV('genre').property('genreId', 1).property('name', 'Action').next(); + g.addV('genre').property('genreId', 2).property('name', 'Drama').next(); + g.addV('genre').property('genreId', 3).property('name', 'Comedy').next(); + g.addV('genre').property('genreId', 4).property('name', 'Horror').next(); + """) + + session.execute_graph(""" + g.addV('person').property('personId', 1).property('name', 'Mark Wahlberg').next(); + g.addV('person').property('personId', 2).property('name', 'Leonardo DiCaprio').next(); + g.addV('person').property('personId', 3).property('name', 'Iggy Pop').next(); + """) + + session.execute_graph(""" + g.addV('movie').property('movieId', 1).property('title', 'The Happening'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 2).property('title', 'The Italian Job'). + property('year', 2003).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 3).property('title', 'Revolutionary Road'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 4).property('title', 'The Man in the Iron Mask'). + property('year', 1998).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 5).property('title', 'Dead Man'). + property('year', 1995).property('country', 'United States').next(); + """) + +Now that our genre, actor and movie vertices are added, we'll create the relationships (edges) between them:: + + session.execute_graph(""" + genre_horror = g.V().hasLabel('genre').has('name', 'Horror').id().next(); + genre_drama = g.V().hasLabel('genre').has('name', 'Drama').id().next(); + genre_action = g.V().hasLabel('genre').has('name', 'Action').id().next(); + + leo = g.V().hasLabel('person').has('name', 'Leonardo DiCaprio').id().next(); + mark = g.V().hasLabel('person').has('name', 'Mark Wahlberg').id().next(); + iggy = g.V().hasLabel('person').has('name', 'Iggy Pop').id().next(); + + the_happening = g.V().hasLabel('movie').has('title', 'The Happening').id().next(); + the_italian_job = g.V().hasLabel('movie').has('title', 'The Italian Job').id().next(); + rev_road = g.V().hasLabel('movie').has('title', 'Revolutionary Road').id().next(); + man_mask = g.V().hasLabel('movie').has('title', 'The Man in the Iron Mask').id().next(); + dead_man = g.V().hasLabel('movie').has('title', 'Dead Man').id().next(); + + g.addE('belongsTo').from(__.V(the_happening)).to(__.V(genre_horror)).next(); + g.addE('belongsTo').from(__.V(the_italian_job)).to(__.V(genre_action)).next(); + g.addE('belongsTo').from(__.V(rev_road)).to(__.V(genre_drama)).next(); + g.addE('belongsTo').from(__.V(man_mask)).to(__.V(genre_drama)).next(); + g.addE('belongsTo').from(__.V(man_mask)).to(__.V(genre_action)).next(); + g.addE('belongsTo').from(__.V(dead_man)).to(__.V(genre_drama)).next(); + + g.addE('actor').from(__.V(the_happening)).to(__.V(mark)).next(); + g.addE('actor').from(__.V(the_italian_job)).to(__.V(mark)).next(); + g.addE('actor').from(__.V(rev_road)).to(__.V(leo)).next(); + g.addE('actor').from(__.V(man_mask)).to(__.V(leo)).next(); + g.addE('actor').from(__.V(dead_man)).to(__.V(iggy)).next(); + """) + +We are all set. You can now query your graph. Here are some examples:: + + # Find all movies of the genre Drama + for r in session.execute_graph(""" + g.V().has('genre', 'name', 'Drama').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of the same genre than the movie 'Dead Man' + for r in session.execute_graph(""" + g.V().has('movie', 'title', 'Dead Man').out('belongsTo').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of Mark Wahlberg + for r in session.execute_graph(""" + g.V().has('person', 'name', 'Mark Wahlberg').in('actor').valueMap();"""): + print(r) + +To see a more graph examples, see `DataStax Graph Examples `_. + +Graph Types for the Core Engine +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here are the supported graph types with their python representations: + +============ ================= +DSE Graph Python Driver +============ ================= +text str +boolean bool +bigint long +int int +smallint int +varint long +double float +float float +uuid UUID +bigdecimal Decimal +duration Duration (cassandra.util) +inet str or IPV4Address/IPV6Address (if available) +timestamp datetime.datetime +date datetime.date +time datetime.time +polygon Polygon +point Point +linestring LineString +blob bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +list list +map dict +set set or list + (Can return a list due to numerical values returned by Java) +tuple tuple +udt class or namedtuple +============ ================= + +Named Parameters +~~~~~~~~~~~~~~~~ + +Named parameters are passed in a dict to :meth:`.cluster.Session.execute_graph`:: + + result_set = session.execute_graph('[a, b]', {'a': 1, 'b': 2}, execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + [r.value for r in result_set] # [1, 2] + +All python types listed in `Graph Types for the Core Engine`_ can be passed as named parameters and will be serialized +automatically to their graph representation: + +Example:: + + session.execute_graph(""" + g.addV('person'). + property('name', text_value). + property('age', integer_value). + property('birthday', timestamp_value). + property('house_yard', polygon_value).next() + """, { + 'text_value': 'Mike Smith', + 'integer_value': 34, + 'timestamp_value': datetime.datetime(1967, 12, 30), + 'polygon_value': Polygon(((30, 10), (40, 40), (20, 40), (10, 20), (30, 10))) + }) + + +As with all Execution Profile parameters, graph options can be set in the cluster default (as shown in the first example) +or specified per execution:: + + ep = session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, + graph_options=GraphOptions(graph_name='something-else')) + session.execute_graph(statement, execution_profile=ep) + +CQL collections, Tuple and UDT +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is a very interesting feature of the core engine: we can use all CQL data types, including +list, map, set, tuple and udt. Here is an example using all these types:: + + query = """ + schema.type('address') + .property('address', Text) + .property('city', Text) + .property('state', Text) + .create(); + """ + session.execute_graph(query) + + # It works the same way than normal CQL UDT, so we + # can create an udt class and register it + class Address(object): + def __init__(self, address, city, state): + self.address = address + self.city = city + self.state = state + + session.cluster.register_user_type(graph_name, 'address', Address) + + query = """ + schema.vertexLabel('person') + .partitionBy('personId', Int) + .property('address', typeOf('address')) + .property('friends', listOf(Text)) + .property('skills', setOf(Text)) + .property('scores', mapOf(Text, Int)) + .property('last_workout', tupleOf(Text, Date)) + .create() + """ + session.execute_graph(query) + + # insertion example + query = """ + g.addV('person') + .property('personId', pid) + .property('address', address) + .property('friends', friends) + .property('skills', skills) + .property('scores', scores) + .property('last_workout', last_workout) + .next() + """ + + session.execute_graph(query, { + 'pid': 3, + 'address': Address('42 Smith St', 'Quebec', 'QC'), + 'friends': ['Al', 'Mike', 'Cathy'], + 'skills': {'food', 'fight', 'chess'}, + 'scores': {'math': 98, 'french': 3}, + 'last_workout': ('CrossFit', datetime.date(2018, 11, 20)) + }) + +Limitations +----------- + +Since Python is not a strongly-typed language and the UDT/Tuple graphson representation is, you might +get schema errors when trying to write numerical data. Example:: + + session.execute_graph(""" + schema.vertexLabel('test_tuple').partitionBy('id', Int).property('t', tupleOf(Text, Bigint)).create() + """) + + session.execute_graph(""" + g.addV('test_tuple').property('id', 0).property('t', t) + """, + {'t': ('Test', 99))} + ) + + # error: [Invalid query] message="Value component 1 is of type int, not bigint" + +This is because the server requires the client to include a GraphSON schema definition +with every UDT or tuple query. In the general case, the driver can't determine what Graph type +is meant by, e.g., an int value, and so it can't serialize the value with the correct type in the schema. +The driver provides some numerical type-wrapper factories that you can use to specify types: + +* :func:`~.to_int` +* :func:`~.to_bigint` +* :func:`~.to_smallint` +* :func:`~.to_float` +* :func:`~.to_double` + +Here's the working example of the case above:: + + from cassandra.graph import to_bigint + + session.execute_graph(""" + g.addV('test_tuple').property('id', 0).property('t', t) + """, + {'t': ('Test', to_bigint(99))} + ) + +Continuous Paging +~~~~~~~~~~~~~~~~~ + +This is another nice feature that comes with the core engine: continuous paging with +graph queries. If all nodes of the cluster are >= DSE 6.8.0, it is automatically +enabled under the hood to get the best performance. If you want to explicitly +enable/disable it, you can do it through the execution profile:: + + # Disable it + ep = GraphExecutionProfile(..., continuous_paging_options=None)) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + + # Enable with a custom max_pages option + ep = GraphExecutionProfile(..., + continuous_paging_options=ContinuousPagingOptions(max_pages=10))) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) diff --git a/docs/graph_fluent.rst b/docs/graph_fluent.rst new file mode 100644 index 0000000000..8d5ad5377d --- /dev/null +++ b/docs/graph_fluent.rst @@ -0,0 +1,415 @@ +DataStax Graph Fluent API +========================= + +The fluent API adds graph features to the core driver: + +* A TinkerPop GraphTraversalSource builder to execute traversals on a DSE cluster +* The ability to execution traversal queries explicitly using execute_graph +* GraphSON serializers for all DSE Graph types. +* DSE Search predicates + +The Graph fluent API depends on Apache TinkerPop and is not installed by default. Make sure +you have the Graph requirements are properly :ref:`installed `. + +You might be interested in reading the :doc:`DataStax Graph Getting Started documentation ` to +understand the basics of creating a graph and its schema. + +Graph Traversal Queries +~~~~~~~~~~~~~~~~~~~~~~~ + +The driver provides :meth:`.Session.execute_graph`, which allows users to execute traversal +query strings. Here is a simple example:: + + session.execute_graph("g.addV('genre').property('genreId', 1).property('name', 'Action').next();") + +Since graph queries can be very complex, working with strings is not very convenient and is +hard to maintain. This fluent API allows you to build Gremlin traversals and write your graph +queries directly in Python. These native traversal queries can be executed explicitly, with +a `Session` object, or implicitly:: + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphProtocol + from cassandra.datastax.graph.fluent import DseGraph + + # Create an execution profile, using GraphSON3 for Core graphs + ep_graphson3 = DseGraph.create_execution_profile( + 'my_core_graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep_graphson3}) + session = cluster.connect() + + # Execute a fluent graph query + g = DseGraph.traversal_source(session=session) + g.addV('genre').property('genreId', 1).property('name', 'Action').next() + + # implicit execution caused by iterating over results + for v in g.V().has('genre', 'name', 'Drama').in_('belongsTo').valueMap(): + print(v) + +These :ref:`Python types ` are also supported transparently:: + + g.addV('person').property('name', 'Mike').property('birthday', datetime(1984, 3, 11)). \ + property('house_yard', Polygon(((30, 10), (40, 40), (20, 40), (10, 20), (30, 10))) + +More readings about Gremlin: + +* `DataStax Drivers Fluent API `_ +* `gremlin-python documentation `_ + +Configuring a Traversal Execution Profile +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The fluent api takes advantage of *configuration profiles* to allow +different execution configurations for the various query handlers. Graph traversal +execution requires a custom execution profile to enable Gremlin-bytecode as +query language. With Core graphs, it is important to use GraphSON3. Here is how +to accomplish this configuration: + +.. code-block:: python + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphProtocol + from cassandra.datastax.graph.fluent import DseGraph + + # Using GraphSON3 as graph protocol is a requirement with Core graphs. + ep = DseGraph.create_execution_profile( + 'graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + + # For Classic graphs, GraphSON1, GraphSON2 and GraphSON3 (DSE 6.8+) are supported. + ep_classic = DseGraph.create_execution_profile('classic_graph_name') # default is GraphSON2 + + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep, 'classic': ep_classic}) + session = cluster.connect() + + g = DseGraph.traversal_source(session) # Build the GraphTraversalSource + print(g.V().toList()) # Traverse the Graph + +Note that the execution profile created with :meth:`DseGraph.create_execution_profile <.datastax.graph.fluent.DseGraph.create_execution_profile>` cannot +be used for any groovy string queries. + +If you want to change execution property defaults, please see the :doc:`Execution Profile documentation ` +for a more generalized discussion of the API. Graph traversal queries use the same execution profile defined for DSE graph. If you +need to change the default properties, please refer to the :doc:`DSE Graph query documentation page ` + +Explicit Graph Traversal Execution with a DSE Session +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Traversal queries can be executed explicitly using `session.execute_graph` or `session.execute_graph_async`. These functions +return results as DSE graph types. If you are familiar with DSE queries or need async execution, you might prefer that way. +Below is an example of explicit execution. For this example, assume the schema has been generated as above: + +.. code-block:: python + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphProtocol + from cassandra.datastax.graph.fluent import DseGraph + from pprint import pprint + + ep = DseGraph.create_execution_profile( + 'graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + + g = DseGraph.traversal_source(session=session) + +Convert a traversal to a bytecode query for classic graphs:: + + addV_query = DseGraph.query_from_traversal( + g.addV('genre').property('genreId', 1).property('name', 'Action'), + graph_protocol=GraphProtocol.GRAPHSON_3_0 + ) + v_query = DseGraph.query_from_traversal( + g.V(), + graph_protocol=GraphProtocol.GRAPHSON_3_0) + + for result in session.execute_graph(addV_query): + pprint(result.value) + for result in session.execute_graph(v_query): + pprint(result.value) + +Converting a traversal to a bytecode query for core graphs require some more work, because we +need the cluster context for UDT and tuple types: + +.. code-block:: python + context = { + 'cluster': cluster, + 'graph_name': 'the_graph_for_the_query' + } + addV_query = DseGraph.query_from_traversal( + g.addV('genre').property('genreId', 1).property('name', 'Action'), + graph_protocol=GraphProtocol.GRAPHSON_3_0, + context=context + ) + + for result in session.execute_graph(addV_query): + pprint(result.value) + +Implicit Graph Traversal Execution with TinkerPop +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Using the :class:`DseGraph <.datastax.graph.fluent.DseGraph>` class, you can build a GraphTraversalSource +that will execute queries on a DSE session without explicitly passing anything to +that session. We call this *implicit execution* because the `Session` is not +explicitly involved. Everything is managed internally by TinkerPop while +traversing the graph and the results are TinkerPop types as well. + +Synchronous Example +------------------- + +.. code-block:: python + + # Build the GraphTraversalSource + g = DseGraph.traversal_source(session) + # implicitly execute the query by traversing the TraversalSource + g.addV('genre').property('genreId', 1).property('name', 'Action').next() + + # blocks until the query is completed and return the results + results = g.V().toList() + pprint(results) + +Asynchronous Exemple +-------------------- + +You can execute a graph traversal query asynchronously by using `.promise()`. It returns a +python `Future `_. + +.. code-block:: python + + # Build the GraphTraversalSource + g = DseGraph.traversal_source(session) + # implicitly execute the query by traversing the TraversalSource + g.addV('genre').property('genreId', 1).property('name', 'Action').next() # not async + + # get a future and wait + future = g.V().promise() + results = list(future.result()) + pprint(results) + + # or set a callback + def cb(f): + results = list(f.result()) + pprint(results) + future = g.V().promise() + future.add_done_callback(cb) + # do other stuff... + +Specify the Execution Profile explicitly +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you don't want to change the default graph execution profile (`EXEC_PROFILE_GRAPH_DEFAULT`), you can register a new +one as usual and use it explicitly. Here is an example: + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.datastax.graph.fluent import DseGraph + + cluster = Cluster() + ep = DseGraph.create_execution_profile('graph_name', graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster.add_execution_profile('graph_traversal', ep) + session = cluster.connect() + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal(g.V()) + session.execute_graph(query, execution_profile='graph_traversal') + +You can also create multiple GraphTraversalSources and use them with +the same execution profile (for different graphs): + +.. code-block:: python + + g_movies = DseGraph.traversal_source(session, graph_name='movies', ep) + g_series = DseGraph.traversal_source(session, graph_name='series', ep) + + print(g_movies.V().toList()) # Traverse the movies Graph + print(g_series.V().toList()) # Traverse the series Graph + +Batch Queries +~~~~~~~~~~~~~ + +DSE Graph supports batch queries using a :class:`TraversalBatch <.datastax.graph.fluent.query.TraversalBatch>` object +instantiated with :meth:`DseGraph.batch <.datastax.graph.fluent.DseGraph.batch>`. A :class:`TraversalBatch <.datastax.graph.fluent.query.TraversalBatch>` allows +you to execute multiple graph traversals in a single atomic transaction. A +traversal batch is executed with :meth:`.Session.execute_graph` or using +:meth:`TraversalBatch.execute <.datastax.graph.fluent.query.TraversalBatch.execute>` if bounded to a DSE session. + +Either way you choose to execute the traversal batch, you need to configure +the execution profile accordingly. Here is a example:: + + from cassandra.cluster import Cluster + from cassandra.datastax.graph.fluent import DseGraph + + ep = DseGraph.create_execution_profile( + 'graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster = Cluster(execution_profiles={'graphson3': ep}) + session = cluster.connect() + + g = DseGraph.traversal_source() + +To execute the batch using :meth:`.Session.execute_graph`, you need to convert +the batch to a GraphStatement:: + + batch = DseGraph.batch() + + batch.add( + g.addV('genre').property('genreId', 1).property('name', 'Action')) + batch.add( + g.addV('genre').property('genreId', 2).property('name', 'Drama')) # Don't use `.next()` with a batch + + graph_statement = batch.as_graph_statement(graph_protocol=GraphProtocol.GRAPHSON_3_0) + graph_statement.is_idempotent = True # configure any Statement parameters if needed... + session.execute_graph(graph_statement, execution_profile='graphson3') + +To execute the batch using :meth:`TraversalBatch.execute <.datastax.graph.fluent.query.TraversalBatch.execute>`, you need to bound the batch to a DSE session:: + + batch = DseGraph.batch(session, 'graphson3') # bound the session and execution profile + + batch.add( + g.addV('genre').property('genreId', 1).property('name', 'Action')) + batch.add( + g.addV('genre').property('genreId', 2).property('name', 'Drama')) # Don't use `.next()` with a batch + + batch.execute() + +DSL (Domain Specific Languages) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +DSL are very useful to write better domain-specific APIs and avoiding +code duplication. Let's say we have a graph of `People` and we produce +a lot of statistics based on age. All graph traversal queries of our +application would look like:: + + g.V().hasLabel("people").has("age", P.gt(21))... + + +which is not really verbose and quite annoying to repeat in a code base. Let's create a DSL:: + + from gremlin_python.process.graph_traversal import GraphTraversal, GraphTraversalSource + + class MyAppTraversal(GraphTraversal): + + def younger_than(self, age): + return self.has("age", P.lt(age)) + + def older_than(self, age): + return self.has("age", P.gt(age)) + + + class MyAppTraversalSource(GraphTraversalSource): + + def __init__(self, *args, **kwargs): + super(MyAppTraversalSource, self).__init__(*args, **kwargs) + self.graph_traversal = MyAppTraversal + + def people(self): + return self.get_graph_traversal().V().hasLabel("people") + +Now, we can use our DSL that is a lot cleaner:: + + from cassandra.datastax.graph.fluent import DseGraph + + # ... + g = DseGraph.traversal_source(session=session, traversal_class=MyAppTraversalsource) + + g.people().younger_than(21)... + g.people().older_than(30)... + +To see a more complete example of DSL, see the `Python killrvideo DSL app `_ + +Search +~~~~~~ + +DSE Graph can use search indexes that take advantage of DSE Search functionality for +efficient traversal queries. Here are the list of additional search predicates: + +Text tokenization: + +* :meth:`token <.datastax.graph.fluent.predicates.Search.token>` +* :meth:`token_prefix <.datastax.graph.fluent.predicates.Search.token_prefix>` +* :meth:`token_regex <.datastax.graph.fluent.predicates.Search.token_regex>` +* :meth:`token_fuzzy <.datastax.graph.fluent.predicates.Search.token_fuzzy>` + +Text match: + +* :meth:`prefix <.datastax.graph.fluent.predicates.Search.prefix>` +* :meth:`regex <.datastax.graph.fluent.predicates.Search.regex>` +* :meth:`fuzzy <.datastax.graph.fluent.predicates.Search.fuzzy>` +* :meth:`phrase <.datastax.graph.fluent.predicates.Search.phrase>` + +Geo: + +* :meth:`inside <.datastax.graph.fluent.predicates.Geo.inside>` + +Create search indexes +--------------------- + +For text tokenization: + +.. code-block:: python + + + s.execute_graph("schema.vertexLabel('my_vertex_label').index('search').search().by('text_field').asText().add()") + +For text match: + +.. code-block:: python + + + s.execute_graph("schema.vertexLabel('my_vertex_label').index('search').search().by('text_field').asString().add()") + + +For geospatial: + +You can create a geospatial index on Point and LineString fields. + +.. code-block:: python + + + s.execute_graph("schema.vertexLabel('my_vertex_label').index('search').search().by('point_field').add()") + + +Using search indexes +-------------------- + +Token: + +.. code-block:: python + + from cassandra.datastax.graph.fluent.predicates import Search + # ... + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal( + g.V().has('my_vertex_label','text_field', Search.token_regex('Hello.+World')).values('text_field')) + session.execute_graph(query) + +Text: + +.. code-block:: python + + from cassandra.datastax.graph.fluent.predicates import Search + # ... + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal( + g.V().has('my_vertex_label','text_field', Search.prefix('Hello')).values('text_field')) + session.execute_graph(query) + +Geospatial: + +.. code-block:: python + + from cassandra.datastax.graph.fluent.predicates import Geo + from cassandra.util import Distance + # ... + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal( + g.V().has('my_vertex_label','point_field', Geo.inside(Distance(46, 71, 100)).values('point_field')) + session.execute_graph(query) + + +For more details, please refer to the official `DSE Search Indexes Documentation `_ diff --git a/docs/index.rst b/docs/index.rst index 9571e31093..0f0e9edad6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,17 +1,112 @@ -Python Cassandra Driver -======================= +DataStax Python Driver for Apache Cassandra® +============================================ +A Python client driver for `Apache Cassandra® `_. +This driver works exclusively with the Cassandra Query Language v3 (CQL3) +and Cassandra's native protocol. Cassandra 2.1+ is supported, including DSE 4.7+. -Contents: +The driver supports Python 3.10 through 3.14. + +This driver is open source under the +`Apache v2 License `_. +The source code for this driver can be found on `GitHub `_. + +**Note:** DataStax products do not support big-endian systems. + +Contents +-------- +:doc:`installation` + How to install the driver. + +:doc:`getting_started` + A guide through the first steps of connecting to Cassandra and executing queries + +:doc:`execution_profiles` + An introduction to a more flexible way of configuring request execution + +:doc:`lwt` + Working with results of conditional requests + +:doc:`object_mapper` + Introduction to the integrated object mapper, cqlengine + +:doc:`performance` + Tips for getting good performance. + +:doc:`query_paging` + Notes on paging large query results + +:doc:`security` + An overview of the security features of the driver + +:doc:`upgrading` + A guide to upgrading versions of the driver + +:doc:`user_defined_types` + Working with Cassandra 2.1's user-defined types + +:doc:`dates_and_times` + Some discussion on the driver's approach to working with timestamp, date, time types + +:doc:`cloud` + A guide to connecting to Datastax Astra + +:doc:`column_encryption` + Transparent client-side per-column encryption and decryption + +:doc:`geo_types` + Working with DSE geometry types + +:doc:`graph` + Graph queries with the Core engine + +:doc:`classic_graph` + Graph queries with the Classic engine + +:doc:`graph_fluent` + DataStax Graph Fluent API + +:doc:`CHANGELOG` + Log of changes to the driver, organized by version. + +:doc:`faq` + A collection of Frequently Asked Questions + +:doc:`api/index` + The API documentation. .. toctree:: - :maxdepth: 2 + :hidden: api/index + installation + getting_started + upgrading + execution_profiles + performance + query_paging + lwt + security + user_defined_types + object_mapper + geo_types + graph + classic_graph + graph_fluent + dates_and_times + cloud + faq + +Getting Help +------------ +Visit the :doc:`FAQ section ` in this documentation. + +Please send questions to the `mailing list `_. -Indices and Tables -================== +Alternatively, you can use the `DataStax Community `_. -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` +Reporting Issues +---------------- +Please report any bugs and make any feature requests on the +`JIRA `_ issue tracker. +If you would like to contribute, please feel free to open a pull request. diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000000..a0a5e25dab --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,240 @@ +Installation +============ + +Supported Platforms +------------------- +Python 3.10 through 3.14 are supported. Both CPython (the standard Python +implementation) and `PyPy `_ are supported and tested. + +Linux, OSX, and Windows are supported. + +Installation through pip +------------------------ +`pip `_ is the suggested tool for installing +packages. It will handle installing all Python dependencies for the driver at +the same time as the driver itself. To install the driver*:: + + pip install cassandra-driver + +You can use ``pip install --pre cassandra-driver`` if you need to install a beta version. + +***Note**: if intending to use optional extensions, install the `dependencies <#optional-non-python-dependencies>`_ first. The driver may need to be reinstalled if dependencies are added after the initial installation. + +Verifying your Installation +--------------------------- +To check if the installation was successful, you can run:: + + python -c 'import cassandra; print(cassandra.__version__)' + +This command should print something like ``3.30.0``. + +.. _installation-datastax-graph: + +(*Optional*) DataStax Graph +--------------------------- +The driver provides an optional fluent graph API that depends on Apache TinkerPop (gremlinpython). It is +not installed by default. To be able to build Gremlin traversals, you need to install +the `graph` extra:: + + pip install cassandra-driver[graph] + +See :doc:`graph_fluent` for more details about this API. + +(*Optional*) Compression Support +-------------------------------- +Compression can optionally be used for communication between the driver and +Cassandra. There are currently two supported compression algorithms: +snappy (in Cassandra 1.2+) and LZ4 (only in Cassandra 2.0+). If either is +available for the driver and Cassandra also supports it, it will +be used automatically. + +For lz4 support:: + + pip install lz4 + +For snappy support:: + + pip install python-snappy + +(If using a Debian Linux derivative such as Ubuntu, it may be easier to +just run ``apt-get install python-snappy``.) + +(*Optional*) Metrics Support +---------------------------- +The driver has built-in support for capturing :attr:`.Cluster.metrics` about +the queries you run. Note that the ``scales`` module is required to +support metrics. This module is available from Pypi and can be installed with:: + + pip install scales + +*Optional:* Column-Level Encryption (CLE) Support +-------------------------------------------------- +The driver has built-in support for client-side encryption and +decryption of data. For more, see :doc:`column_encryption`. + +CLE depends on the Python `cryptography `_ module. +When installing Python driver 3.27.0. the `cryptography` module is +also downloaded and installed. +If you are using Python driver 3.28.0 or later and want to use CLE, you must +install the `cryptography `_ module. + +You can install this module along with the driver by specifying the `cle` extra:: + + pip install cassandra-driver[cle] + +Alternatively, you can also install the module directly via `pip`:: + + pip install cryptography + +Any version of cryptography >= 35.0 will work for the CLE feature. You can find additional +details at `PYTHON-1351 `_ + +Speeding Up Installation +^^^^^^^^^^^^^^^^^^^^^^^^ +By default, installing the driver through ``pip`` uses a pre-compiled, platform-specific wheel when available. +If using a source distribution rather than a wheel, Cython is used to compile certain parts of the driver. +This makes those hot paths faster at runtime, but the Cython compilation +process can take a long time -- as long as 10 minutes in some environments. + +In environments where performance is less important, it may be worth it to +:ref:`disable Cython as documented below `. + +Cython also supports concurrent builds of native extensions. The ``build-concurrency`` key in the +``tool.cassandra-driver`` table of pyproject.toml is an integer value which specifies the number of +concurrent builds Cython may execute. The value for this key must be a non-negative integer; the default is zero, +indicating no concurrent builds. Note that Cython's concurrent builds use the standard ``multiprocessing`` package +so this library must be availble is concurrent builds are used. + +OSX Installation Error +^^^^^^^^^^^^^^^^^^^^^^ +If you're installing on OSX and have XCode 5.1 installed, you may see an error like this:: + + clang: error: unknown argument: '-mno-fused-madd' [-Wunused-command-line-argument-hard-error-in-future] + +To fix this, re-run the installation with an extra compilation flag: + +.. code-block:: bash + + ARCHFLAGS=-Wno-error=unused-command-line-argument-hard-error-in-future pip install cassandra-driver + +.. _windows_build: + +Windows Installation Notes +-------------------------- +Installing the driver with extensions in Windows sometimes presents some challenges. A few notes about common +hang-ups: + +Setup requires a compiler. When using Python 2, this is as simple as installing `this package `_ +(this link is also emitted during install if setuptools is unable to find the resources it needs). Depending on your +system settings, this package may install as a user-specific application. Make sure to install for everyone, or at least +as the user that will be building the Python environment. + +It is also possible to run the build with your compiler of choice. Just make sure to have your environment setup with +the proper paths. Make sure the compiler target architecture matches the bitness of your Python runtime. +Perhaps the easiest way to do this is to run the build/install from a Visual Studio Command Prompt (a +shortcut installed with Visual Studio that sources the appropriate environment and presents a shell). + +Manual Installation +------------------- +You can always install the driver directly from a source checkout or tarball. +When installing manually, ensure the python dependencies are already +installed. You can find the list of dependencies in +`requirements.txt `_. + +Once the dependencies are installed, simply run:: + + pip install . + +(*Optional*) Non-python Dependencies +------------------------------------ +The driver has several **optional** features that have non-Python dependencies. + +C Extensions +^^^^^^^^^^^^ +By default, a number of extensions are compiled, providing faster hashing +for token-aware routing with the ``Murmur3Partitioner``, +`libev `_ event loop integration, +and Cython optimized extensions. + +Some or all of these native extensions can be disabled by changing the corresponding +key in the ``tool.cassandra-driver`` table of pyproject.toml to ``false``. Please consult +the ``build-murmur3-extension``, ``build-libev-extension`` and ``build-cython-extensions`` +keys (respectively) to disable these extensions. + +To compile the extensions, ensure that GCC and the Python headers are available. + +On Ubuntu and Debian, this can be accomplished by running:: + + $ sudo apt-get install gcc python-dev + +On RedHat and RedHat-based systems like CentOS and Fedora:: + + $ sudo yum install gcc python-devel + +On OS X, homebrew installations of Python should provide the necessary headers. + +See :ref:`windows_build` for notes on configuring the build environment on Windows. + +.. _cython-extensions: + +Cython-based Extensions +~~~~~~~~~~~~~~~~~~~~~~~ +By default, this package uses `Cython `_ to optimize core modules and build custom extensions. +This is not a hard requirement, but is enabled by default to build extensions offering better performance than the +pure Python implementation. + +This process does take some time, however, so if you wish to build without generating these extensions using +Cython you can do so by changing the ``build-cython-extensions`` key in the ``tool.cassandra-driver`` table of pyproject.toml. +By default this key is set to ``true``; simply changing it to ``false`` will disable all Cython functionality. + +Supported Event Loops +^^^^^^^^^^^^^^^^^^^^^ +The ``asyncore`` and ``libev`` event loops are proven production-grade event loops. Python 3.12 removed +asyncore from the runtime but this event loop can still be used in newer versions of Python via the +`pyasyncore `_ package. + +The ``asyncio`` event loop is generally functional but still somewhat experimental and not recommended +for production systems. We anticipate significant improvements to this event loop (including hopefully +making this event loop the default going forward) in 3.31.0. + +The ``gevent``, ``eventlet`` and ``Twisted`` event loops have been deprecated in 3.30.0 and will be removed +completely in 3.31.0. + +libev support +^^^^^^^^^^^^^ +If you're on Linux, you should be able to install libev +through a package manager. For example, on Debian/Ubuntu:: + + $ sudo apt-get install libev4 libev-dev + +On RHEL/CentOS/Fedora:: + + $ sudo yum install libev libev-devel + +If you're on Mac OS X, you should be able to install libev +through `Homebrew `_. For example, on Mac OS X:: + + $ brew install libev + +The libev extension can now be built for Windows as of Python driver version 3.29.2. You can +install libev using any Windows package manager. For example, to install using `vcpkg `_:: + + $ vcpkg install libev + +If successful, you should be able to build and install the extension +(just using ``pip install`` or ``pip install -e``) and then use +the libev event loop by doing the following: + +.. code-block:: python + + >>> from cassandra.io.libevreactor import LibevConnection + >>> from cassandra.cluster import Cluster + + >>> cluster = Cluster() + >>> cluster.connection_class = LibevConnection + >>> session = cluster.connect() + +(*Optional*) Configuring SSL +----------------------------- +Andrew Mussey has published a thorough guide on +`Using SSL with the DataStax Python driver `_. diff --git a/docs/lwt.rst b/docs/lwt.rst new file mode 100644 index 0000000000..2cc272f350 --- /dev/null +++ b/docs/lwt.rst @@ -0,0 +1,91 @@ +Lightweight Transactions (Compare-and-set) +========================================== + +Lightweight Transactions (LWTs) are mostly pass-through CQL for the driver. However, +the server returns some specialized results indicating the outcome and optional state +preceding the transaction. + +For pertinent execution parameters, see :attr:`.Statement.serial_consistency_level`. + +This section discusses working with specialized result sets returned by the server for LWTs, +and how to work with them using the driver. + + +Specialized Results +------------------- +The result returned from a LWT request is always a single row result. It will always have +prepended a special column named ``[applied]``. How this value appears in your results depends +on the row factory in use. See below for examples. + +The value of this ``[applied]`` column is boolean value indicating whether or not the transaction was applied. +If ``True``, it is the only column in the result. If ``False``, the additional columns depend on the LWT operation being +executed: + +- When using a ``UPDATE ... IF "col" = ...`` clause, the result will contain the ``[applied]`` column, plus the existing columns + and values for any columns in the ``IF`` clause (and thus the value that caused the transaction to fail). + +- When using ``INSERT ... IF NOT EXISTS``, the result will contain the ``[applied]`` column, plus all columns and values + of the existing row that rejected the transaction. + +- ``UPDATE .. IF EXISTS`` never has additional columns, regardless of ``[applied]`` status. + +How the ``[applied]`` column manifests depends on the row factory in use. Considering the following (initially empty) table:: + + CREATE TABLE test.t ( + k int PRIMARY KEY, + v int, + x int + ) + +... the following sections show the expected result for a number of example statements, using the three base row factories. + +named_tuple_factory (default) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The name ``[applied]`` is not a valid Python identifier, so the square brackets are actually removed +from the attribute for the resulting ``namedtuple``. The row always has a boolean column ``applied`` in position 0:: + + >>> session.execute("INSERT INTO t (k,v) VALUES (0,0) IF NOT EXISTS") + Row(applied=True) + + >>> session.execute("INSERT INTO t (k,v) VALUES (0,0) IF NOT EXISTS") + Row(applied=False, k=0, v=0, x=None) + + >>> session.execute("UPDATE t SET v = 1, x = 2 WHERE k = 0 IF v =0") + Row(applied=True) + + >>> session.execute("UPDATE t SET v = 1, x = 2 WHERE k = 0 IF v =0 AND x = 1") + Row(applied=False, v=1, x=2) + +tuple_factory +~~~~~~~~~~~~~ +This return type does not refer to names, but the boolean value ``applied`` is always present in position 0:: + + >>> session.execute("INSERT INTO t (k,v) VALUES (0,0) IF NOT EXISTS") + (True,) + + >>> session.execute("INSERT INTO t (k,v) VALUES (0,0) IF NOT EXISTS") + (False, 0, 0, None) + + >>> session.execute("UPDATE t SET v = 1, x = 2 WHERE k = 0 IF v =0") + (True,) + + >>> session.execute("UPDATE t SET v = 1, x = 2 WHERE k = 0 IF v =0 AND x = 1") + (False, 1, 2) + +dict_factory +~~~~~~~~~~~~ +The retuned ``dict`` contains the ``[applied]`` key:: + + >>> session.execute("INSERT INTO t (k,v) VALUES (0,0) IF NOT EXISTS") + {u'[applied]': True} + + >>> session.execute("INSERT INTO t (k,v) VALUES (0,0) IF NOT EXISTS") + {u'x': 2, u'[applied]': False, u'v': 1} + + >>> session.execute("UPDATE t SET v = 1, x = 2 WHERE k = 0 IF v =0") + {u'x': None, u'[applied]': False, u'k': 0, u'v': 0} + + >>> session.execute("UPDATE t SET v = 1, x = 2 WHERE k = 0 IF v =0 AND x = 1") + {u'[applied]': True} + + diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000..6be2277f78 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,190 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\cqlengine.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\cqlengine.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +:end diff --git a/docs/object_mapper.rst b/docs/object_mapper.rst new file mode 100644 index 0000000000..21d2954f4b --- /dev/null +++ b/docs/object_mapper.rst @@ -0,0 +1,105 @@ +Object Mapper +============= + +cqlengine is the Cassandra CQL 3 Object Mapper packaged with this driver + +:ref:`Jump to Getting Started ` + +Contents +-------- +:doc:`cqlengine/upgrade_guide` + For migrating projects from legacy cqlengine, to the integrated product + +:doc:`cqlengine/models` + Examples defining models, and mapping them to tables + +:doc:`cqlengine/queryset` + Overview of query sets and filtering + +:doc:`cqlengine/batches` + Working with batch mutations + +:doc:`cqlengine/connections` + Working with multiple sessions + +:ref:`API Documentation ` + Index of API documentation + +:doc:`cqlengine/third_party` + High-level examples in Celery and uWSGI + +:doc:`cqlengine/faq` + +.. toctree:: + :hidden: + + cqlengine/upgrade_guide + cqlengine/models + cqlengine/queryset + cqlengine/batches + cqlengine/connections + cqlengine/third_party + cqlengine/faq + +.. _getting-started: + +Getting Started +--------------- + +.. code-block:: python + + import uuid + from cassandra.cqlengine import columns + from cassandra.cqlengine import connection + from datetime import datetime + from cassandra.cqlengine.management import sync_table + from cassandra.cqlengine.models import Model + + #first, define a model + class ExampleModel(Model): + example_id = columns.UUID(primary_key=True, default=uuid.uuid4) + example_type = columns.Integer(index=True) + created_at = columns.DateTime() + description = columns.Text(required=False) + + #next, setup the connection to your cassandra server(s)... + # see https://docs.datastax.com/en/developer/python-driver/latest/api/cassandra/cluster.html for options + # the list of hosts will be passed to create a Cluster() instance + connection.setup(['127.0.0.1'], "cqlengine", protocol_version=3) + + #...and create your CQL table + >>> sync_table(ExampleModel) + + #now we can create some rows: + >>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now()) + >>> em2 = ExampleModel.create(example_type=0, description="example2", created_at=datetime.now()) + >>> em3 = ExampleModel.create(example_type=0, description="example3", created_at=datetime.now()) + >>> em4 = ExampleModel.create(example_type=0, description="example4", created_at=datetime.now()) + >>> em5 = ExampleModel.create(example_type=1, description="example5", created_at=datetime.now()) + >>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now()) + >>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now()) + >>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now()) + + #and now we can run some queries against our table + >>> ExampleModel.objects.count() + 8 + >>> q = ExampleModel.objects(example_type=1) + >>> q.count() + 4 + >>> for instance in q: + >>> print(instance.description) + example5 + example6 + example7 + example8 + + #here we are applying additional filtering to an existing query + #query objects are immutable, so calling filter returns a new + #query object + >>> q2 = q.filter(example_id=em5.example_id) + + >>> q2.count() + 1 + >>> for instance in q2: + >>> print(instance.description) + example5 diff --git a/docs/performance.rst b/docs/performance.rst new file mode 100644 index 0000000000..f7a3f49e0f --- /dev/null +++ b/docs/performance.rst @@ -0,0 +1,45 @@ +Performance Notes +================= +The Python driver for Cassandra offers several methods for executing queries. +You can synchronously block for queries to complete using +:meth:`.Session.execute()`, you can obtain asynchronous request futures through +:meth:`.Session.execute_async()`, and you can attach a callback to the future +with :meth:`.ResponseFuture.add_callback()`. + +Examples of multiple request patterns can be found in the benchmark scripts included in the driver project. + +The choice of execution pattern will depend on the application context. For applications dealing with multiple +requests in a given context, the recommended pattern is to use concurrent asynchronous +requests with callbacks. For many use cases, you don't need to implement this pattern yourself. +:meth:`cassandra.concurrent.execute_concurrent` and :meth:`cassandra.concurrent.execute_concurrent_with_args` +provide this pattern with a synchronous API and tunable concurrency. + +Due to the GIL and limited concurrency, the driver can become CPU-bound pretty quickly. The sections below +discuss further runtime and design considerations for mitigating this limitation. + +PyPy +---- +`PyPy `_ is an alternative Python runtime which uses a JIT compiler to +reduce CPU consumption. This leads to a huge improvement in the driver performance, +more than doubling throughput for many workloads. + +Cython Extensions +----------------- +`Cython `_ is an optimizing compiler and language that can be used to compile the core files and +optional extensions for the driver. Cython is not a strict dependency, but the extensions will be built by default. + +See :doc:`installation` for details on controlling this build. + +multiprocessing +--------------- +All of the patterns discussed above may be used over multiple processes using the +`multiprocessing `_ +module. Multiple processes will scale better than multiple threads, so if high throughput is your goal, +consider this option. + +Be sure to **never share any** :class:`~.Cluster`, :class:`~.Session`, +**or** :class:`~.ResponseFuture` **objects across multiple processes**. These +objects should all be created after forking the process, not before. + +For further discussion and simple examples using the driver with ``multiprocessing``, +see `this blog post `_. diff --git a/docs/query_paging.rst b/docs/query_paging.rst new file mode 100644 index 0000000000..23ee2c1129 --- /dev/null +++ b/docs/query_paging.rst @@ -0,0 +1,95 @@ +.. _query-paging: + +Paging Large Queries +==================== +Cassandra 2.0+ offers support for automatic query paging. Starting with +version 2.0 of the driver, if :attr:`~.Cluster.protocol_version` is greater than +:const:`2` (it is by default), queries returning large result sets will be +automatically paged. + +Controlling the Page Size +------------------------- +By default, :attr:`.Session.default_fetch_size` controls how many rows will +be fetched per page. This can be overridden per-query by setting +:attr:`~.fetch_size` on a :class:`~.Statement`. By default, each page +will contain at most 5000 rows. + +Handling Paged Results +---------------------- +Whenever the number of result rows for are query exceed the page size, an +instance of :class:`~.PagedResult` will be returned instead of a normal +list. This class implements the iterator interface, so you can treat +it like a normal iterator over rows:: + + from cassandra.query import SimpleStatement + query = "SELECT * FROM users" # users contains 100 rows + statement = SimpleStatement(query, fetch_size=10) + for user_row in session.execute(statement): + process_user(user_row) + +Whenever there are no more rows in the current page, the next page will +be fetched transparently. However, note that it *is* possible for +an :class:`Exception` to be raised while fetching the next page, just +like you might see on a normal call to ``session.execute()``. + +If you use :meth:`.Session.execute_async()` along with, +:meth:`.ResponseFuture.result()`, the first page will be fetched before +:meth:`~.ResponseFuture.result()` returns, but latter pages will be +transparently fetched synchronously while iterating the result. + +Handling Paged Results with Callbacks +------------------------------------- +If callbacks are attached to a query that returns a paged result, +the callback will be called once per page with a normal list of rows. + +Use :attr:`.ResponseFuture.has_more_pages` and +:meth:`.ResponseFuture.start_fetching_next_page()` to continue fetching +pages. For example:: + + class PagedResultHandler(object): + + def __init__(self, future): + self.error = None + self.finished_event = Event() + self.future = future + self.future.add_callbacks( + callback=self.handle_page, + errback=self.handle_err) + + def handle_page(self, rows): + for row in rows: + process_row(row) + + if self.future.has_more_pages: + self.future.start_fetching_next_page() + else: + self.finished_event.set() + + def handle_error(self, exc): + self.error = exc + self.finished_event.set() + + future = session.execute_async("SELECT * FROM users") + handler = PagedResultHandler(future) + handler.finished_event.wait() + if handler.error: + raise handler.error + +Resume Paged Results +-------------------- + +You can resume the pagination when executing a new query by using the :attr:`.ResultSet.paging_state`. This can be useful if you want to provide some stateless pagination capabilities to your application (ie. via http). For example:: + + from cassandra.query import SimpleStatement + query = "SELECT * FROM users" + statement = SimpleStatement(query, fetch_size=10) + results = session.execute(statement) + + # save the paging_state somewhere and return current results + web_session['paging_state'] = results.paging_state + + + # resume the pagination sometime later... + statement = SimpleStatement(query, fetch_size=10) + ps = web_session['paging_state'] + results = session.execute(statement, paging_state=ps) diff --git a/docs/security.rst b/docs/security.rst new file mode 100644 index 0000000000..6dd2624c24 --- /dev/null +++ b/docs/security.rst @@ -0,0 +1,421 @@ +.. _security: + +Security +======== +The two main security components you will use with the +Python driver are Authentication and SSL. + +Authentication +-------------- +Versions 2.0 and higher of the driver support a SASL-based +authentication mechanism when :attr:`~.Cluster.protocol_version` +is set to 2 or higher. To use this authentication, set +:attr:`~.Cluster.auth_provider` to an instance of a subclass +of :class:`~cassandra.auth.AuthProvider`. When working +with Cassandra's ``PasswordAuthenticator``, you can use +the :class:`~cassandra.auth.PlainTextAuthProvider` class. + +For example, suppose Cassandra is setup with its default +'cassandra' user with a password of 'cassandra': + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + auth_provider = PlainTextAuthProvider(username='cassandra', password='cassandra') + cluster = Cluster(auth_provider=auth_provider, protocol_version=2) + + + +Custom Authenticators +^^^^^^^^^^^^^^^^^^^^^ +If you're using something other than Cassandra's ``PasswordAuthenticator``, +:class:`~.SaslAuthProvider` is provided for generic SASL authentication mechanisms, +utilizing the ``pure-sasl`` package. +If these do not suit your needs, you may need to create your own subclasses of +:class:`~.AuthProvider` and :class:`~.Authenticator`. You can use the Sasl classes +as example implementations. + +Protocol v1 Authentication +^^^^^^^^^^^^^^^^^^^^^^^^^^ +When working with Cassandra 1.2 (or a higher version with +:attr:`~.Cluster.protocol_version` set to ``1``), you will not pass in +an :class:`~.AuthProvider` instance. Instead, you should pass in a +function that takes one argument, the IP address of a host, and returns +a dict of credentials with a ``username`` and ``password`` key: + +.. code-block:: python + + from cassandra.cluster import Cluster + + def get_credentials(host_address): + return {'username': 'joe', 'password': '1234'} + + cluster = Cluster(auth_provider=get_credentials, protocol_version=1) + +SSL +--- +SSL should be used when client encryption is enabled in Cassandra. + +To give you as much control as possible over your SSL configuration, our SSL +API takes a user-created `SSLContext` instance from the Python standard library. +These docs will include some examples for how to achieve common configurations, +but the `ssl.SSLContext `_ documentation +gives a more complete description of what is possible. + +To enable SSL with version 3.17.0 and higher, you will need to set :attr:`.Cluster.ssl_context` to a +``ssl.SSLContext`` instance to enable SSL. Optionally, you can also set :attr:`.Cluster.ssl_options` +to a dict of options. These will be passed as kwargs to ``ssl.SSLContext.wrap_socket()`` +when new sockets are created. + +If you create your SSLContext using `ssl.create_default_context `_, +be aware that SSLContext.check_hostname is set to True by default, so the hostname validation will be done +by Python and not the driver. For this reason, we need to set the server_hostname at best effort, which is the +resolved ip address. If this validation needs to be done against the FQDN, consider enabling it using the ssl_options +as described in the following examples or implement your own :class:`~.connection.EndPoint` and +:class:`~.connection.EndPointFactory`. + + +The following examples assume you have generated your Cassandra certificate and +keystore files with these intructions: + +* `Setup SSL Cert `_ + +It might be also useful to learn about the different levels of identity verification to understand the examples: + +* `Using SSL in DSE drivers `_ + +SSL with Twisted or Eventlet +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Twisted and Eventlet both use an alternative SSL implementation called pyOpenSSL, so if your `Cluster`'s connection class is +:class:`~cassandra.io.twistedreactor.TwistedConnection` or :class:`~cassandra.io.eventletreactor.EventletConnection`, you must pass a +`pyOpenSSL context `_ instead. +An example is provided in these docs, and more details can be found in the +`documentation `_. +pyOpenSSL is not installed by the driver and must be installed separately. + +SSL Configuration Examples +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Here, we'll describe the server and driver configuration necessary to set up SSL to meet various goals, such as the client verifying the server and the server verifying the client. We'll also include Python code demonstrating how to use servers and drivers configured in these ways. + +.. _ssl-no-identify-verification: + +No identity verification +++++++++++++++++++++++++ + +No identity verification at all. Note that this is not recommended for for production deployments. + +The Cassandra configuration:: + + client_encryption_options: + enabled: true + keystore: /path/to/127.0.0.1.keystore + keystore_password: myStorePass + require_client_auth: false + +The driver configuration: + +.. code-block:: python + + from cassandra.cluster import Cluster, Session + from ssl import SSLContext, PROTOCOL_TLS + + ssl_context = SSLContext(PROTOCOL_TLS) + + cluster = Cluster(['127.0.0.1'], ssl_context=ssl_context) + session = cluster.connect() + +.. _ssl-client-verifies-server: + +Client verifies server +++++++++++++++++++++++ + +Ensure the python driver verifies the identity of the server. + +The Cassandra configuration:: + + client_encryption_options: + enabled: true + keystore: /path/to/127.0.0.1.keystore + keystore_password: myStorePass + require_client_auth: false + +For the driver configuration, it's very important to set `ssl_context.verify_mode` +to `CERT_REQUIRED`. Otherwise, the loaded verify certificate will have no effect: + +.. code-block:: python + + from cassandra.cluster import Cluster, Session + from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED + + ssl_context = SSLContext(PROTOCOL_TLS) + ssl_context.load_verify_locations('/path/to/rootca.crt') + ssl_context.verify_mode = CERT_REQUIRED + + cluster = Cluster(['127.0.0.1'], ssl_context=ssl_context) + session = cluster.connect() + +Additionally, you can also force the driver to verify the `hostname` of the server by passing additional options to `ssl_context.wrap_socket` via the `ssl_options` kwarg: + +.. code-block:: python + + from cassandra.cluster import Cluster, Session + from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED + + ssl_context = SSLContext(PROTOCOL_TLS) + ssl_context.load_verify_locations('/path/to/rootca.crt') + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.check_hostname = True + ssl_options = {'server_hostname': '127.0.0.1'} + + cluster = Cluster(['127.0.0.1'], ssl_context=ssl_context, ssl_options=ssl_options) + session = cluster.connect() + +.. _ssl-server-verifies-client: + +Server verifies client +++++++++++++++++++++++ + +If Cassandra is configured to verify clients (``require_client_auth``), you need to generate +SSL key and certificate files. + +The cassandra configuration:: + + client_encryption_options: + enabled: true + keystore: /path/to/127.0.0.1.keystore + keystore_password: myStorePass + require_client_auth: true + truststore: /path/to/dse-truststore.jks + truststore_password: myStorePass + +The Python ``ssl`` APIs require the certificate in PEM format. First, create a certificate +conf file: + +.. code-block:: bash + + cat > gen_client_cert.conf <`_ +for more details about ``SSLContext`` configuration. + +**Server verifies client and client verifies server using Twisted and pyOpenSSL** + +.. code-block:: python + + from OpenSSL import SSL, crypto + from cassandra.cluster import Cluster + from cassandra.io.twistedreactor import TwistedConnection + + ssl_context = SSL.Context(SSL.TLSv1_2_METHOD) + ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok) + ssl_context.use_certificate_file('/path/to/client.crt_signed') + ssl_context.use_privatekey_file('/path/to/client.key') + ssl_context.load_verify_locations('/path/to/rootca.crt') + + cluster = Cluster( + contact_points=['127.0.0.1'], + connection_class=TwistedConnection, + ssl_context=ssl_context, + ssl_options={'check_hostname': True} + ) + session = cluster.connect() + + +Connecting using Eventlet would look similar except instead of importing and using ``TwistedConnection``, you would +import and use ``EventletConnection``, including the appropriate monkey-patching. + +Versions 3.16.0 and lower +^^^^^^^^^^^^^^^^^^^^^^^^^ + +To enable SSL you will need to set :attr:`.Cluster.ssl_options` to a +dict of options. These will be passed as kwargs to ``ssl.wrap_socket()`` +when new sockets are created. Note that this use of ssl_options will be +deprecated in the next major release. + +By default, a ``ca_certs`` value should be supplied (the value should be +a string pointing to the location of the CA certs file), and you probably +want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLS`` to match +Cassandra's default protocol. + +For example: + +.. code-block:: python + + from cassandra.cluster import Cluster + from ssl import PROTOCOL_TLS, CERT_REQUIRED + + ssl_opts = { + 'ca_certs': '/path/to/my/ca.certs', + 'ssl_version': PROTOCOL_TLS, + 'cert_reqs': CERT_REQUIRED # Certificates are required and validated + } + cluster = Cluster(ssl_options=ssl_opts) + +This is only an example to show how to pass the ssl parameters. Consider reading +the `python ssl documentation `_ for +your configuration. For further reading, Andrew Mussey has published a thorough guide on +`Using SSL with the DataStax Python driver `_. + +SSL with Twisted +++++++++++++++++ + +In case the twisted event loop is used pyOpenSSL must be installed or an exception will be risen. Also +to set the ``ssl_version`` and ``cert_reqs`` in ``ssl_opts`` the appropriate constants from pyOpenSSL are expected. + +DSE Authentication +------------------ +When authenticating against DSE, the Cassandra driver provides two auth providers that work both with legacy kerberos and Cassandra authenticators, +as well as the new DSE Unified Authentication. This allows client to configure this auth provider independently, +and in advance of any server upgrade. These auth providers are configured in the same way as any previous implementation:: + + from cassandra.auth import DSEGSSAPIAuthProvider + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"]) + cluster = Cluster(auth_provider=auth_provider) + session = cluster.connect() + +Implementations are :attr:`.DSEPlainTextAuthProvider`, :class:`.DSEGSSAPIAuthProvider` and :class:`.SaslAuthProvider`. + +DSE Unified Authentication +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With DSE (>=5.1), unified Authentication allows you to: + +* Proxy Login: Authenticate using a fixed set of authentication credentials but allow authorization of resources based another user id. +* Proxy Execute: Authenticate using a fixed set of authentication credentials but execute requests based on another user id. + +Proxy Login ++++++++++++ + +Proxy login allows you to authenticate with a user but act as another one. You need to ensure the authenticated user has the permission to use the authorization of resources of the other user. ie. this example will allow the `server` user to authenticate as usual but use the authorization of `user1`: + +.. code-block:: text + + GRANT PROXY.LOGIN on role user1 to server + +then you can do the proxy authentication.... + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import SaslAuthProvider + + sasl_kwargs = { + "service": 'dse', + "mechanism":"PLAIN", + "username": 'server', + 'password': 'server', + 'authorization_id': 'user1' + } + + auth_provider = SaslAuthProvider(**sasl_kwargs) + c = Cluster(auth_provider=auth_provider) + s = c.connect() + s.execute(...) # all requests will be executed as 'user1' + +If you are using kerberos, you can use directly :class:`.DSEGSSAPIAuthProvider` and pass the authorization_id, like this: + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import DSEGSSAPIAuthProvider + + # Ensure the kerberos ticket of the server user is set with the kinit utility. + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="server@DATASTAX.COM", + authorization_id='user1@DATASTAX.COM') + c = Cluster(auth_provider=auth_provider) + s = c.connect() + s.execute(...) # all requests will be executed as 'user1' + + +Proxy Execute ++++++++++++++ + +Proxy execute allows you to execute requests as another user than the authenticated one. You need to ensure the authenticated user has the permission to use the authorization of resources of the specified user. ie. this example will allow the `server` user to execute requests as `user1`: + +.. code-block:: text + + GRANT PROXY.EXECUTE on role user1 to server + +then you can do a proxy execute... + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import DSEPlainTextAuthProvider, + + auth_provider = DSEPlainTextAuthProvider('server', 'server') + + c = Cluster(auth_provider=auth_provider) + s = c.connect() + s.execute('select * from k.t;', execute_as='user1') # the request will be executed as 'user1' + +Please see the `official documentation `_ for more details on the feature and configuration process. diff --git a/docs/themes/custom/static/custom.css_t b/docs/themes/custom/static/custom.css_t new file mode 100644 index 0000000000..c3460e75a5 --- /dev/null +++ b/docs/themes/custom/static/custom.css_t @@ -0,0 +1,26 @@ +@import url("alabaster.css"); + +div.document { + width: 1200px; +} + +div.sphinxsidebar h1.logo a { + font-size: 24px; +} + +code.descname { + color: #4885ed; +} + +th.field-name { + min-width: 100px; + color: #3cba54; +} + +div.versionmodified { + font-weight: bold +} + +div.versionadded { + font-weight: bold +} diff --git a/docs/themes/custom/theme.conf b/docs/themes/custom/theme.conf new file mode 100644 index 0000000000..b0fbb6961e --- /dev/null +++ b/docs/themes/custom/theme.conf @@ -0,0 +1,11 @@ +[theme] +inherit = alabaster +stylesheet = custom.css +pygments_style = friendly + +[options] +description = Python driver for Cassandra +github_user = datastax +github_repo = python-driver +github_button = true +github_type = star \ No newline at end of file diff --git a/docs/upgrading.rst b/docs/upgrading.rst new file mode 100644 index 0000000000..5ea51440da --- /dev/null +++ b/docs/upgrading.rst @@ -0,0 +1,426 @@ +Upgrading +========= + +.. toctree:: + :maxdepth: 1 + +Upgrading to 3.30.0 +------------------- +Version 3.30.0 of the Python driver is the first release since the driver's donation +to the Apache Software Foundation (ASF). + +Supported Python Versions +^^^^^^^^^^^^^^^^^^^^^^^^^ +An individual version of the Python driver aims to officially support all Python runtimes +that are not end-of-life (EOL) at the time of that version's release. For 3.30.0 this policy +entails support for Python 3.10 through Python 3.14. The driver will likely continue to work +reasonably well on older Python runtimes but only these versions are officially supported. + +Conversion to pyproject.toml +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +With this release we are moving away from the install and configuration process based on +setup.py and towards the use of pyproject.toml. As a result of this change (and in keeping +with the typical use of pyproject.toml) configuration of a driver build is now declarative. All +build options should be specified in pyproject.toml and overrides via command-line flags or +environment variables are no longer supported. Please consult +`CASSPYTHON-7 `_ for additional details. + +Event Loop Deprecation +^^^^^^^^^^^^^^^^^^^^^^ +With this release the eventlet, gevent and Twisted event loops are considered deprecated. Use +of these event loops in this version will generate a warning to this effect. We are planning on +removing these event loops in their entirety in the next minor release. Please consult +`CASSPYTHON-12 `_ for additional details. + +Removal of Win32 Wheels +^^^^^^^^^^^^^^^^^^^^^^^ +As of this release we will no longer be offering wheels for Win32 platforms. Wheels for other +Windows platforms will continue to be deployed to PyPI. Please consult +`CASSPYTHON-5 `_ for additional details. + +Change to DRIVER_NAME in STARTUP Messages +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The DRIVER_NAME property in STARTUP messages has been changed in this release to allow +administrators to clearly distinguish between uses of the previous DataStax Python drivers +and this driver. Any monitoring/management applications which were monitoring driver usage +based on this string should be aware of this change and update accordingly. Please consult +`CASSPYTHON-17 `_ for additional details. + + +Upgrading from dse-driver +------------------------- +Since 3.21.0, cassandra-driver fully supports DataStax products. dse-driver and +dse-graph users should now migrate to cassandra-driver to benefit from latest bug fixes +and new features. The upgrade to this new unified driver version is straightforward +with no major API changes. + +Installation +^^^^^^^^^^^^ + +Only the `cassandra-driver` package should be installed. `dse-driver` and `dse-graph` +are not required anymore:: + + pip install cassandra-driver + +If you need the Graph *Fluent* API (features provided by dse-graph):: + + pip install cassandra-driver[graph] + +See :doc:`installation` for more details. + +Import from the cassandra module +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +There is no `dse` module, so you should import from the `cassandra` module. You +need to change only the first module of your import statements, not the submodules. + +.. code-block:: python + + from dse.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from dse.auth import PlainTextAuthProvider + from dse.policies import WhiteListRoundRobinPolicy + + # becomes + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.auth import PlainTextAuthProvider + from cassandra.policies import WhiteListRoundRobinPolicy + +Also note that the cassandra.hosts module doesn't exist in cassandra-driver. This +module is named cassandra.pool. + +dse-graph +^^^^^^^^^ + +dse-graph features are now built-in in cassandra-driver. The only change you need +to do is your import statements: + +.. code-block:: python + + from dse_graph import .. + from dse_graph.query import .. + + # becomes + + from cassandra.datastax.graph.fluent import .. + from cassandra.datastax.graph.fluent.query import .. + +See :mod:`~.datastax.graph.fluent`. + +Session.execute and Session.execute_async API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Although it is not common to use this API with positional arguments, it is +important to be aware that the `host` and `execute_as` parameters have had +their positional order swapped. This is only because `execute_as` was added +in dse-driver before `host`. + +See :meth:`.Session.execute`. + +Deprecations +^^^^^^^^^^^^ + +These changes are optional, but recommended: + +* Importing from `cassandra.graph` is deprecated. Consider importing from `cassandra.datastax.graph`. +* Use :class:`~.policies.DefaultLoadBalancingPolicy` instead of DSELoadBalancingPolicy. + +Upgrading to 3.0 +---------------- +Version 3.0 of the DataStax Python driver for Apache Cassandra +adds support for Cassandra 3.0 while maintaining support for +previously supported versions. In addition to substantial internal rework, +there are several updates to the API that integrators will need +to consider: + +Default consistency is now ``LOCAL_ONE`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Previous value was ``ONE``. The new value is introduced to mesh with the default +DC-aware load balancing policy and to match other drivers. + +Execution API Updates +^^^^^^^^^^^^^^^^^^^^^ +Result return normalization +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +`PYTHON-368 `_ + +Previously results would be returned as a ``list`` of rows for result rows +up to ``fetch_size``, and ``PagedResult`` afterward. This could break +application code that assumed one type and got another. + +Now, all results are returned as an iterable :class:`~.ResultSet`. + +The preferred way to consume results of unknown size is to iterate through +them, letting automatic paging occur as they are consumed. + +.. code-block:: python + + results = session.execute("SELECT * FROM system.local") + for row in results: + process(row) + +If the expected size of the results is known, it is still possible to +materialize a list using the iterator: + +.. code-block:: python + + results = session.execute("SELECT * FROM system.local") + row_list = list(results) + +For backward compatibility, :class:`~.ResultSet` supports indexing. When +accessed at an index, a `~.ResultSet` object will materialize all its pages: + +.. code-block:: python + + results = session.execute("SELECT * FROM system.local") + first_result = results[0] # materializes results, fetching all pages + +This can send requests and load (possibly large) results into memory, so +`~.ResultSet` will log a warning on implicit materialization. + +Trace information is not attached to executed Statements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +`PYTHON-318 `_ + +Previously trace data was attached to Statements if tracing was enabled. This +could lead to confusion if the same statement was used for multiple executions. + +Now, trace data is associated with the ``ResponseFuture`` and ``ResultSet`` +returned for each query: + +:meth:`.ResponseFuture.get_query_trace()` + +:meth:`.ResponseFuture.get_all_query_traces()` + +:meth:`.ResultSet.get_query_trace()` + +:meth:`.ResultSet.get_all_query_traces()` + +Binding named parameters now ignores extra names +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +`PYTHON-178 `_ + +Previously, :meth:`.BoundStatement.bind()` would raise if a mapping +was passed with extra names not found in the prepared statement. + +Behavior in 3.0+ is to ignore extra names. + +blist removed as soft dependency +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +`PYTHON-385 `_ + +Previously the driver had a soft dependency on ``blist sortedset``, using +that where available and using an internal fallback where possible. + +Now, the driver never chooses the ``blist`` variant, instead returning the +internal :class:`.util.SortedSet` for all ``set`` results. The class implements +all standard set operations, so no integration code should need to change unless +it explicitly checks for ``sortedset`` type. + +Metadata API Updates +^^^^^^^^^^^^^^^^^^^^ +`PYTHON-276 `_, `PYTHON-408 `_, `PYTHON-400 `_, `PYTHON-422 `_ + +Cassandra 3.0 brought a substantial overhaul to the internal schema metadata representation. +This version of the driver supports that metadata in addition to the legacy version. Doing so +also brought some changes to the metadata model. + +The present API is documented: :any:`cassandra.metadata`. Changes highlighted below: + +* All types are now exposed as CQL types instead of types derived from the internal server implementation +* Some metadata attributes have changed names to match current nomenclature (for example, :attr:`.Index.kind` in place of ``Index.type``). +* Some metadata attributes removed + + * ``TableMetadata.keyspace`` reference replaced with :attr:`.TableMetadata.keyspace_name` + * ``ColumnMetadata.index`` is removed table- and keyspace-level mappings are still maintained + +Several deprecated features are removed +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +`PYTHON-292 `_ + +* ``ResponseFuture.result`` timeout parameter is removed, use ``Session.execute`` timeout instead (`031ebb0 `_) +* ``Cluster.refresh_schema`` removed, use ``Cluster.refresh_*_metadata`` instead (`419fcdf `_) +* ``Cluster.submit_schema_refresh`` removed (`574266d `_) +* ``cqltypes`` time/date functions removed, use ``util`` entry points instead (`bb984ee `_) +* ``decoder`` module removed (`e16a073 `_) +* ``TableMetadata.keyspace`` attribute replaced with ``keyspace_name`` (`cc94073 `_) +* ``cqlengine.columns.TimeUUID.from_datetime`` removed, use ``util`` variant instead (`96489cc `_) +* ``cqlengine.columns.Float(double_precision)`` parameter removed, use ``columns.Double`` instead (`a2d3a98 `_) +* ``cqlengine`` keyspace management functions are removed in favor of the strategy-specific entry points (`4bd5909 `_) +* ``cqlengine.Model.__polymorphic_*__`` attributes removed, use ``__discriminator*`` attributes instead (`9d98c8e `_) +* ``cqlengine.statements`` will no longer warn about list list prepend behavior (`79efe97 `_) + + +Upgrading to 2.1 from 2.0 +------------------------- +Version 2.1 of the DataStax Python driver for Apache Cassandra +adds support for Cassandra 2.1 and version 3 of the native protocol. + +Cassandra 1.2, 2.0, and 2.1 are all supported. However, 1.2 only +supports protocol version 1, and 2.0 only supports versions 1 and +2, so some features may not be available. + +Using the v3 Native Protocol +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +By default, the driver will attempt to use version 2 of the +native protocol. To use version 3, you must explicitly +set the :attr:`~.Cluster.protocol_version`: + +.. code-block:: python + + from cassandra.cluster import Cluster + + cluster = Cluster(protocol_version=3) + +Note that protocol version 3 is only supported by Cassandra 2.1+. + +In future releases, the driver may default to using protocol version +3. + +Working with User-Defined Types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Cassandra 2.1 introduced the ability to define new types:: + + USE KEYSPACE mykeyspace; + + CREATE TYPE address (street text, city text, zip int); + +The driver generally expects you to use instances of a specific +class to represent column values of this type. You can let the +driver know what class to use with :meth:`.Cluster.register_user_type`: + +.. code-block:: python + + cluster = Cluster() + + class Address(object): + + def __init__(self, street, city, zipcode): + self.street = street + self.city = text + self.zipcode = zipcode + + cluster.register_user_type('mykeyspace', 'address', Address) + +When inserting data for ``address`` columns, you should pass in +instances of ``Address``. When querying data, ``address`` column +values will be instances of ``Address``. + +If no class is registered for a user-defined type, query results +will use a ``namedtuple`` class and data may only be inserted +though prepared statements. + +See :ref:`udts` for more details. + +Customizing Encoders for Non-prepared Statements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Starting with version 2.1 of the driver, it is possible to customize +how Python types are converted to CQL literals when working with +non-prepared statements. This is done on a per-:class:`~.Session` +basis through :attr:`.Session.encoder`: + +.. code-block:: python + + cluster = Cluster() + session = cluster.connect() + session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple + +See :ref:`type-conversions` for the table of default CQL literal conversions. + +Using Client-Side Protocol-Level Timestamps +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +With version 3 of the native protocol, timestamps may be supplied by the +client at the protocol level. (Normally, if they are not specified within +the CQL query itself, a timestamp is generated server-side.) + +When :attr:`~.Cluster.protocol_version` is set to 3 or higher, the driver +will automatically use client-side timestamps with microsecond precision +unless :attr:`.Session.use_client_timestamp` is changed to :const:`False`. +If a timestamp is specified within the CQL query, it will override the +timestamp generated by the driver. + +Upgrading to 2.0 from 1.x +------------------------- +Version 2.0 of the DataStax Python driver for Apache Cassandra +includes some notable improvements over version 1.x. This version +of the driver supports Cassandra 1.2, 2.0, and 2.1. However, not +all features may be used with Cassandra 1.2, and some new features +in 2.1 are not yet supported. + +Using the v2 Native Protocol +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +By default, the driver will attempt to use version 2 of Cassandra's +native protocol. You can explicitly set the protocol version to +2, though: + +.. code-block:: python + + from cassandra.cluster import Cluster + + cluster = Cluster(protocol_version=2) + +When working with Cassandra 1.2, you will need to +explicitly set the :attr:`~.Cluster.protocol_version` to 1: + +.. code-block:: python + + from cassandra.cluster import Cluster + + cluster = Cluster(protocol_version=1) + +Automatic Query Paging +^^^^^^^^^^^^^^^^^^^^^^ +Version 2 of the native protocol adds support for automatic query +paging, which can make dealing with large result sets much simpler. + +See :ref:`query-paging` for full details. + +Protocol-Level Batch Statements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +With version 1 of the native protocol, batching of statements required +using a `BATCH cql query `_. +With version 2 of the native protocol, you can now batch statements at +the protocol level. This allows you to use many different prepared +statements within a single batch. + +See :class:`~.query.BatchStatement` for details and usage examples. + +SASL-based Authentication +^^^^^^^^^^^^^^^^^^^^^^^^^ +Also new in version 2 of the native protocol is SASL-based authentication. +See the section on :ref:`security` for details and examples. + +Lightweight Transactions +^^^^^^^^^^^^^^^^^^^^^^^^ +`Lightweight transactions `_ are another new feature. To use lightweight transactions, add ``IF`` clauses +to your CQL queries and set the :attr:`~.Statement.serial_consistency_level` +on your statements. + +Calling Cluster.shutdown() +^^^^^^^^^^^^^^^^^^^^^^^^^^ +In order to fix some issues around garbage collection and unclean interpreter +shutdowns, version 2.0 of the driver requires you to call :meth:`.Cluster.shutdown()` +on your :class:`~.Cluster` objects when you are through with them. +This helps to guarantee a clean shutdown. + +Deprecations +^^^^^^^^^^^^ +The following functions have moved from ``cassandra.decoder`` to ``cassandra.query``. +The original functions have been left in place with a :exc:`DeprecationWarning` for +now: + +* :attr:`cassandra.decoder.tuple_factory` has moved to + :attr:`cassandra.query.tuple_factory` +* :attr:`cassandra.decoder.named_tuple_factory` has moved to + :attr:`cassandra.query.named_tuple_factory` +* :attr:`cassandra.decoder.dict_factory` has moved to + :attr:`cassandra.query.dict_factory` +* :attr:`cassandra.decoder.ordered_dict_factory` has moved to + :attr:`cassandra.query.ordered_dict_factory` + +Dependency Changes +^^^^^^^^^^^^^^^^^^ +The following dependencies have officially been made optional: + +* ``scales`` +* ``blist`` diff --git a/docs/user_defined_types.rst b/docs/user_defined_types.rst new file mode 100644 index 0000000000..32c03e37e8 --- /dev/null +++ b/docs/user_defined_types.rst @@ -0,0 +1,118 @@ +.. _udts: + +User Defined Types +================== +Cassandra 2.1 introduced user-defined types (UDTs). You can create a +new type through ``CREATE TYPE`` statements in CQL:: + + CREATE TYPE address (street text, zip int); + +Version 2.1 of the Python driver adds support for user-defined types. + +Registering a UDT +----------------- +You can tell the Python driver to return columns of a specific UDT as +instances of a class or a dict by registering them with your :class:`~.Cluster` +instance through :meth:`.Cluster.register_user_type`: + + +Map a Class to a UDT +++++++++++++++++++++ + +.. code-block:: python + + cluster = Cluster(protocol_version=3) + session = cluster.connect() + session.set_keyspace('mykeyspace') + session.execute("CREATE TYPE address (street text, zipcode int)") + session.execute("CREATE TABLE users (id int PRIMARY KEY, location frozen
)") + + # create a class to map to the "address" UDT + class Address(object): + + def __init__(self, street, zipcode): + self.street = street + self.zipcode = zipcode + + cluster.register_user_type('mykeyspace', 'address', Address) + + # insert a row using an instance of Address + session.execute("INSERT INTO users (id, location) VALUES (%s, %s)", + (0, Address("123 Main St.", 78723))) + + # results will include Address instances + results = session.execute("SELECT * FROM users") + row = results[0] + print(row.id, row.location.street, row.location.zipcode) + +Map a dict to a UDT ++++++++++++++++++++ + +.. code-block:: python + + cluster = Cluster(protocol_version=3) + session = cluster.connect() + session.set_keyspace('mykeyspace') + session.execute("CREATE TYPE address (street text, zipcode int)") + session.execute("CREATE TABLE users (id int PRIMARY KEY, location frozen
)") + + cluster.register_user_type('mykeyspace', 'address', dict) + + # insert a row using a prepared statement and a tuple + insert_statement = session.prepare("INSERT INTO mykeyspace.users (id, location) VALUES (?, ?)") + session.execute(insert_statement, [0, ("123 Main St.", 78723)]) + + # results will include dict instances + results = session.execute("SELECT * FROM users") + row = results[0] + print(row.id, row.location['street'], row.location['zipcode']) + +Using UDTs Without Registering Them +----------------------------------- +Although it is recommended to register your types with +:meth:`.Cluster.register_user_type`, the driver gives you some options +for working with unregistered UDTS. + +When you use prepared statements, the driver knows what data types to +expect for each placeholder. This allows you to pass any object you +want for a UDT, as long as it has attributes that match the field names +for the UDT: + +.. code-block:: python + + cluster = Cluster(protocol_version=3) + session = cluster.connect() + session.set_keyspace('mykeyspace') + session.execute("CREATE TYPE address (street text, zipcode int)") + session.execute("CREATE TABLE users (id int PRIMARY KEY, location frozen
)") + + class Foo(object): + + def __init__(self, street, zipcode, otherstuff): + self.street = street + self.zipcode = zipcode + self.otherstuff = otherstuff + + insert_statement = session.prepare("INSERT INTO users (id, location) VALUES (?, ?)") + + # since we're using a prepared statement, we don't *have* to register + # a class to map to the UDT to insert data. The object just needs to have + # "street" and "zipcode" attributes (which Foo does): + session.execute(insert_statement, [0, Foo("123 Main St.", 78723, "some other stuff")]) + + # when we query data, UDT columns that don't have a class registered + # will be returned as namedtuples: + results = session.execute("SELECT * FROM users") + first_row = results[0] + address = first_row.location + print(address) # prints "Address(street='123 Main St.', zipcode=78723)" + street = address.street + zipcode = address.street + +As shown in the code example, inserting data for UDT columns without registering +a class works fine for prepared statements. However, **you must register a +class to insert UDT columns with unprepared statements**.\* You can still query +UDT columns without registered classes using unprepared statements, they will +simply return ``namedtuple`` instances (just like prepared statements do). + +\* this applies to *parameterized* unprepared statements, in which the driver will be formatting parameters -- not statements with interpolated UDT literals. diff --git a/doxyfile b/doxyfile new file mode 100644 index 0000000000..d453557e22 --- /dev/null +++ b/doxyfile @@ -0,0 +1,2339 @@ +# Doxyfile 1.8.8 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the config file +# that follow. The default is UTF-8 which is also the encoding used for all text +# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv +# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv +# for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = "Python Driver" + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify an logo or icon that is included in +# the documentation. The maximum height of the logo should not exceed 55 pixels +# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo +# to the output directory. + +PROJECT_LOGO = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = + +# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub- +# directories (in 2 levels) under the output directory of each output format and +# will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, +# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), +# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, +# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), +# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, +# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, +# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, +# Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = NO + +# If the REPEAT_BRIEF tag is set to YES doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = NO + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES, then doxygen will produce a +# new page for each member. If set to NO, the documentation of a member will be +# part of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:\n" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". You can put \n's in the value part of an alias to insert +# newlines. + +ALIASES = "test_assumptions=\par Test Assumptions\n" \ + "note=\par Note\n" \ + "test_category=\par Test Category\n" \ + "jira_ticket=\par JIRA Ticket\n" \ + "expected_result=\par Expected Result\n" \ + "since=\par Since\n" \ + "param=\par Parameters\n" \ + "return=\par Return\n" \ + "expected_errors=\par Expected Errors\n" + +# This tag can be used to specify a number of word-keyword mappings (TCL only). +# A mapping has the form "name=value". For example adding "class=itcl::class" +# will allow you to use the command class in the itcl::class meaning. + +TCL_SUBST = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = YES + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, Javascript, +# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: +# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: +# Fortran. In the later case the parser tries to guess whether the code is fixed +# or free formatted code, this is the default for Fortran type files), VHDL. For +# instance to make doxygen treat .inc files as Fortran files (default is PHP), +# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# +# Note For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. + +EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See http://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by by putting a % sign in front of the word +# or globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = NO + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES, then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = NO + +# If the EXTRACT_PACKAGE tag is set to YES all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. When set to YES local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO these classes will be included in the various overviews. This option has +# no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# (class|struct|union) declarations. If set to NO these declarations will be +# included in the documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file +# names in lower-case letters. If set to YES upper-case letters are also +# allowed. This is useful if you have classes or files whose names only differ +# in case and if your file system supports case sensitive file names. Windows +# and Mac users are advised to set this option to NO. +# The default value is: system dependent. + +CASE_SENSE_NAMES = YES + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = NO + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable ( YES) or disable ( NO) the +# todo list. This list is created by putting \todo commands in the +# documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable ( YES) or disable ( NO) the +# test list. This list is created by putting \test commands in the +# documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable ( YES) or disable ( NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable ( YES) or disable ( NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES the list +# will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USED_FILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error ( stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES, then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as not documenting some parameters +# in a documented function, or documenting parameters that don't exist or using +# markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO doxygen will only warn about wrong or incomplete parameter +# documentation, but not about the absence of documentation. +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. +# Note: If this tag is empty the current directory is searched. + +INPUT = ./tests + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: http://www.gnu.org/software/libiconv) for the list of +# possible encodings. +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank the +# following patterns are tested:*.c, *.cc, *.cxx, *.cpp, *.c++, *.java, *.ii, +# *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, *.hh, *.hxx, *.hpp, +# *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, *.m, *.markdown, +# *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf, +# *.qsf, *.as and *.js. + +FILE_PATTERNS = *.py + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = YES + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + +EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# AClass::ANamespace, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + +EXCLUDE_SYMBOLS = @Test + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. + +INPUT_FILTER = "python /usr/local/bin/doxypy.py" + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER ) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = YES + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + +USE_MDFILE_AS_MAINPAGE = + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# function all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES, then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see http://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = YES + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in +# which the alphabetical index list will be split. +# Minimum value: 1, maximum value: 20, default value: 5. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +COLS_IN_ALPHA_INDEX = 5 + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefor more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra stylesheet files is of importance (e.g. the last +# stylesheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the stylesheet and background images according to +# this color. Hue is specified as an angle on a colorwheel, see +# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use grayscales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to NO can help when comparing the output of multiple runs. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_TIMESTAMP = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: http://developer.apple.com/tools/xcode/), introduced with +# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a +# Makefile in the HTML output directory. Running make will produce the docset in +# that directory and running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html +# for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on +# Windows. +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler ( hhc.exe). If non-empty +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated ( +# YES) or that it should be included in the master .chm file ( NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index ( hhk), content ( hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated ( +# YES) or a normal table of contents ( NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- +# folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location of Qt's +# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the +# generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom stylesheets (see HTML_EXTRA_STYLESHEET) one can +# further fine-tune the look of the index. As an example, the default style +# sheet generated by doxygen has an example that shows how to put an image at +# the root of the tree instead of the PROJECT_NAME. Since the tree basically has +# the same information as the tab index, you could consider setting +# DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = YES + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# When the EXT_LINKS_IN_WINDOW option is set to YES doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_TRANSPARENT = YES + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# http://www.mathjax.org) which uses client side Javascript for the rendering +# instead of using prerendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = NO + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. See the MathJax site (see: +# http://docs.mathjax.org/en/latest/output.html) for more details. +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility), NativeMML (i.e. MathML) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from http://www.mathjax.org before deployment. +# The default value is: http://cdn.mathjax.org/mathjax/latest. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /