diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 3a3d93171a..a1a6c854c7 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -24,11 +24,11 @@ jobs: permissions: id-token: write steps: - - uses: actions/download-artifact@v8 + - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: path: dist merge-multiple: true - - uses: pypa/gh-action-pypi-publish@release/v1 + - uses: pypa/gh-action-pypi-publish@cef2210092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0 with: skip-existing: true diff --git a/.github/workflows/call_jira_sync.yml b/.github/workflows/call_jira_sync.yml index 14f517df40..0855246f48 100644 --- a/.github/workflows/call_jira_sync.yml +++ b/.github/workflows/call_jira_sync.yml @@ -11,7 +11,7 @@ permissions: jobs: jira-sync: - uses: scylladb/github-automation/.github/workflows/main_pr_events_jira_sync.yml@main + uses: scylladb/github-automation/.github/workflows/main_pr_events_jira_sync.yml@83115dc2553dbf968e73271e97fc7aac16b8145a # main 2026-05-20 with: caller_action: ${{ github.event.action }} secrets: diff --git a/.github/workflows/docs-pages.yml b/.github/workflows/docs-pages.yml index 9d14b9c4d8..a413e3317e 100644 --- a/.github/workflows/docs-pages.yml +++ b/.github/workflows/docs-pages.yml @@ -24,14 +24,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.repository.default_branch }} persist-credentials: false fetch-depth: 0 - name: Install uv - uses: astral-sh/setup-uv@v8.1.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: working-directory: docs enable-cache: true diff --git a/.github/workflows/docs-pr.yml b/.github/workflows/docs-pr.yml index f0aa64d628..1881c227ed 100644 --- a/.github/workflows/docs-pr.yml +++ b/.github/workflows/docs-pr.yml @@ -31,13 +31,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false fetch-depth: 0 - name: Install uv - uses: astral-sh/setup-uv@v8.1.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: working-directory: docs enable-cache: true diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index fde1ab3e1d..5e76d6bbb4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -38,7 +38,7 @@ jobs: if: "!contains(github.event.pull_request.labels.*.name, 'disable-integration-tests')" runs-on: ubuntu-24.04 env: - SCYLLA_VERSION: release:2025.2 + SCYLLA_VERSION: release:2026.1 strategy: fail-fast: false matrix: @@ -56,10 +56,10 @@ jobs: event_loop_manager: "asyncore" steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up JDK ${{ matrix.java-version }} - uses: actions/setup-java@v5 + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5.2.0 with: java-version: ${{ matrix.java-version }} distribution: 'adopt' @@ -68,7 +68,7 @@ jobs: run: sudo apt-get install libev4 libev-dev - name: Install uv - uses: astral-sh/setup-uv@v8.1.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: python-version: ${{ matrix.python-version }} @@ -78,7 +78,7 @@ jobs: run: uv sync - name: Cache Scylla download - uses: actions/cache@v5 + uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: path: ~/.ccm/repository key: scylla-${{ env.SCYLLA_VERSION }}-${{ runner.os }} diff --git a/.github/workflows/lib-build.yml b/.github/workflows/lib-build.yml index 21dcc0604f..f6959ddfec 100644 --- a/.github/workflows/lib-build.yml +++ b/.github/workflows/lib-build.yml @@ -77,11 +77,11 @@ jobs: include: ${{ fromJson(needs.prepare-matrix.outputs.matrix) }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Checkout tag ${{ inputs.target_tag }} if: inputs.target_tag != '' - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ inputs.target_tag }} @@ -96,7 +96,7 @@ jobs: echo "CIBW_BEFORE_TEST_WINDOWS=(exit 0)" >> $GITHUB_ENV; - name: Install uv - uses: astral-sh/setup-uv@v8.1.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: python-version: ${{ inputs.python-version }} @@ -111,7 +111,7 @@ jobs: - name: Install Conan if: runner.os == 'Windows' - uses: turtlebrowser/get-conan@main + uses: turtlebrowser/get-conan@c171f295f3f507360ee018736a6608731aa2109d # v1.2 - name: Configure libev for Windows if: runner.os == 'Windows' @@ -147,7 +147,7 @@ jobs: run: | CIBW_BUILD="cp3*" cibuildwheel --archs aarch64 --output-dir wheelhouse - - uses: actions/upload-artifact@v7 + - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: wheels-${{ matrix.target }}-${{ matrix.os }} path: ./wheelhouse/*.whl @@ -156,17 +156,17 @@ jobs: name: Build source distribution runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install uv - uses: astral-sh/setup-uv@v8.1.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: python-version: ${{ inputs.python-version }} - name: Build sdist run: uv build --sdist - - uses: actions/upload-artifact@v7 + - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: source-dist path: dist/*.tar.gz diff --git a/.github/workflows/publish-manually.yml b/.github/workflows/publish-manually.yml index 2f15c6ecda..5b9298fb7f 100644 --- a/.github/workflows/publish-manually.yml +++ b/.github/workflows/publish-manually.yml @@ -58,11 +58,11 @@ jobs: permissions: id-token: write steps: - - uses: actions/download-artifact@v8 + - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: path: dist merge-multiple: true - - uses: pypa/gh-action-pypi-publish@release/v1 + - uses: pypa/gh-action-pypi-publish@cef2210092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0 with: skip-existing: true diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3ae00a7ee8..39a8aca069 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,28 @@ +3.29.10 +======= +May 10, 2026 + +Features +-------- +* Fast-path ``lookup_casstype()`` for simple type names +* Add ``Session.wait_for_schema_agreement`` + +Bug Fixes +--------- +* Fix CQL injection in ``Connection.set_keyspace_blocking`` and ``Connection.set_keyspace_async`` +* Fix libev shutdown crashes by correcting atexit registration +* Handle ``None`` ``control_connection_timeout`` in ``wait_for_schema_agreement`` +* Clean up failed heartbeat sends +* Fix ``ExponentialBackoffRetryPolicy.__init__`` super() call +* Correct ``clustering_key`` to ``clustering`` in column kind filter +* Fix inverted cooldown check in ``_get_shard_aware_endpoint`` + +Others +------ +* Deprecate ``ControlConnection.wait_for_schema_agreement`` +* Add timeout and in-flight observability to ``OperationTimedOut`` +* Drop per-query connection log + 3.29.9 ====== March 18, 2026 diff --git a/benchmarks/base.py b/benchmarks/base.py index d9cd004474..3922eefad5 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -97,7 +97,7 @@ def setup(options): try: session.execute(""" CREATE KEYSPACE %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' } """ % options.keyspace) log.debug("Setting keyspace...") diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 46de7daaf0..1286f20e9b 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -23,7 +23,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 29, 9) +__version_info__ = (3, 29, 10) __version__ = '.'.join(map(str, __version_info__)) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..1181c6f686 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -20,16 +20,18 @@ import atexit import datetime +from enum import Enum from binascii import hexlify from collections import defaultdict from collections.abc import Mapping -from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures +from concurrent.futures import Future, ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy from functools import partial, reduce, wraps from itertools import groupby, count, chain +import enum import json import logging -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Tuple from warnings import warn from random import random import re @@ -214,6 +216,14 @@ def __init__(self, message, errors): self.errors = errors +class SchemaAgreementScope(str, Enum): + """Scope selectors for :meth:`.Session.wait_for_schema_agreement`.""" + + RACK = 'rack' + DC = 'dc' + CLUSTER = 'cluster' + + def _future_completed(future): """ Helper for run_in_executor() """ exc = future.exception() @@ -505,8 +515,9 @@ def __init__(self, load_balancing_policy=None, retry_policy=None, class ProfileManager(object): - def __init__(self): + def __init__(self, pools_allowed: bool=True): self.profiles = dict() + self.pools_allowed = pools_allowed def _profiles_without_explicit_lbps(self): names = (profile_name for @@ -518,6 +529,8 @@ def _profiles_without_explicit_lbps(self): ) def distance(self, host): + if not self.pools_allowed: + return HostDistance.IGNORED distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \ HostDistance.LOCAL if HostDistance.LOCAL in distances else \ @@ -533,10 +546,14 @@ def check_supported(self): p.load_balancing_policy.check_supported() def on_up(self, host): + if not self.pools_allowed: + return for p in self.profiles.values(): p.load_balancing_policy.on_up(host) def on_down(self, host): + if not self.pools_allowed: + return for p in self.profiles.values(): p.load_balancing_policy.on_down(host) @@ -610,6 +627,31 @@ class _ConfigMode(object): PROFILES = 2 +class ControlConnectionQueryFallback(enum.Enum): + """ + Controls how application queries use the control connection when node pools + are unavailable. + + ``Disabled`` requires a usable node pool for application queries. If the + driver cannot establish one during session startup, it raises + :class:`NoHostAvailable`. + + ``Fallback`` still attempts to create node pools, but allows application + queries to fall back to the control connection when no usable node pool is + available. Session startup is allowed to proceed even if the initial pool + attempts all fail. + + ``SkipPoolCreation`` disables node-pool creation for the session and uses + the control-connection fallback path for application queries. + + The fallback path is not used for requests targeted to an explicit host. + """ + + Disabled = "Disabled" + Fallback = "Fallback" + SkipPoolCreation = "SkipPoolCreation" + + class Cluster(object): """ The main class to use when interacting with a Cassandra cluster. @@ -930,6 +972,16 @@ def default_retry_policy(self, policy): If set to :const:`None`, there will be no timeout for these queries. """ + allow_control_connection_query_fallback: ControlConnectionQueryFallback = ControlConnectionQueryFallback.Disabled + """ + Controls whether application queries may fall back to the control connection. + + ``Disabled`` keeps the old behavior. + ``Fallback`` enables control-connection fallback when no usable node pools exist. + ``SkipPoolCreation`` skips node-pool creation and uses the control connection fallback path. + This fallback is still not used for requests targeted to an explicit host. + """ + idle_heartbeat_interval = 30 """ Interval, in seconds, on which to heartbeat idle connections. This helps @@ -1216,7 +1268,8 @@ def __init__(self, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, application_info:Optional[ApplicationInfoBase]=None, - client_routes_config:Optional[ClientRoutesConfig]=None + client_routes_config:Optional[ClientRoutesConfig]=None, + allow_control_connection_query_fallback:Optional[ControlConnectionQueryFallback]=ControlConnectionQueryFallback.Disabled ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1234,6 +1287,10 @@ def __init__(self, if port < 1 or port > 65535: raise ValueError("Invalid port number (%s) (1-65535)" % port) + if not isinstance(allow_control_connection_query_fallback, ControlConnectionQueryFallback): + raise TypeError( + "allow_control_connection_query_fallback must be a ControlConnectionQueryFallback value") + if connection_class is not None: self.connection_class = connection_class @@ -1395,7 +1452,8 @@ def __init__(self, else: self.timestamp_generator = MonotonicTimestampGenerator() - self.profile_manager = ProfileManager() + self.profile_manager = ProfileManager( + pools_allowed=allow_control_connection_query_fallback != ControlConnectionQueryFallback.SkipPoolCreation) self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( self.load_balancing_policy, self.default_retry_policy, @@ -1464,6 +1522,7 @@ def __init__(self, self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout + self.allow_control_connection_query_fallback = allow_control_connection_query_fallback self.metadata_request_timeout = self.control_connection_timeout if metadata_request_timeout is None else metadata_request_timeout self.idle_heartbeat_interval = idle_heartbeat_interval self.idle_heartbeat_timeout = idle_heartbeat_timeout @@ -1806,7 +1865,8 @@ def get_all_pools(self): return pools def is_shard_aware(self): - return bool(self.get_all_pools()[0].host.sharding_info) + pools = self.get_all_pools() + return bool(pools and pools[0].host.sharding_info) def shard_aware_stats(self): if self.is_shard_aware(): @@ -1911,7 +1971,7 @@ def on_up(self, host): """ Intended for internal use only. """ - if self.is_shutdown: + if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation: return log.debug("Waiting to acquire lock for handling up status of node %s", host) @@ -2019,7 +2079,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ Intended for internal use only. """ - if self.is_shutdown: + if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation: return with host.lock: @@ -2624,20 +2684,24 @@ def __init__(self, cluster, hosts, keyspace=None): # create connection pools in parallel self._initial_connect_futures = set() - for host in hosts: - future = self.add_or_renew_pool(host, is_host_addition=False) - if future: - self._initial_connect_futures.add(future) - - futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) - while futures.not_done and not any(f.result() for f in futures.done): - futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) - - if not any(f.result() for f in self._initial_connect_futures): - msg = "Unable to connect to any servers" - if self.keyspace: - msg += " using keyspace '%s'" % self.keyspace - raise NoHostAvailable(msg, [h.address for h in hosts]) + fallback_mode = self.cluster.allow_control_connection_query_fallback + if fallback_mode is not ControlConnectionQueryFallback.SkipPoolCreation: + for host in hosts: + future = self.add_or_renew_pool(host, is_host_addition=False) + if future: + self._initial_connect_futures.add(future) + + futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + while futures.not_done and not any(f.result() for f in futures.done): + futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) + + # Only Disabled requires an initial pool to come up. + if not any(f.result() for f in self._initial_connect_futures) and \ + fallback_mode is ControlConnectionQueryFallback.Disabled: + msg = "Unable to connect to any servers" + if self.keyspace: + msg += " using keyspace '%s'" % self.keyspace + raise NoHostAvailable(msg, [h.address for h in hosts]) self.session_id = uuid.uuid4() @@ -3236,6 +3300,9 @@ def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ + if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation: + return None + distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None @@ -3306,6 +3373,9 @@ def update_created_pools(self): For internal use only. """ + if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation: + return set() + futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) @@ -3374,6 +3444,185 @@ def pool_finished_setting_keyspace(pool, host_errors): for pool in tuple(self._pools.values()): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) + def wait_for_schema_agreement(self, wait_time: Optional[float] = None, + scope: SchemaAgreementScope = SchemaAgreementScope.CLUSTER) -> bool: + """ + Wait for connected hosts in the selected scope to report the same + schema version from ``system.local``. + + By default, the timeout for this operation is governed by + :attr:`~.Cluster.max_schema_agreement_wait` and + :attr:`~.Cluster.control_connection_timeout`. + + Passing ``wait_time`` here overrides + :attr:`~.Cluster.max_schema_agreement_wait`. If provided, ``wait_time`` + must be greater than 0. + + ``scope`` determines which connected hosts participate in the check. + Pass :attr:`SchemaAgreementScope.RACK`, :attr:`SchemaAgreementScope.DC`, + or :attr:`SchemaAgreementScope.CLUSTER`. + The default is :attr:`SchemaAgreementScope.CLUSTER`. ``RACK`` narrows + the check to connected hosts in the local rack only. ``DC`` checks + connected hosts in the local datacenter. ``CLUSTER`` queries every + connected host across all datacenters. + + :param wait_time: Override for + :attr:`~.Cluster.max_schema_agreement_wait`, should be positive + number. + :param scope: Restricts the check to connected hosts in the local rack, + local datacenter, or whole connected cluster. + :returns: ``True`` when the selected connected hosts agree on schema, + otherwise ``False``. + :raises ValueError: If ``wait_time`` is provided and is not greater + than 0. + :raises ValueError: If ``scope`` is not one of the schema agreement + scope values. + """ + + if wait_time is not None and wait_time <= 0: + raise ValueError("wait_time must be greater than 0") + + total_timeout = wait_time if wait_time is not None else self.cluster.max_schema_agreement_wait + if total_timeout <= 0: + raise ValueError("total_timeout must be greater than 0") + + deadline = time.time() + total_timeout + schema_mismatches = None + scope_label = 'local rack' if scope is SchemaAgreementScope.RACK else ( + 'local datacenter' if scope is SchemaAgreementScope.DC else 'cluster') + + while time.time() < deadline: + schema_mismatches = self._get_schema_mismatches_for_scope(deadline, scope) + if schema_mismatches is None: + return True + + log.debug("[session] Connected hosts in the %s still disagree on schema, trying again", scope_label) + remaining = deadline - time.time() + if remaining > 0: + time.sleep(min(0.2, remaining)) + + log.warning("[session] Connected hosts in the %s are reporting a schema disagreement: %s", + scope_label, schema_mismatches) + return False + + def _get_schema_mismatches_for_scope(self, deadline: float, + scope: SchemaAgreementScope) -> Optional[Dict[Any, Any]]: + hosts = self._get_schema_agreement_hosts(scope) + mismatches = defaultdict(list) + errors = {} + scope_label = 'local rack' if scope is SchemaAgreementScope.RACK else ( + 'local datacenter' if scope is SchemaAgreementScope.DC else 'cluster') + + if not hosts: + errors[scope.value] = ConnectionException( + "No connected hosts available in the %s" % (scope_label,) + ) + return {'unavailable': errors} + + metadata_request_timeout = self.cluster.control_connection._metadata_request_timeout + query = maybe_add_timeout_to_query(ControlConnection._SELECT_SCHEMA_LOCAL, metadata_request_timeout) + + schema_version_futures = [] + for host in hosts: + try: + schema_version_future = self._query_local_schema_version(host, query, deadline) + except Exception as exc: + errors[host.endpoint] = exc + continue + + schema_version_futures.append((host, schema_version_future)) + + if schema_version_futures: + # Start all host queries first, then wait for the whole batch. + remaining = max(0.0, deadline - time.time()) + if remaining > 0: + wait_futures([future for _, future in schema_version_futures], timeout=remaining) + + for host, future in schema_version_futures: + if future.done(): + try: + rows = future.result() + except Exception as exc: + errors[host.endpoint] = exc + continue + + row = rows.one() + schema_version = getattr(row, "schema_version", None) if row is not None else None + mismatches[schema_version].append(host.endpoint) + else: + errors[host.endpoint] = OperationTimedOut(last_host=host, timeout=max(0.0, deadline - time.time())) + + if len(mismatches) == 1 and None not in mismatches and not errors: + log.debug("[session] Connected hosts in the %s agree on schema", scope_label) + return None + + if errors: + mismatches['unavailable'] = errors + return dict(mismatches) + + def _get_schema_agreement_hosts(self, scope: SchemaAgreementScope) -> Tuple[Host, ...]: + if scope is SchemaAgreementScope.RACK: + allowed_distances = (HostDistance.LOCAL_RACK,) + elif scope is SchemaAgreementScope.DC: + allowed_distances = (HostDistance.LOCAL_RACK, HostDistance.LOCAL) + else: + allowed_distances = (HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE) + + return tuple( + host for host, pool in tuple(self._pools.items()) + if host.is_up + and not pool.is_shutdown + and self._profile_manager.distance(host) in allowed_distances) + + def _query_local_schema_version(self, host: Host, query: str, deadline: float) -> Future: + remaining = max(0.0, deadline - time.time()) + try: + response_future = self.execute_async( + query, + timeout=self._schema_agreement_query_timeout(remaining), + host=host, + ) + except OperationTimedOut as timeout: + log.debug("[session] Timed out waiting for schema version from %s: %s", host, timeout) + raise + except Exception as exc: + log.debug("[session] Error querying schema version from %s: %s", host, exc) + raise + + # execute_async returns cassandra.cluster.ResponseFuture, which does not have bulk waiting logic for it. + # That is why _query_local_schema_version returns concurrent.futures.Future + # so that schema agreement logic could use concurrent.futures.wait_futures to wait on them. + # schema_version_future is an adapter between cassandra.cluster.ResponseFuture and concurrent.futures.Future + # to make things work + schema_version_future = Future() + + def _set_result(result, result_future=schema_version_future, response_future=response_future): + if result_future.done(): + return + try: + result_future.set_result(ResultSet(response_future, result)) + except Exception as exc: + result_future.set_exception(exc) + + def _set_exception(exc, result_future=schema_version_future): + if result_future.done(): + return + result_future.set_exception(exc) + + try: + response_future.add_callbacks(_set_result, _set_exception) + except Exception as exc: + log.debug("[session] Error registering schema version callback from %s: %s", host, exc) + raise + + return schema_version_future + + def _schema_agreement_query_timeout(self, remaining: float) -> float: + control_timeout = self.cluster.control_connection._timeout + if control_timeout is None: + return max(0.0, remaining) + return max(0.0, min(control_timeout, remaining)) + def user_type_registered(self, keyspace, user_type, klass): """ Called by the parent Cluster instance when the user registers a new @@ -3786,7 +4035,7 @@ def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_w if self._cluster.is_shutdown: return False - agreed = self.wait_for_schema_agreement(connection, + agreed = self._wait_for_schema_agreement(connection=connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) @@ -4079,7 +4328,30 @@ def _handle_schema_change(self, event): self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): + """ + Wait for schema agreement from the control connection's metadata view. + + This method is intended for internal metadata refresh flows. External + callers should use :meth:`.Session.wait_for_schema_agreement` instead. + + The control connection observes schema agreement from its own + perspective, which may include hosts the session is not using, and it + may fail when the control connection itself is transiently unhealthy. + That can produce false positives or failures that do not reflect + whether a session can safely proceed. + + .. deprecated:: 3.30.0 + Use :meth:`.Session.wait_for_schema_agreement` instead. + """ + warn("ControlConnection.wait_for_schema_agreement is deprecated and will be removed in 4.0. " + "Use Session.wait_for_schema_agreement instead. " + "This method is for internal metadata refresh use only.", + DeprecationWarning, stacklevel=2) + return self._wait_for_schema_agreement(connection=connection, + preloaded_results=preloaded_results, + wait_time=wait_time) + def _wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait if total_timeout <= 0: return True @@ -4439,6 +4711,7 @@ class ResponseFuture(object): _spec_execution_plan = NoSpeculativeExecutionPlan() _continuous_paging_session = None _host = None + _control_connection_query_attempted = False _TABLET_ROUTING_CTYPE = None _warned_timeout = False @@ -4459,6 +4732,7 @@ def __init__(self, session, message, query, timeout, metrics=None, prepared_stat self._callback_lock = Lock() self._start_time = start_time or time.time() self._host = host + self._control_connection_query_attempted = False self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self._make_query_plan() self._event = Event() @@ -4537,11 +4811,22 @@ def _on_timeout(self, _attempts=0): self._connection.orphaned_threshold_reached = True pool.return_connection(self._connection, stream_was_orphaned=True) + elif self._connection.is_control_connection: + with self._connection.lock: + self._connection.orphaned_request_ids.add(self._req_id) + if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold: + self._connection.orphaned_threshold_reached = True errors = self._errors if not errors: if self.is_schema_agreed: - key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' + if self._current_host is None: + key = 'no host queried before timeout' + elif self._connection is not None and self._connection.is_control_connection: + control_host = self.session.cluster.get_control_connection_host() + key = str(control_host.endpoint) if control_host is not None else str(self._connection.endpoint) + else: + key = str(self._current_host.endpoint) errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} else: connection = self.session.cluster.control_connection._connection @@ -4599,14 +4884,110 @@ def send_request(self, error_no_hosts=True): self._on_timeout() return True if error_no_hosts: + if self._fallback_to_control_connection(): + req_id = self._query_control_connection() + if req_id is not None: + self._req_id = req_id + return True + self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) return False + def _has_usable_node_pool(self): + try: + pools = tuple(self.session._pools.values()) + except (AttributeError, TypeError): + return False + + return any(pool and not pool.is_shutdown for pool in pools) + + def _fallback_to_control_connection(self): + fallback_mode = self.session.cluster.allow_control_connection_query_fallback + if fallback_mode is ControlConnectionQueryFallback.Disabled: + return False + if self._host or self._control_connection_query_attempted: + return False + if fallback_mode is ControlConnectionQueryFallback.SkipPoolCreation: + return True + return not self._has_usable_node_pool() + + def _borrow_control_connection(self, connection): + with connection.lock: + if connection.in_flight >= connection.max_request_id: + raise NoConnectionsAvailable("All request IDs are currently in use") + connection.in_flight += 1 + return connection.get_request_id() + + def _release_control_connection_request(self, connection, request_id): + with connection.lock: + connection.in_flight -= 1 + connection.request_ids.append(request_id) + connection._requests.pop(request_id, None) + + def _handle_control_connection_response(self, connection, cb, response): + with connection.lock: + connection.in_flight -= 1 + cb(response) + + def _query_control_connection(self, message=None, cb=None, connection=None, host=None): + self._control_connection_query_attempted = True + + if message is None: + message = self.message + + if connection is None: + control_connection = self.session.cluster.control_connection + connection = control_connection._connection if control_connection else None + if not connection: + self._errors['control connection'] = ConnectionException("Control connection is not connected") + return None + + if host is None: + host = self.session.cluster.get_control_connection_host() or connection.endpoint + self._current_host = host + + request_id = None + request_sent = False + try: + request_id = self._borrow_control_connection(connection) + self._connection = connection + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + if cb is None: + cb = partial(self._set_result, host, connection, None) + cb = partial(self._handle_control_connection_response, connection, cb) + + log.debug("No usable node pools; falling back to control connection for host %s", host) + self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, + encoder=self._protocol_handler.encode_message, + decoder=self._protocol_handler.decode_message, + result_metadata=result_meta) + request_sent = True + self.attempted_hosts.append(host) + return request_id + except NoConnectionsAvailable as exc: + log.debug("Control connection is at capacity") + self._errors[host] = exc + except ConnectionBusy as exc: + log.debug("Control connection is busy") + self._errors[host] = exc + except Exception as exc: + log.debug("Error querying control connection", exc_info=True) + self._errors[host] = exc + if self._metrics is not None: + self._metrics.on_connection_error() + finally: + if request_id is not None and not request_sent: + self._release_control_connection_request(connection, request_id) + + return None + def _query(self, host, message=None, cb=None): if message is None: message = self.message + self._control_connection_query_attempted = False + pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") @@ -4717,12 +5098,17 @@ def start_fetching_next_page(self): self._event.clear() self._final_result = _NOT_SET self._final_exception = None + self._control_connection_query_attempted = False self._start_timer() self.send_request() def _reprepare(self, prepare_message, host, connection, pool): cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) - request_id = self._query(host, prepare_message, cb=cb) + if pool is None and connection is not None and connection.is_control_connection: + request_id = self._query_control_connection(prepare_message, cb=cb, + connection=connection, host=host) + else: + request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() @@ -4761,6 +5147,8 @@ def _set_result(self, host, connection, pool, response): if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) + if connection is not None: + connection.keyspace = response.new_keyspace # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event @@ -4937,10 +5325,13 @@ def _execute_after_prepare(self, host, connection, pool, response): new_metadata_id = response.result_metadata_id if new_metadata_id is not None: self.prepared_statement.result_metadata_id = new_metadata_id - + # use self._query to re-use the same host and # at the same time properly borrow the connection - request_id = self._query(host) + if pool is None and connection is not None and connection.is_control_connection: + request_id = self._query_control_connection(connection=connection, host=host) + else: + request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() @@ -5053,6 +5444,11 @@ def _retry_task(self, reuse_connection, host): # to retry the operation return + if self._control_connection_query_attempted: + self._control_connection_query_attempted = False + self.send_request() + return + if reuse_connection and self._query(host) is not None: return diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..f07160e385 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1816,7 +1816,19 @@ def __init__(self, connection, owner): with connection.lock: if connection.in_flight < connection.max_request_id: connection.in_flight += 1 - connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + request_id = connection.get_request_id() + try: + connection.send_msg(OptionsMessage(), request_id, self._options_callback) + except Exception as exc: + if connection.is_control_connection: + connection.in_flight -= 1 + # send_msg() registers the callback before writing to the socket, + # so a write failure must unwind that registration here. + connection._requests.pop(request_id, None) + if request_id not in connection.request_ids: + connection.request_ids.append(request_id) + self._exception = exc + self._event.set() else: self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") self._event.set() diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index d6dc44119a..684bc50b8a 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -56,7 +56,7 @@ def _get_context(keyspaces, connections): def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None): """ - Creates a keyspace with SimpleStrategy for replica placement + Creates a keyspace with NetworkTopologyStrategy for replica placement If the keyspace already exists, it will not be modified. @@ -66,11 +66,11 @@ def create_keyspace_simple(name, replication_factor, durable_writes=True, connec *There are plans to guard schema-modifying functions with an environment-driven conditional.* :param str name: name of keyspace to create - :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` + :param int replication_factor: keyspace replication factor, used with :attr:`~.NetworkTopologyStrategy` :param bool durable_writes: Write log is bypassed if set to False :param list connections: List of connection names """ - _create_keyspace(name, durable_writes, 'SimpleStrategy', + _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', {'replication_factor': replication_factor}, connections=connections) diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py index 66e1d7295c..452667c8eb 100644 --- a/cassandra/io/asyncioreactor.py +++ b/cassandra/io/asyncioreactor.py @@ -23,8 +23,8 @@ asyncio.run_coroutine_threadsafe except AttributeError: raise ImportError( - 'Cannot use asyncioreactor without access to ' - 'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)' + "Cannot use asyncioreactor without access to " + "asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)" ) @@ -38,12 +38,12 @@ class AsyncioTimer(object): @property def end(self): - raise NotImplementedError('{} is not compatible with TimerManager and ' - 'does not implement .end()') + raise NotImplementedError( + "{} is not compatible with TimerManager and does not implement .end()" + ) def __init__(self, timeout, callback, loop): - delayed = self._call_delayed_coro(timeout=timeout, - callback=callback) + delayed = self._call_delayed_coro(timeout=timeout, callback=callback) self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop) @staticmethod @@ -63,17 +63,61 @@ def cancel(self): def finish(self): # connection.Timer method not implemented here because we can't inspect # the Handle returned from call_later - raise NotImplementedError('{} is not compatible with TimerManager and ' - 'does not implement .finish()') + raise NotImplementedError( + "{} is not compatible with TimerManager and does not implement .finish()" + ) + + +class _AsyncioProtocol(asyncio.Protocol): + """ + Protocol adapter for asyncio SSL connections. Bridges asyncio's + transport/protocol API back to AsyncioConnection's buffer processing. + """ + + def __init__(self, connection, loop_args=None): + self._connection = connection + self.transport = None + self.write_ready = asyncio.Event(**(loop_args or {})) + self.write_ready.set() + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + conn = self._connection + conn._iobuf.write(data) + if conn._iobuf.tell(): + conn.process_io_buffer() + + def pause_writing(self): + self.write_ready.clear() + + def resume_writing(self): + self.write_ready.set() + + def connection_lost(self, exc): + # Unblock any paused writer so shutdown does not hang + self.write_ready.set() + conn = self._connection + if exc: + log.debug("Connection %s lost: %s", conn, exc) + conn.defunct(exc) + else: + log.debug("Connection %s closed by server", conn) + conn.close() + + def eof_received(self): + return False class AsyncioConnection(Connection): """ - An experimental implementation of :class:`.Connection` that uses the - ``asyncio`` module in the Python standard library for its event loop. + An implementation of :class:`.Connection` that uses the ``asyncio`` + module in the Python standard library for its event loop. - Note that it requires ``asyncio`` features that were only introduced in the - 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. + Supports SSL connections via asyncio's native TLS transport, which + avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's + low-level socket methods (``sock_sendall``, ``sock_recv``). """ _loop = None @@ -88,26 +132,109 @@ class AsyncioConnection(Connection): def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._background_tasks = set() + self._transport = None + self._using_ssl = bool(self.ssl_context) self._connect_socket() self._socket.setblocking(0) loop_args = dict() if sys.version_info[0] == 3 and sys.version_info[1] < 10: - loop_args['loop'] = self._loop + loop_args["loop"] = self._loop + self._protocol = _AsyncioProtocol(self, loop_args) if self._using_ssl else None + self._ssl_ready = asyncio.Event(**loop_args) if self._using_ssl else None self._write_queue = asyncio.Queue(**loop_args) self._write_queue_lock = asyncio.Lock(**loop_args) # see initialize_reactor -- loop is running in a separate thread, so we # have to use a threadsafe call - self._read_watcher = asyncio.run_coroutine_threadsafe( - self.handle_read(), loop=self._loop - ) + if self._using_ssl: + # For SSL: set up asyncio transport/protocol, then start writer + self._read_watcher = asyncio.run_coroutine_threadsafe( + self._setup_ssl_and_run(), loop=self._loop + ) + else: + # For non-SSL: use low-level sock_sendall/sock_recv as before + self._read_watcher = asyncio.run_coroutine_threadsafe( + self.handle_read(), loop=self._loop + ) self._write_watcher = asyncio.run_coroutine_threadsafe( self.handle_write(), loop=self._loop ) self._send_options_message() + def _connect_socket(self): + """ + Override base class to skip SSL wrapping of the socket. + For SSL connections, the plain TCP socket is connected here, and TLS + is set up later via asyncio's native SSL transport in _setup_ssl_and_run(). + """ + sockerr = None + addresses = self._get_socket_addresses() + for af, socktype, proto, _, sockaddr in addresses: + try: + self._socket = self._socket_impl.socket(af, socktype, proto) + # Do NOT wrap with ssl_context here -- asyncio will handle TLS + self._socket.settimeout(self.connect_timeout) + self._initiate_connection(sockaddr) + self._socket.settimeout(None) + + local_addr = self._socket.getsockname() + log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr) + sockerr = None + break + except socket.error as err: + if self._socket: + self._socket.close() + self._socket = None + sockerr = err + + if sockerr: + raise socket.error( + sockerr.errno, + "Tried connecting to %s. Last error: %s" + % ([a[4] for a in addresses], sockerr.strerror or sockerr), + ) + + if self.sockopts: + for args in self.sockopts: + self._socket.setsockopt(*args) + + async def _setup_ssl_and_run(self): + """ + Upgrade the plain TCP connection to TLS using asyncio's native SSL + transport, then continuously read data via the protocol callbacks. + """ + try: + ssl_context = self.ssl_context + server_hostname = None + if self.ssl_options: + server_hostname = self.ssl_options.get("server_hostname", None) + if server_hostname is None: + # asyncio's create_connection requires server_hostname when + # ssl= is set. Use endpoint address for SNI/verification when + # check_hostname is enabled; otherwise pass "" to suppress SNI. + server_hostname = ( + self.endpoint.address if ssl_context.check_hostname else "" + ) + + transport, protocol = await self._loop.create_connection( + lambda: self._protocol, + sock=self._socket, + ssl=ssl_context, + server_hostname=server_hostname, + ) + self._transport = transport + + if self._check_hostname: + self._validate_hostname() + self._ssl_ready.set() + except Exception as exc: + log.debug("SSL setup failed for %s: %s", self, exc) + self.defunct(exc) + # Unblock handle_write so it can observe the defunct state and exit + self._ssl_ready.set() + return @classmethod def initialize_reactor(cls): @@ -126,8 +253,9 @@ def initialize_reactor(cls): cls._loop = asyncio.new_event_loop() # daemonize so the loop will be shut down on interpreter # shutdown - cls._loop_thread = Thread(target=cls._loop.run_forever, - daemon=True, name="asyncio_thread") + cls._loop_thread = Thread( + target=cls._loop.run_forever, daemon=True, name="asyncio_thread" + ) cls._loop_thread.start() @classmethod @@ -142,9 +270,7 @@ def close(self): # close from the loop thread to avoid races when removing file # descriptors - asyncio.run_coroutine_threadsafe( - self._close(), loop=self._loop - ) + asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop) async def _close(self): log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) @@ -152,7 +278,10 @@ async def _close(self): self._write_watcher.cancel() if self._read_watcher: self._read_watcher.cancel() - if self._socket: + if self._transport: + self._transport.close() + self._transport = None + elif self._socket: self._loop.remove_writer(self._socket.fileno()) self._loop.remove_reader(self._socket.fileno()) self._socket.close() @@ -172,15 +301,12 @@ def push(self, data): if len(data) > buff_size: chunks = [] for i in range(0, len(data), buff_size): - chunks.append(data[i:i + buff_size]) + chunks.append(data[i : i + buff_size]) else: chunks = [data] if self._loop_thread != threading.current_thread(): - asyncio.run_coroutine_threadsafe( - self._push_msg(chunks), - loop=self._loop - ) + asyncio.run_coroutine_threadsafe(self._push_msg(chunks), loop=self._loop) else: # avoid races/hangs by just scheduling this, not using threadsafe task = self._loop.create_task(self._push_msg(chunks)) @@ -194,13 +320,25 @@ async def _push_msg(self, chunks): for chunk in chunks: self._write_queue.put_nowait(chunk) - async def handle_write(self): + # For SSL connections, wait until the TLS handshake completes + if self._ssl_ready: + await self._ssl_ready.wait() + if self.is_defunct: + return while True: try: next_msg = await self._write_queue.get() if next_msg: - await self._loop.sock_sendall(self._socket, next_msg) + if self._transport: + # SSL: use asyncio transport (handles TLS transparently) + await self._protocol.write_ready.wait() + if self.is_closed or self.is_defunct or not self._transport: + return + self._transport.write(next_msg) + else: + # Non-SSL: use low-level socket API + await self._loop.sock_sendall(self._socket, next_msg) except socket.error as err: log.debug("Exception in send for %s: %s", self, err) self.defunct(err) @@ -223,8 +361,7 @@ async def handle_read(self): await asyncio.sleep(0) continue except socket.error as err: - log.debug("Exception during socket recv for %s: %s", - self, err) + log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) return # leave the read loop except asyncio.CancelledError: diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 51f03f3d97..44b7b63f67 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -48,6 +48,8 @@ Clusters and Sessions .. autoattribute:: control_connection_timeout + .. autoattribute:: allow_control_connection_query_fallback + .. autoattribute:: idle_heartbeat_interval .. autoattribute:: idle_heartbeat_timeout @@ -106,6 +108,9 @@ Clusters and Sessions .. automethod:: set_meta_refresh_enabled +.. autoclass:: ControlConnectionQueryFallback + :members: + .. autoclass:: ExecutionProfile (load_balancing_policy=, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=, speculative_execution_policy=None) :members: :exclude-members: consistency_level @@ -169,6 +174,8 @@ Clusters and Sessions .. automethod:: set_keyspace(keyspace) + .. automethod:: wait_for_schema_agreement + .. automethod:: get_execution_profile .. automethod:: execution_profile_clone_update diff --git a/docs/conf.py b/docs/conf.py index 87a38c6add..34ef31ccae 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,11 +29,11 @@ '3.29.6-scylla', '3.29.7-scylla', '3.29.8-scylla', - '3.29.9-scylla', + '3.29.10-scylla', ] BRANCHES = ['master'] # Set the latest version. -LATEST_VERSION = '3.29.9-scylla' +LATEST_VERSION = '3.29.10-scylla' # Set which versions are not released yet. UNSTABLE_VERSIONS = ['master'] # Set which versions are deprecated diff --git a/docs/installation.rst b/docs/installation.rst index fbb9ac4043..6a4b38ea80 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -26,7 +26,7 @@ To check if the installation was successful, you can run:: python -c 'import cassandra; print(cassandra.__version__)' -It should print something like "3.29.9". +It should print something like "3.29.10". (*Optional*) Compression Support -------------------------------- @@ -190,7 +190,7 @@ through `Homebrew `_. For example, on Mac OS X:: $ brew install libev -The libev extension can now be built for Windows as of Python driver version 3.29.9. You can +The libev extension can now be built for Windows as of Python driver version 3.29.10. You can install libev using any Windows package manager. For example, to install using `vcpkg `_: $ vcpkg install libev diff --git a/docs/scylla-specific.rst b/docs/scylla-specific.rst index e9fe695f8f..4b28781f1c 100644 --- a/docs/scylla-specific.rst +++ b/docs/scylla-specific.rst @@ -91,7 +91,7 @@ New Error Types session = cluster.connect() session.execute(""" CREATE KEYSPACE IF NOT EXISTS keyspace1 - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} """) session.execute("USE keyspace1") diff --git a/docs/uv.lock b/docs/uv.lock index 56b0841403..515e37abba 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -1067,11 +1067,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.3" +version = "2.7.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, + { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" }, ] [[package]] diff --git a/examples/concurrent_executions/execute_async_with_queue.py b/examples/concurrent_executions/execute_async_with_queue.py index 72d2c101cb..794ac78818 100644 --- a/examples/concurrent_executions/execute_async_with_queue.py +++ b/examples/concurrent_executions/execute_async_with_queue.py @@ -31,7 +31,7 @@ session = cluster.connect() session.execute(("CREATE KEYSPACE IF NOT EXISTS examples " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }")) + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1' }")) session.execute("USE examples") session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))") prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)") diff --git a/examples/concurrent_executions/execute_with_threads.py b/examples/concurrent_executions/execute_with_threads.py index e3c80f5d6b..70893bd5be 100644 --- a/examples/concurrent_executions/execute_with_threads.py +++ b/examples/concurrent_executions/execute_with_threads.py @@ -34,7 +34,7 @@ session = cluster.connect() session.execute(("CREATE KEYSPACE IF NOT EXISTS examples " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }")) + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1' }")) session.execute("USE examples") session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))") prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)") diff --git a/examples/example_core.py b/examples/example_core.py index 01c766e109..ec41ca7fd5 100644 --- a/examples/example_core.py +++ b/examples/example_core.py @@ -36,7 +36,7 @@ def main(): log.info("creating keyspace...") session.execute(""" CREATE KEYSPACE IF NOT EXISTS %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' } """ % KEYSPACE) log.info("setting keyspace...") diff --git a/renovate.json b/renovate.json index 5db72dd6a9..d85ac38c01 100644 --- a/renovate.json +++ b/renovate.json @@ -2,5 +2,12 @@ "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "config:recommended" + ], + "packageRules": [ + { + "matchManagers": ["github-actions"], + "pinDigests": true, + "minimumReleaseAge": "90 days" + } ] } diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 6a809bded4..5701e5b3da 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -651,17 +651,17 @@ def setup_keyspace(ipformat=None, protocol_version=None, port=9042): ddl = ''' CREATE KEYSPACE test3rf - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'}''' execute_with_long_wait_retry(session, ddl) ddl = ''' CREATE KEYSPACE test2rf - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '2'}''' execute_with_long_wait_retry(session, ddl) ddl = ''' CREATE KEYSPACE test1rf - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}''' execute_with_long_wait_retry(session, ddl) ddl_3f = ''' @@ -707,6 +707,17 @@ def xfail_scylla_version_lt(reason, scylla_version, *args, **kwargs): return pytest.mark.xfail(current_version < Version(scylla_version), reason=reason, *args, **kwargs) +def get_tablets_disabled_ddl_suffix(scylla_version='2026.1'): + """ + Returns DDL option string for disabling tablets on ScyllaDB versions older than scylla_version. + Used to work around features not yet supported with tablets (e.g. MVs, secondary indexes, counters). + :param scylla_version: str, version from which tablets support the feature + """ + if SCYLLA_VERSION is not None and Version(get_scylla_version(SCYLLA_VERSION)) < Version(scylla_version): + return " AND tablets = {'enabled': false}" + return "" + + def skip_scylla_version_lt(reason, scylla_version): """ Skip tests on scylla versions older than the specified thresholds. @@ -774,7 +785,7 @@ def drop_keyspace(cls): @classmethod def create_keyspace(cls, rf): - ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.ks_name, rf) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format(cls.ks_name, rf) execute_with_long_wait_retry(cls.session, ddl) @classmethod diff --git a/tests/integration/cqlengine/connections/test_connection.py b/tests/integration/cqlengine/connections/test_connection.py index 78d5133e63..640c953285 100644 --- a/tests/integration/cqlengine/connections/test_connection.py +++ b/tests/integration/cqlengine/connections/test_connection.py @@ -76,9 +76,9 @@ def setUpClass(cls): super(SeveralConnectionsTest, cls).setUpClass() cls.setup_cluster = TestCluster() cls.setup_session = cls.setup_cluster.connect() - ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1) execute_with_long_wait_retry(cls.setup_session, ddl) - ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace2, 1) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace2, 1) execute_with_long_wait_retry(cls.setup_session, ddl) @classmethod diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index cafe6ae9c9..98d71993fd 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -259,10 +259,8 @@ class SensitiveModel(Model): rows[-1] rows[-1:] - # ignore DeprecationWarning('The loop argument is deprecated since Python 3.8, and scheduled for removal in Python 3.10.') - relevant_warnings = [warn for warn in w if "The loop argument is deprecated" not in str(warn.message)] + warning_messages = [str(warn.message) for warn in w] - assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[0].message) - assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[1].message) - assert "ModelQuerySet indexing with negative indices support will be removed in 4.0." in str(relevant_warnings[2].message) - assert "ModelQuerySet slicing with negative indices support will be removed in 4.0." in str(relevant_warnings[3].message) + assert sum("__table_name_case_sensitive__ will be removed in 4.0." in message for message in warning_messages) == 2 + assert sum("ModelQuerySet indexing with negative indices support will be removed in 4.0." in message for message in warning_messages) == 1 + assert sum("ModelQuerySet slicing with negative indices support will be removed in 4.0." in message for message in warning_messages) == 1 diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py index 24a6802b47..66ba8b973a 100644 --- a/tests/integration/cqlengine/query/test_named.py +++ b/tests/integration/cqlengine/query/test_named.py @@ -27,7 +27,7 @@ from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage -from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30, requires_collection_indexes +from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30, requires_collection_indexes, get_tablets_disabled_ddl_suffix, execute_with_long_wait_retry import pytest @@ -280,6 +280,12 @@ def test_get_multipleobjects_exception(self): class TestNamedWithMV(BasicSharedKeyspaceUnitTestCase): + @classmethod + def create_keyspace(cls, rf): + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}{2}".format( + cls.ks_name, rf, get_tablets_disabled_ddl_suffix()) + execute_with_long_wait_retry(cls.session, ddl) + @classmethod def setUpClass(cls): super(TestNamedWithMV, cls).setUpClass() diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py index beb10f02c0..04d75555f5 100644 --- a/tests/integration/long/test_failure_types.py +++ b/tests/integration/long/test_failure_types.py @@ -187,7 +187,7 @@ def test_write_failures_from_coordinator(self): self._perform_cql_statement( """ CREATE KEYSPACE testksfail - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'} """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) # create table diff --git a/tests/integration/long/test_policies.py b/tests/integration/long/test_policies.py index ab8d125ab1..5cada34d8b 100644 --- a/tests/integration/long/test_policies.py +++ b/tests/integration/long/test_policies.py @@ -48,7 +48,7 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ep}) session = cluster.connect() - session.execute("CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'SimpleStrategy','replication_factor': 3};") + session.execute("CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'NetworkTopologyStrategy','replication_factor': 3};") session.execute("CREATE TABLE test_retry_policy_cas.t (id int PRIMARY KEY, data text);") session.execute('INSERT INTO test_retry_policy_cas.t ("id", "data") VALUES (%(0)s, %(1)s)', {'0': 42, '1': 'testing'}) diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index f892acba52..d60ff775c4 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -57,7 +57,7 @@ def test_recreates(self): log.debug(drop) execute_until_pass(session, drop) - create = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 3}}".format(keyspace) + create = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': 3}}".format(keyspace) log.debug(create) execute_until_pass(session, create) @@ -82,7 +82,7 @@ def test_for_schema_disagreements_different_keyspaces(self): session = self.session for i in range(30): - execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i)) + execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': 1}}".format(i)) execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i)) for j in range(100): @@ -100,10 +100,10 @@ def test_for_schema_disagreements_same_keyspace(self): for i in range(30): try: - execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}") except AlreadyExists: execute_until_pass(session, "DROP KEYSPACE test") - execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}") execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)") @@ -132,7 +132,7 @@ def test_for_schema_disagreement_attribute(self): cluster = TestCluster(max_schema_agreement_wait=0.001) session = cluster.connect(wait_for_all_pools=True) - rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 3}") self.check_and_wait_for_agreement(session, rs, False) rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", consistency_level=ConsistencyLevel.ALL)) @@ -144,7 +144,7 @@ def test_for_schema_disagreement_attribute(self): # These should have schema agreement cluster = TestCluster(max_schema_agreement_wait=100) session = cluster.connect() - rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 3}") self.check_and_wait_for_agreement(session, rs, True) rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", consistency_level=ConsistencyLevel.ALL)) @@ -158,4 +158,4 @@ def check_and_wait_for_agreement(self, session, rs, exepected): time.sleep(1) assert rs.response_future.is_schema_agreed == exepected if not rs.response_future.is_schema_agreed: - session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) + session.wait_for_schema_agreement(wait_time=1000) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index 56dc6a5c2d..0170f56fa1 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -116,7 +116,7 @@ def validate_ssl_options(**kwargs): # attempt a few simple commands. insert_keyspace = """CREATE KEYSPACE ssltest - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'} """ statement = SimpleStatement(insert_keyspace) statement.consistency_level = 3 @@ -369,7 +369,7 @@ def test_ssl_want_write_errors_are_retried(self): except: pass session.execute( - "CREATE KEYSPACE ssl_error_test WITH replication = {'class':'SimpleStrategy','replication_factor':1};") + "CREATE KEYSPACE ssl_error_test WITH replication = {'class':'NetworkTopologyStrategy','replication_factor':1};") session.execute("CREATE TABLE ssl_error_test.big_text (id uuid PRIMARY KEY, data text);") params = { diff --git a/tests/integration/long/utils.py b/tests/integration/long/utils.py index 93464df8ff..ba9351828e 100644 --- a/tests/integration/long/utils.py +++ b/tests/integration/long/utils.py @@ -63,7 +63,7 @@ def create_schema(cluster, session, keyspace, simple_strategy=True, if simple_strategy: ddl = "CREATE KEYSPACE %s WITH replication" \ - " = {'class': 'SimpleStrategy', 'replication_factor': '%s'}" + " = {'class': 'NetworkTopologyStrategy', 'replication_factor': '%s'}" session.execute(ddl % (keyspace, replication_factor), timeout=10) else: if not replication_strategy: diff --git a/tests/integration/simulacron/test_empty_column.py b/tests/integration/simulacron/test_empty_column.py index 2dbf3985ad..daa9f20fa8 100644 --- a/tests/integration/simulacron/test_empty_column.py +++ b/tests/integration/simulacron/test_empty_column.py @@ -140,9 +140,9 @@ def test_empty_columns_in_system_schema(self): 'delay_in_ms': 0, 'rows': [ { - "strategy_class": "SimpleStrategy", # C* 2.2 + "strategy_class": "NetworkTopologyStrategy", # C* 2.2 "strategy_options": '{}', # C* 2.2 - "replication": {'strategy': 'SimpleStrategy', 'replication_factor': 1}, + "replication": {'strategy': 'NetworkTopologyStrategy', 'replication_factor': 1}, "durable_writes": True, "keyspace_name": "testks" } diff --git a/tests/integration/standard/column_encryption/test_policies.py b/tests/integration/standard/column_encryption/test_policies.py index 9a1d186895..4b12fa135a 100644 --- a/tests/integration/standard/column_encryption/test_policies.py +++ b/tests/integration/standard/column_encryption/test_policies.py @@ -30,7 +30,7 @@ class ColumnEncryptionPolicyTest(unittest.TestCase): def _recreate_keyspace(self, session): session.execute("drop keyspace if exists foo") - session.execute("CREATE KEYSPACE foo WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") + session.execute("CREATE KEYSPACE foo WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") session.execute("CREATE TABLE foo.bar(encrypted blob, unencrypted int, primary key(unencrypted))") def _create_policy(self, key, iv = None): diff --git a/tests/integration/standard/conftest.py b/tests/integration/standard/conftest.py index 3adaf371b0..9934cfcbbb 100644 --- a/tests/integration/standard/conftest.py +++ b/tests/integration/standard/conftest.py @@ -37,6 +37,7 @@ "test_ip_change": 4, "test_authentication": 4, "test_authentication_misconfiguration": 4, + "test_control_connection_query_fallback": 4, "test_custom_cluster": 4, "test_query": 4, # Group 5: tablets (destructive — decommissions a node) diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index 5a20421276..292eabca30 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -741,7 +741,7 @@ def test_queries_succeed_through_proxy(self): session = cluster.connect() session.execute( "CREATE KEYSPACE IF NOT EXISTS test_cr_ks " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}" + "WITH replication = {'class':'NetworkTopologyStrategy', 'replication_factor': 3}" ) session.execute( "CREATE TABLE IF NOT EXISTS test_cr_ks.t (k int PRIMARY KEY, v text)" @@ -1154,7 +1154,7 @@ def tearDownClass(cls): def test_should_survive_full_node_replacement_through_nlb(self): """ 1. Start with 3 nodes behind the NLB - 2. Bootstrap 2 new nodes, add to NLB, update routes + 2. Bootstrap 3 new nodes, add to NLB, update routes 3. Decommission the original 3 nodes one-by-one, updating NLB/routes 4. Verify the session survives with only new nodes """ @@ -1190,7 +1190,7 @@ def test_should_survive_full_node_replacement_through_nlb(self): len(original_node_ids)) # ---- Stage 3: Bootstrap new nodes ---- - new_node_ids = [max(original_node_ids) + 1, max(original_node_ids) + 2] + new_node_ids = [max(original_node_ids) + 1, max(original_node_ids) + 2, max(original_node_ids) + 3] log.info("Stage 3: Adding nodes %s", new_node_ids) ccm_cluster = get_cluster() diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 08b823d716..00ea11ea27 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -180,7 +180,7 @@ def test_basic(self): result = execute_until_pass(session, """ CREATE KEYSPACE clustertests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} """) assert not result @@ -1195,27 +1195,35 @@ def test_replicas_are_queried(self): Then using HostFilterPolicy the replica is excluded from the considered hosts. By checking the trace we verify that there are no more replicas. + Requires tablets feature disabled. + @since 3.5 @jira_ticket PYTHON-653 @expected_result the replicas are queried for HostFilterPolicy @test_category metadata """ + ks_name = 'test_replicas_queried_ks' queried_hosts = set() tap_profile = ExecutionProfile( load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()) ) with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: tap_profile}) as cluster: session = cluster.connect(wait_for_all_pools=True) + session.execute("DROP KEYSPACE IF EXISTS {}".format(ks_name)) + session.execute( + "CREATE KEYSPACE {} WITH replication = {{'class': 'NetworkTopologyStrategy', " + "'replication_factor': '1'}} AND tablets = {{'enabled': false}}".format(ks_name) + ) session.execute(''' - CREATE TABLE test1rf.table_with_big_key ( + CREATE TABLE {}.table_with_big_key ( k1 int, k2 int, k3 int, k4 int, - PRIMARY KEY((k1, k2, k3), k4))''') - prepared = session.prepare("""SELECT * from test1rf.table_with_big_key - WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + PRIMARY KEY((k1, k2, k3), k4))'''.format(ks_name)) + prepared = session.prepare("""SELECT * from {}.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""".format(ks_name)) for i in range(10): result = session.execute(prepared, (i, i, i, i), trace=True) trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) @@ -1234,14 +1242,14 @@ def test_replicas_are_queried(self): execution_profiles={EXEC_PROFILE_DEFAULT: hfp_profile}) as cluster: session = cluster.connect(wait_for_all_pools=True) - prepared = session.prepare("""SELECT * from test1rf.table_with_big_key - WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + prepared = session.prepare("""SELECT * from {}.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""".format(ks_name)) for _ in range(10): result = session.execute(prepared, (last_i, last_i, last_i, last_i), trace=True) trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) self._assert_replica_queried(trace, only_replicas=False) - session.execute('''DROP TABLE test1rf.table_with_big_key''') + session.execute('DROP KEYSPACE {}'.format(ks_name)) @greaterthanorequalcass30 @lessthanorequalcass40 @@ -1506,7 +1514,7 @@ def test_prepare_on_ignored_hosts(self): hosts = cluster.metadata.all_hosts() session.execute("CREATE KEYSPACE clustertests " "WITH replication = " - "{'class': 'SimpleStrategy', 'replication_factor': '1'}") + "{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") session.execute("CREATE TABLE clustertests.tab (a text, PRIMARY KEY (a))") # assign to an unused variable so cluster._prepared_statements retains # reference diff --git a/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py b/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py index 910dcaa9fe..9a9a3d325f 100644 --- a/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py +++ b/tests/integration/standard/test_concurrent_schema_change_and_node_kill.py @@ -27,7 +27,7 @@ def test_schema_change_after_node_kill(self): "DROP KEYSPACE IF EXISTS ks_deadlock;") self.session.execute( "CREATE KEYSPACE IF NOT EXISTS ks_deadlock " - "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2' };") + "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '2' };") self.session.set_keyspace('ks_deadlock') self.session.execute("CREATE TABLE IF NOT EXISTS some_table(k int, c int, v int, PRIMARY KEY (k, v));") self.session.execute("INSERT INTO some_table (k, c, v) VALUES (1, 2, 3);") diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index c4463e17fd..f0c41dde14 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -68,7 +68,7 @@ def test_drop_keyspace(self): self.session = self.cluster.connect() self.session.execute(""" CREATE KEYSPACE keyspacetodrop - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) self.session.set_keyspace("keyspacetodrop") self.session.execute("CREATE TYPE user (age int, name text)") diff --git a/tests/integration/standard/test_control_connection_query_fallback.py b/tests/integration/standard/test_control_connection_query_fallback.py new file mode 100644 index 0000000000..e64763a72c --- /dev/null +++ b/tests/integration/standard/test_control_connection_query_fallback.py @@ -0,0 +1,115 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from cassandra.cluster import ControlConnectionQueryFallback, NoHostAvailable + +from tests.integration import USE_CASS_EXTERNAL, TestCluster, local, remove_cluster, use_cluster + + +_CLUSTER_NAME = "control_connection_query_fallback" +_UNREACHABLE_BROADCAST_RPC_ADDRESS = "127.255.255.1" + + +def setup_module(): + if USE_CASS_EXTERNAL: + return + + remove_cluster() + + ccm_cluster = use_cluster(_CLUSTER_NAME, [1], start=False) + ccm_cluster.nodes["node1"].set_configuration_options(values={ + "broadcast_rpc_address": _UNREACHABLE_BROADCAST_RPC_ADDRESS, + }) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + + +def teardown_module(): + if USE_CASS_EXTERNAL: + return + + remove_cluster() + + +@local +class ControlConnectionQueryFallbackIntegrationTests(unittest.TestCase): + + def setUp(self): + self.cluster = None + + def tearDown(self): + if self.cluster is not None: + self.cluster.shutdown() + + def _assert_unreachable_broadcast_rpc_metadata(self): + hosts = self.cluster.metadata.all_hosts() + assert len(hosts) == 1 + + host = hosts[0] + assert host.broadcast_rpc_address == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert host.endpoint.address == _UNREACHABLE_BROADCAST_RPC_ADDRESS + return host + + def test_disabled_raises_when_broadcast_rpc_address_is_unreachable(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Disabled, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + with pytest.raises(NoHostAvailable): + self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert self.cluster.control_connection._connection is not None + assert self.cluster.get_all_pools() == [] + + def test_fallback_executes_queries_when_broadcast_rpc_address_is_unreachable(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + session = self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert session._initial_connect_futures + assert list(session.get_pools()) == [] + + row = session.execute( + "SELECT release_version, rpc_address FROM system.local WHERE key='local'").one() + assert str(row.rpc_address) == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert row.release_version + + def test_no_node_pool_fallback_executes_queries_without_creating_pools(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + session = self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert session._initial_connect_futures == set() + assert list(session.get_pools()) == [] + + row = session.execute( + "SELECT release_version, rpc_address FROM system.local WHERE key='local'").one() + assert str(row.rpc_address) == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert row.release_version diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index e123f2050e..e7d336014f 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -42,8 +42,9 @@ class CustomProtocolHandlerTest(unittest.TestCase): def setUpClass(cls): cls.cluster = TestCluster() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE custserdes WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.execute("CREATE KEYSPACE custserdes WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1'}") cls.session.set_keyspace("custserdes") + cls.session.execute("CREATE TABLE IF NOT EXISTS custserdes.test (k int PRIMARY KEY, v int)") @classmethod def tearDownClass(cls): @@ -165,7 +166,7 @@ def test_protocol_divergence_v5_fail_by_flag_uses_int(self): int_flag=False) def _send_query_message(self, session, timeout, **kwargs): - query = "SELECT * FROM test3rf.test" + query = "SELECT * FROM custserdes.test" message = QueryMessage(query=query, **kwargs) future = ResponseFuture(session, message, query=None, timeout=timeout) future.send_request() @@ -175,8 +176,8 @@ def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_fla cluster = TestCluster(protocol_version=version, allow_beta_protocol_version=beta) session = cluster.connect() - query_one = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") - query_two = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)") + query_one = SimpleStatement("INSERT INTO custserdes.test (k, v) VALUES (1, 1)") + query_two = SimpleStatement("INSERT INTO custserdes.test (k, v) VALUES (2, 2)") execute_with_long_wait_retry(session, query_one) execute_with_long_wait_retry(session, query_two) @@ -190,7 +191,7 @@ def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_fla # This means the flag are not handled as they are meant by the server if uses_int=False assert response.has_more_pages == uses_int_query_flag - execute_with_long_wait_retry(session, SimpleStatement("TRUNCATE test3rf.test")) + execute_with_long_wait_retry(session, SimpleStatement("TRUNCATE custserdes.test")) cluster.shutdown() diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 9c94b2ac77..49a13ac23a 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -34,7 +34,7 @@ def setUpClass(cls): cls.cluster = TestCluster() cls.session = cls.cluster.connect() cls.session.execute("CREATE KEYSPACE testspace WITH replication = " - "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + "{ 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1'}") cls.session.set_keyspace("testspace") cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS) @@ -225,7 +225,7 @@ def setUpClass(cls): cls.cluster = TestCluster() cls.session = cls.cluster.connect() cls.session.execute("CREATE KEYSPACE IF NOT EXISTS test_wide_table WITH replication = " - "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + "{ 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1'}") cls.session.set_keyspace("test_wide_table") # Create a wide table with many int columns diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 6e64401a75..f5a11dd5fe 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -45,7 +45,7 @@ lessthancass40, TestCluster, requires_java_udf, requires_composite_type, requires_collection_indexes, SCYLLA_VERSION, xfail_scylla, xfail_scylla_version_lt, - requirescompactstorage) + requirescompactstorage, get_tablets_disabled_ddl_suffix, execute_with_long_wait_retry) from tests.util import wait_until, assertRegex, assertDictEqual, assertListEqual, assert_startswith_diff @@ -141,6 +141,12 @@ def test_bad_contact_point(self): class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): + @classmethod + def create_keyspace(cls, rf): + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}{2}".format( + cls.ks_name, rf, get_tablets_disabled_ddl_suffix()) + execute_with_long_wait_retry(cls.session, ddl) + def test_schema_metadata_disable(self): """ Checks to ensure that schema metadata_enabled, and token_metadata_enabled @@ -230,8 +236,8 @@ def test_basic_table_meta_properties(self): assert ksmeta.name == self.keyspace_name assert ksmeta.durable_writes - assert ksmeta.replication_strategy.name == 'SimpleStrategy' - assert ksmeta.replication_strategy.replication_factor == 1 + assert ksmeta.replication_strategy.name == 'NetworkTopologyStrategy' + assert ksmeta.replication_strategy.dc_replication_factors["dc1"] == 1 assert self.function_table_name in ksmeta.tables tablemeta = ksmeta.tables[self.function_table_name] @@ -601,7 +607,7 @@ def test_refresh_schema_metadata(self): assert "new_keyspace" not in cluster2.metadata.keyspaces # Cluster metadata modification - self.session.execute("CREATE KEYSPACE new_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") + self.session.execute("CREATE KEYSPACE new_keyspace WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") assert "new_keyspace" not in cluster2.metadata.keyspaces cluster2.refresh_schema_metadata() @@ -1077,7 +1083,7 @@ def test_metadata_pagination_keyspaces(self): for ks in keyspaces: self.session.execute( - f"CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 3 }}" + f"CREATE KEYSPACE IF NOT EXISTS {ks} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3 }}" ) self.cluster.schema_metadata_page_size = 2000 @@ -1138,7 +1144,7 @@ def test_export_keyspace_schema_udts(self): session.execute(""" CREATE KEYSPACE export_udts - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} AND durable_writes = true; """) session.execute(""" @@ -1162,7 +1168,7 @@ def test_export_keyspace_schema_udts(self): addresses map>) """) - expected_prefix = """CREATE KEYSPACE export_udts WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; + expected_prefix = """CREATE KEYSPACE export_udts WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} AND durable_writes = true; CREATE TYPE export_udts.street ( street_number int, @@ -1196,8 +1202,6 @@ def test_export_keyspace_schema_udts(self): cluster.shutdown() @greaterthancass21 - @xfail_scylla_version_lt(reason='scylladb/scylladb#10707 - Column name in CREATE INDEX is not quoted', - scylla_version="2023.1.1") def test_case_sensitivity(self): """ Test that names that need to be escaped in CREATE statements are @@ -1210,10 +1214,9 @@ def test_case_sensitivity(self): cfname = 'AnInterestingTable' session.execute("DROP KEYSPACE IF EXISTS {0}".format(ksname)) - session.execute(""" - CREATE KEYSPACE "%s" - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - """ % (ksname,)) + session.execute( + ("CREATE KEYSPACE \"%s\" WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" + + get_tablets_disabled_ddl_suffix()) % (ksname,)) session.execute(""" CREATE TABLE "%s"."%s" ( k int, @@ -1256,7 +1259,7 @@ def test_already_exists_exceptions(self): ddl = ''' CREATE KEYSPACE %s - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'}''' with pytest.raises(AlreadyExists): session.execute(ddl % ksname) @@ -1387,7 +1390,7 @@ def setUp(self): self.session = self.cluster.connect() name = self._testMethodName.lower() crt_ks = ''' - CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} AND durable_writes = true''' % name + CREATE KEYSPACE %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND durable_writes = true''' % name self.session.execute(crt_ks) def tearDown(self): @@ -1434,11 +1437,9 @@ def setup_class(cls): if cls.keyspace_name in cls.cluster.metadata.keyspaces: cls.session.execute("DROP KEYSPACE %s" % cls.keyspace_name) - cls.session.execute( - """ - CREATE KEYSPACE %s - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}; - """ % cls.keyspace_name) + ddl = ("CREATE KEYSPACE %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" + + get_tablets_disabled_ddl_suffix()) + cls.session.execute(ddl % cls.keyspace_name) cls.session.set_keyspace(cls.keyspace_name) except Exception: cls.cluster.shutdown() @@ -1540,7 +1541,7 @@ def setup_class(cls): cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}" % cls.keyspace_name) cls.session.set_keyspace(cls.keyspace_name) cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].functions cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].aggregates @@ -2007,7 +2008,8 @@ def setup_class(cls): cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + ddl = "CREATE KEYSPACE %s WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" + get_tablets_disabled_ddl_suffix() + cls.session.execute(ddl % cls.keyspace_name) cls.session.set_keyspace(cls.keyspace_name) connection = cls.cluster.control_connection._connection @@ -2132,6 +2134,13 @@ def test_dct_alias(self): @greaterthanorequalcass30 class MaterializedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): + @classmethod + def create_keyspace(cls, rf): + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}{2}".format( + cls.ks_name, rf, get_tablets_disabled_ddl_suffix()) + execute_with_long_wait_retry(cls.session, ddl) + + def setUp(self): self.session.execute("CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format(self.keyspace_name, self.function_table_name)) self.session.execute( @@ -2219,6 +2228,13 @@ def test_materialized_view_metadata_drop(self): @greaterthanorequalcass30 class MaterializedViewMetadataTestComplex(BasicSegregatedKeyspaceUnitTestCase): + + @classmethod + def create_keyspace(cls, rf): + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}{2}".format( + cls.ks_name, rf, get_tablets_disabled_ddl_suffix()) + execute_with_long_wait_retry(cls.session, ddl) + def test_create_view_metadata(self): """ test to ensure that materialized view metadata is properly constructed diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py index 2de12f7b7f..50b431e3c9 100644 --- a/tests/integration/standard/test_policies.py +++ b/tests/integration/standard/test_policies.py @@ -104,5 +104,5 @@ def test_exponential_retries(self): self.session.execute( """ CREATE KEYSPACE preparedtests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} """) diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index 3f63b881ef..37f93c94c6 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -62,7 +62,7 @@ def test_basic(self): self.session.execute( """ CREATE KEYSPACE preparedtests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'} """) self.session.set_keyspace("preparedtests") @@ -437,7 +437,7 @@ def test_fail_if_different_query_id_on_reprepare(self): keyspace = "test_fail_if_different_query_id_on_reprepare" self.session.execute( "CREATE KEYSPACE IF NOT EXISTS {} WITH replication = " - "{{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(keyspace) + "{{'class': 'NetworkTopologyStrategy', 'replication_factor': 1}}".format(keyspace) ) self.session.execute("CREATE TABLE IF NOT EXISTS {}.foo(k int PRIMARY KEY)".format(keyspace)) prepared = self.session.prepare("SELECT * FROM {}.foo WHERE k=?".format(keyspace)) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index f9d3dc26bc..210f6dacb1 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -26,7 +26,8 @@ from cassandra.policies import HostDistance, RoundRobinPolicy, WhiteListRoundRobinPolicy from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, \ greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \ - USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla + USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla, xfail_scylla_version_lt, \ + get_tablets_disabled_ddl_suffix, execute_with_long_wait_retry from tests import notwindows from tests.integration import greaterthanorequalcass30, get_node from tests.util import assertListEqual, wait_until @@ -804,6 +805,9 @@ def setUp(self): def tearDown(self): self.cluster.shutdown() + @xfail_scylla_version_lt(reason='scylladb/scylladb#18068 - LWT is not yet supported with tablets', + scylla_version='2025.4', + raises=InvalidRequest) def test_conditional_update(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") statement = SimpleStatement( @@ -828,6 +832,9 @@ def test_conditional_update(self): assert result assert result.one().applied + @xfail_scylla_version_lt(reason='scylladb/scylladb#18068 - LWT is not yet supported with tablets', + scylla_version='2025.4', + raises=InvalidRequest) def test_conditional_update_with_prepared_statements(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") statement = self.session.prepare( @@ -850,6 +857,9 @@ def test_conditional_update_with_prepared_statements(self): assert result assert result.one().applied + @xfail_scylla_version_lt(reason='scylladb/scylladb#18068 - LWT is not yet supported with tablets', + scylla_version='2025.4', + raises=InvalidRequest) def test_conditional_update_with_batch_statements(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") statement = BatchStatement(serial_consistency_level=ConsistencyLevel.SERIAL) @@ -915,6 +925,9 @@ def tearDown(self): self.session.execute("DROP TABLE test3rf.lwt_clustering") self.cluster.shutdown() + @xfail_scylla_version_lt(reason='scylladb/scylladb#18068 - LWT is not yet supported with tablets', + scylla_version='2025.4', + raises=AttributeError) def test_no_connection_refused_on_timeout(self): """ Test for PYTHON-91 "Connection closed after LWT timeout" @@ -1156,6 +1169,12 @@ def test_inherit_first_rk_prepared_param(self): @greaterthanorequalcass30 class MaterializedViewQueryTest(BasicSharedKeyspaceUnitTestCase): + @classmethod + def create_keyspace(cls, rf): + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}{2}".format( + cls.ks_name, rf, get_tablets_disabled_ddl_suffix()) + execute_with_long_wait_retry(cls.session, ddl) + def test_mv_filtering(self): """ Test to ensure that cql filtering where clauses are properly supported in the python driver. @@ -1359,12 +1378,12 @@ def setUpClass(cls): cls.table_name = "table_query_keyspace_tests" ddl = """CREATE KEYSPACE {0} WITH replication = - {{'class': 'SimpleStrategy', + {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}""".format(cls.ks_name, 1) cls.session.execute(ddl) ddl = """CREATE KEYSPACE {0} WITH replication = - {{'class': 'SimpleStrategy', + {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{1}'}}""".format(cls.alternative_ks, 1) cls.session.execute(ddl) diff --git a/tests/integration/standard/test_rate_limit_exceeded.py b/tests/integration/standard/test_rate_limit_exceeded.py index ea7dfc7d61..5a7fc5dc74 100644 --- a/tests/integration/standard/test_rate_limit_exceeded.py +++ b/tests/integration/standard/test_rate_limit_exceeded.py @@ -33,7 +33,7 @@ def test_rate_limit_exceeded(self): self.session.execute( """ CREATE KEYSPACE IF NOT EXISTS ratetests - WITH REPLICATION = {'class' : 'SimpleStrategy', 'replication_factor' : 1} + WITH REPLICATION = {'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1} """) self.session.execute("USE ratetests") diff --git a/tests/integration/standard/test_shard_aware.py b/tests/integration/standard/test_shard_aware.py index d1f3e27abd..4a6c7887d8 100644 --- a/tests/integration/standard/test_shard_aware.py +++ b/tests/integration/standard/test_shard_aware.py @@ -89,7 +89,7 @@ def create_ks_and_cf(self): self.session.execute( """ CREATE KEYSPACE preparedtests - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '3'} AND tablets = {'enabled': false} """) self.session.execute("USE preparedtests") @@ -174,6 +174,8 @@ def test_all_tracing_coming_one_shard(self): using the traces to validate that all the action been executed on the the same shard. this test is using prepared SELECT statements for this validation + + Requires tablets to be disabled to ensure shard consistency. """ self.create_ks_and_cf() diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index d969140339..45e8a807ea 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -9,7 +9,7 @@ def setup_module(): - use_cluster('tablets', [3], start=True) + use_cluster('tablets', [3], start=True, set_keyspace=False) class TestTabletsIntegration: diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index e608a9610b..11888adda4 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -94,7 +94,7 @@ def test_can_insert_unprepared_registered_udts(self): # use the same UDT name in a different keyspace s.execute(""" CREATE KEYSPACE udt_test_unprepared_registered2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.set_keyspace("udt_test_unprepared_registered2") s.execute("CREATE TYPE user (state text, is_cool boolean)") @@ -124,14 +124,14 @@ def test_can_register_udt_before_connecting(self): s.execute(""" CREATE KEYSPACE udt_test_register_before_connecting - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.execute("CREATE TYPE udt_test_register_before_connecting.user (age int, name text)") s.execute("CREATE TABLE udt_test_register_before_connecting.mytable (a int PRIMARY KEY, b frozen)") s.execute(""" CREATE KEYSPACE udt_test_register_before_connecting2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.execute("CREATE TYPE udt_test_register_before_connecting2.user (state text, is_cool boolean)") s.execute("CREATE TABLE udt_test_register_before_connecting2.mytable (a int PRIMARY KEY, b frozen)") @@ -147,7 +147,7 @@ def test_can_register_udt_before_connecting(self): c.register_user_type("udt_test_register_before_connecting2", "user", User2) s = c.connect(wait_for_all_pools=True) - c.control_connection.wait_for_schema_agreement() + s.wait_for_schema_agreement() s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0") @@ -193,7 +193,7 @@ def test_can_insert_prepared_unregistered_udts(self): # use the same UDT name in a different keyspace s.execute(""" CREATE KEYSPACE udt_test_prepared_unregistered2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.set_keyspace("udt_test_prepared_unregistered2") s.execute("CREATE TYPE user (state text, is_cool boolean)") @@ -240,7 +240,7 @@ def test_can_insert_prepared_registered_udts(self): # use the same UDT name in a different keyspace s.execute(""" CREATE KEYSPACE udt_test_prepared_registered2 - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + WITH replication = { 'class' : 'NetworkTopologyStrategy', 'replication_factor': '1' } """) s.set_keyspace("udt_test_prepared_registered2") s.execute("CREATE TYPE user (state text, is_cool boolean)") diff --git a/tests/integration/standard/test_use_keyspace.py b/tests/integration/standard/test_use_keyspace.py index 80e7cfe5f3..9eb3f5be36 100644 --- a/tests/integration/standard/test_use_keyspace.py +++ b/tests/integration/standard/test_use_keyspace.py @@ -65,7 +65,7 @@ def patched_set_keyspace_blocking(*args, **kwargs): return original_set_keyspace_blocking(*args, **kwargs) with patch.object(Connection, "set_keyspace_blocking", patched_set_keyspace_blocking): - self.session.execute("CREATE KEYSPACE test_set_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + self.session.execute("CREATE KEYSPACE test_set_keyspace WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}") self.session.execute("CREATE TABLE test_set_keyspace.set_keyspace_slow_connection(pk int, PRIMARY KEY(pk))") session2 = self.cluster.connect() diff --git a/tests/unit/advanced/test_metadata.py b/tests/unit/advanced/test_metadata.py index 5ccfa5e477..d68a87961d 100644 --- a/tests/unit/advanced/test_metadata.py +++ b/tests/unit/advanced/test_metadata.py @@ -34,8 +34,8 @@ def _create_vertex_metadata(self, label_name='label'): def _create_keyspace_metadata(self, graph_engine): return KeyspaceMetadata( - 'keyspace', True, 'org.apache.cassandra.locator.SimpleStrategy', - {'replication_factor': 1}, graph_engine=graph_engine) + 'keyspace', True, 'org.apache.cassandra.locator.NetworkTopologyStrategy', + {'dc1': 1}, graph_engine=graph_engine) def _create_table_metadata(self, with_vertex=False, with_edge=False): tm = TableMetadataDSE68('keyspace', 'table') diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..3d55bc1860 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -13,16 +13,19 @@ # limitations under the License. import unittest +from concurrent.futures import Future import logging import socket +from types import SimpleNamespace from unittest.mock import patch, Mock import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion -from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ +from cassandra.cluster import _Scheduler, Session, Cluster, ResultSet, SchemaAgreementScope, ControlConnectionQueryFallback, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import ConnectionBusy, ConnectionException from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -184,6 +187,52 @@ def test_port_range(self): with pytest.raises(ValueError): cluster = Cluster(contact_points=['127.0.0.1'], port=invalid_port) + def test_control_connection_query_fallback_modes(self): + assert Cluster().allow_control_connection_query_fallback is ControlConnectionQueryFallback.Disabled + with pytest.raises(TypeError): + Cluster(allow_control_connection_query_fallback=False) + with pytest.raises(TypeError): + Cluster(allow_control_connection_query_fallback=True) + assert ( + Cluster(allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback) + .allow_control_connection_query_fallback + is ControlConnectionQueryFallback.Fallback + ) + assert Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation + ).allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation + + def test_control_connection_query_fallback_no_node_pool_mode_skips_pool_creation(self): + cluster = Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation, + monitor_reporting_enabled=False, + ) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + with patch.object(Session, "add_or_renew_pool") as mocked_add_or_renew_pool: + session = Session(cluster, [host]) + + mocked_add_or_renew_pool.assert_not_called() + assert session._initial_connect_futures == set() + assert session._pools == {} + assert session.update_created_pools() == set() + + def test_control_connection_query_fallback_fallback_tolerates_empty_initial_pools(self): + cluster = Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback, + monitor_reporting_enabled=False, + ) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + future = Future() + future.set_result(False) + + with patch.object(Session, "add_or_renew_pool", return_value=future) as mocked_add_or_renew_pool: + session = Session(cluster, [host]) + + mocked_add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + assert session._initial_connect_futures == {future} + assert session._pools == {} + def test_compression_autodisabled_without_libraries(self): with patch.dict('cassandra.cluster.locally_supported_compressions', {}, clear=True): with patch('cassandra.cluster.log') as patched_logger: @@ -247,11 +296,123 @@ def test_event_delay_timing(self, *_): class SessionTest(unittest.TestCase): + class FakeTime(object): + + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + + def sleep(self, amount): + self.clock += amount + + class MockPool(object): + + def __init__(self, host, connection): + self.host = host + self.host_distance = HostDistance.LOCAL + self.is_shutdown = False + self.connection = connection + + def _get_connection_for_routing_key(self): + return self.connection + + class MockSchemaVersionFuture(object): + + def __init__(self, outcome, auto_complete=True): + self._outcome = outcome + self._auto_complete = auto_complete + self._delivered = False + self._callback_state = None + self._col_names = ("schema_version",) + self._col_types = None + self.has_more_pages = False + self._continuous_paging_session = None + + def _deliver(self): + if self._delivered or self._callback_state is None: + return + + self._delivered = True + callback, errback, callback_args, callback_kwargs, errback_args, errback_kwargs = self._callback_state + if isinstance(self._outcome, Exception): + errback(self._outcome, *errback_args, **errback_kwargs) + else: + row = SimpleNamespace(schema_version=self._outcome) + callback([row], *callback_args, **callback_kwargs) + + def add_callbacks(self, callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + self._callback_state = ( + callback, + errback, + callback_args, + callback_kwargs or {}, + errback_args, + errback_kwargs or {}, + ) + if self._auto_complete: + self._deliver() + return self + + def complete(self): + self._deliver() + + def result(self): + if isinstance(self._outcome, Exception): + raise self._outcome + return ResultSet(self, [SimpleNamespace(schema_version=self._outcome)]) + def setUp(self): if connection_class is None: raise unittest.SkipTest('libev does not appear to be installed correctly') connection_class.initialize_reactor() + def _mock_schema_future(self, outcome): + return self.MockSchemaVersionFuture(outcome) + + def _host_query_count(self, session, target_host): + return sum(1 for call in session.execute_async.call_args_list if call.kwargs.get('host') is target_host) + + def _new_schema_agreement_session(self, schema_versions, distances=None): + hosts = [] + connections = {} + distance_map = {} + if distances is None: + distances = [HostDistance.LOCAL] * len(schema_versions) + + for index, schema_version in enumerate(schema_versions): + host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + hosts.append(host) + distance_map[host] = distances[index] + + cluster = Cluster(protocol_version=4) + for host in hosts: + cluster.metadata.add_or_return_host(host) + + session = Session(cluster, hosts) + session._profile_manager.distance = Mock(side_effect=lambda host: distance_map.get(host, HostDistance.LOCAL)) + session._pools = {} + for host, schema_version in zip(hosts, schema_versions): + connection = Mock(endpoint=host.endpoint) + connection.future_outcomes = [schema_version] + session._pools[host] = self.MockPool(host, connection) + connections[host] = connection + + def execute_async(query, parameters=None, trace=False, + custom_payload=None, execution_profile=None, + paging_state=None, timeout=None, host=None, execute_as=None): + connection = connections[host] + outcome = connection.future_outcomes.pop(0) if len(connection.future_outcomes) > 1 else connection.future_outcomes[0] + return self._mock_schema_future(outcome) + + session.execute_async = Mock(side_effect=execute_async) + + return session, hosts, connections + # TODO: this suite could be expanded; for now just adding a test covering a PR @mock_session_pools def test_default_serial_consistency_level_ep(self, *_): @@ -339,6 +500,130 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + @mock_session_pools + def test_wait_for_schema_agreement_default_scope_queries_all_connected_hosts(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.REMOTE]) + + assert session.wait_for_schema_agreement(wait_time=1) + + for host in hosts: + assert self._host_query_count(session, host) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_retries_until_local_hosts_match(self, *_): + session, hosts, connections = self._new_schema_agreement_session(["a", "b"]) + clock = self.FakeTime() + connections[hosts[1]].future_outcomes = ["b", "a"] + + with patch('cassandra.cluster.time', new=clock): + assert session.wait_for_schema_agreement(wait_time=1) + for host in hosts: + assert self._host_query_count(session, host) == 2 + assert clock.clock == 0.2 + + @mock_session_pools + def test_wait_for_schema_agreement_retries_when_local_connection_is_busy(self, *_): + session, hosts, connections = self._new_schema_agreement_session(["a", "a"]) + clock = self.FakeTime() + connections[hosts[1]].future_outcomes = [ + ConnectionBusy("connection overloaded"), + "a"] + + with patch('cassandra.cluster.time', new=clock): + assert session.wait_for_schema_agreement(wait_time=1) + for host in hosts: + assert self._host_query_count(session, host) == 2 + assert clock.clock == 0.2 + + @mock_session_pools + def test_wait_for_schema_agreement_ignores_local_hosts_without_session_pool(self, *_): + session, hosts, _ = self._new_schema_agreement_session(["a"]) + + unconnected_host = Host("127.0.0.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) + unconnected_host.set_up() + session.cluster.metadata.add_or_return_host(unconnected_host) + + assert session.wait_for_schema_agreement(wait_time=1) + assert self._host_query_count(session, hosts[0]) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_queries_hosts_in_order(self, *_): + session, hosts, _ = self._new_schema_agreement_session(["a"] * 11) + + assert session.wait_for_schema_agreement(wait_time=1) + assert [call.kwargs['host'] for call in session.execute_async.call_args_list] == list(hosts) + + @mock_session_pools + def test_wait_for_schema_agreement_rack_scope_only_queries_local_rack_connections(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]) + + assert session.wait_for_schema_agreement(wait_time=1, scope=SchemaAgreementScope.RACK) + + assert self._host_query_count(session, hosts[0]) == 1 + assert self._host_query_count(session, hosts[1]) == 0 + assert self._host_query_count(session, hosts[2]) == 0 + + @mock_session_pools + def test_wait_for_schema_agreement_cluster_scope_skips_ignored_hosts(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a"], + distances=[HostDistance.IGNORED, HostDistance.LOCAL]) + + assert session.wait_for_schema_agreement(wait_time=1, scope=SchemaAgreementScope.CLUSTER) + + assert self._host_query_count(session, hosts[0]) == 0 + assert self._host_query_count(session, hosts[1]) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_cluster_scope_excludes_hosts_with_unknown_status(self, *_): + session, hosts, _ = self._new_schema_agreement_session( + ["a", "a"], + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL]) + + hosts[0].is_up = None + + assert session.wait_for_schema_agreement(wait_time=1, scope=SchemaAgreementScope.CLUSTER) + + assert self._host_query_count(session, hosts[0]) == 0 + assert self._host_query_count(session, hosts[1]) == 1 + + @mock_session_pools + def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_): + session, _, _ = self._new_schema_agreement_session(["a"]) + + with pytest.raises(ValueError): + session.wait_for_schema_agreement(wait_time=1, scope='planet') + + @mock_session_pools + def test_set_keyspace_for_all_pools_reports_all_errors(self, *_): + cluster = Cluster() + session = Session( + cluster, + [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())], + ) + + pool1 = Mock(host='host1') + pool2 = Mock(host='host2') + keyspace_error = ConnectionException("boom") + + pool1._set_keyspace_for_all_conns.side_effect = ( + lambda keyspace, callback: callback(pool1, [keyspace_error]) + ) + pool2._set_keyspace_for_all_conns.side_effect = ( + lambda keyspace, callback: callback(pool2, []) + ) + session._pools = {'host1': pool1, 'host2': pool2} + + callback = Mock() + session._set_keyspace_for_all_pools('ks', callback) + + callback.assert_called_once() + assert callback.call_args.args[0] == {'host1': [keyspace_error]} + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2fa7c71196..cf4607fbed 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -21,7 +21,7 @@ from cassandra import OperationTimedOut from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, - locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, + locally_supported_compressions, ConnectionHeartbeat, HeartbeatFuture, _Frame, Timer, TimerManager, ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, @@ -463,6 +463,31 @@ def test_no_req_ids(self, *args): holder.return_connection.assert_has_calls( [call(max_connection)] * get_holders.call_count) + def test_heartbeat_future_releases_request_id_when_send_fails(self, *args): + connection = Connection(DefaultEndPoint('1.2.3.4')) + connection.push = Mock(side_effect=ConnectionException("write failed")) + owner = Mock() + initial_in_flight = connection.in_flight + initial_request_ids = len(connection.request_ids) + + # HostConnection.return_connection releases the heartbeat's in-flight slot. + def return_connection(conn): + with conn.lock: + conn.in_flight -= 1 + + owner.return_connection.side_effect = return_connection + + future = HeartbeatFuture(connection, owner) + + with pytest.raises(ConnectionException): + future.wait(0) + + owner.return_connection(connection) + + assert connection.in_flight == initial_in_flight + assert len(connection.request_ids) == initial_request_ids + assert not connection._requests + def test_unexpected_response(self, *args): request_id = 999 diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..fd62323f33 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,7 +15,7 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, ANY, call +from unittest.mock import Mock, ANY, call, patch from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS @@ -210,16 +210,27 @@ def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing """ - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() # the control connection should not have slept at all assert self.time.clock == 0 + @patch('cassandra.cluster.warn') + def test_wait_for_schema_agreement_warns_about_deprecation(self, mocked_warn): + assert self.control_connection.wait_for_schema_agreement() + + mocked_warn.assert_called_once() + warning_args, warning_kwargs = mocked_warn.call_args + assert 'ControlConnection.wait_for_schema_agreement is deprecated' in str(warning_args[0]) + assert 'Use Session.wait_for_schema_agreement instead.' in str(warning_args[0]) + assert warning_args[1] is DeprecationWarning + assert warning_kwargs['stacklevel'] == 2 + def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self): """ wait_for_schema_agreement uses preloaded results if given for shared table queries """ preloaded_results = self._matching_schema_preloaded_results - assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) + assert self.control_connection._wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all assert self.time.clock == 0 # the connection should not have made any queries if given preloaded results @@ -230,7 +241,7 @@ def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_ wait_for_schema_agreement requery if schema does not match using preloaded results """ preloaded_results = self._nonmatching_schema_preloaded_results - assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) + assert self.control_connection._wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all assert self.time.clock == 0 assert self.connection.wait_for_responses.call_count == 1 @@ -241,7 +252,7 @@ def test_wait_for_schema_agreement_fails(self): """ # change the schema version on one node self.connection.peer_results[1][1][2] = 'b' - assert not self.control_connection.wait_for_schema_agreement() + assert not self.control_connection._wait_for_schema_agreement() # the control connection should have slept until it hit the limit assert self.time.clock >= self.cluster.max_schema_agreement_wait @@ -262,7 +273,7 @@ def test_wait_for_schema_agreement_skipping(self): self.connection.peer_results[1][1][3] = 'c' self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.1')).is_up = False - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() assert self.time.clock == 0 def test_wait_for_schema_agreement_rpc_lookup(self): @@ -279,12 +290,12 @@ def test_wait_for_schema_agreement_rpc_lookup(self): # even though the new host has a different schema version, it's # marked as down, so the control connection shouldn't care - assert self.control_connection.wait_for_schema_agreement() + assert self.control_connection._wait_for_schema_agreement() assert self.time.clock == 0 # but once we mark it up, the control connection will care host.is_up = True - assert not self.control_connection.wait_for_schema_agreement() + assert not self.control_connection._wait_for_schema_agreement() assert self.time.clock >= self.cluster.max_schema_agreement_wait @@ -299,7 +310,7 @@ def test_wait_for_schema_agreement_none_timeout(self): status_event_refresh_window=0) cc._connection = self.connection cc._time = self.time - assert cc.wait_for_schema_agreement() + assert cc._wait_for_schema_agreement() def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() @@ -441,7 +452,8 @@ def bad_wait_for_responses(*args, **kwargs): self.control_connection.refresh_node_list_and_token_map() self.cluster.executor.submit.assert_called_with(self.control_connection._reconnect) - def test_refresh_schema_timeout(self): + @patch('cassandra.cluster.warn') + def test_refresh_schema_timeout(self, mocked_warn): def bad_wait_for_responses(*args, **kwargs): self.time.sleep(kwargs['timeout']) @@ -451,6 +463,7 @@ def bad_wait_for_responses(*args, **kwargs): self.control_connection.refresh_schema() assert self.connection.wait_for_responses.call_count == self.cluster.max_schema_agreement_wait / self.control_connection._timeout assert self.connection.wait_for_responses.call_args[1]['timeout'] == self.control_connection._timeout + mocked_warn.assert_not_called() def test_handle_topology_change(self): event = { diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index dcbb840447..15cf283777 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -25,7 +25,7 @@ from cassandra.marshal import uint16_unpack, uint16_pack from cassandra.metadata import (Murmur3Token, MD5Token, BytesToken, ReplicationStrategy, - NetworkTopologyStrategy, SimpleStrategy, + NetworkTopologyStrategy, LocalStrategy, protect_name, protect_names, protect_value, is_valid_name, UserType, KeyspaceMetadata, get_schema_parser, @@ -96,14 +96,14 @@ def test_replication_strategy(self): assert rs.create('NetworkTopologyStrategy', fake_options_map).dc_replication_factors == NetworkTopologyStrategy(fake_options_map).dc_replication_factors fake_options_map = {'options': 'map'} - assert rs.create('SimpleStrategy', fake_options_map) is None + assert rs.create('NetworkTopologyStrategy', fake_options_map) is None fake_options_map = {'options': 'map'} assert isinstance(rs.create('LocalStrategy', fake_options_map), LocalStrategy) - fake_options_map = {'options': 'map', 'replication_factor': 3} - assert isinstance(rs.create('SimpleStrategy', fake_options_map), SimpleStrategy) - assert rs.create('SimpleStrategy', fake_options_map).replication_factor == SimpleStrategy(fake_options_map).replication_factor + fake_options_map = {'dc1': 3} + assert isinstance(rs.create('NetworkTopologyStrategy', fake_options_map), NetworkTopologyStrategy) + assert rs.create('NetworkTopologyStrategy', fake_options_map).dc_replication_factors == NetworkTopologyStrategy(fake_options_map).dc_replication_factors assert rs.create('xxxxxxxx', fake_options_map) == _UnknownStrategy('xxxxxxxx', fake_options_map) @@ -113,38 +113,38 @@ def test_replication_strategy(self): rs.export_for_schema() def test_simple_replication_type_parsing(self): - """ Test equality between passing numeric and string replication factor for simple strategy """ + """ Test equality between passing numeric and string replication factor for NTS """ rs = ReplicationStrategy() - simple_int = rs.create('SimpleStrategy', {'replication_factor': 3}) - simple_str = rs.create('SimpleStrategy', {'replication_factor': '3'}) + nts_int = rs.create('NetworkTopologyStrategy', {'dc1': 3}) + nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3'}) - assert simple_int.export_for_schema() == simple_str.export_for_schema() - assert simple_int == simple_str + assert nts_int.export_for_schema() == nts_str.export_for_schema() + assert nts_int == nts_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, datacenter='dc1', rack='rack1', host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) - assert simple_int.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) + assert nts_int.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_transient_replication_parsing(self): - """ Test that we can PARSE a transient replication factor for SimpleStrategy """ + """ Test that we can PARSE a transient replication factor for NetworkTopologyStrategy """ rs = ReplicationStrategy() - simple_transient = rs.create('SimpleStrategy', {'replication_factor': '3/1'}) - assert simple_transient.replication_factor_info == ReplicationFactor(3, 1) - assert simple_transient.replication_factor == 2 - assert "'replication_factor': '3/1'" in simple_transient.export_for_schema() + nts_transient = rs.create('NetworkTopologyStrategy', {'dc1': '3/1'}) + assert nts_transient.dc_replication_factors_info['dc1'] == ReplicationFactor(3, 1) + assert nts_transient.dc_replication_factors['dc1'] == 2 + assert "'dc1': '3/1'" in nts_transient.export_for_schema() - simple_str = rs.create('SimpleStrategy', {'replication_factor': '2'}) - assert simple_transient != simple_str + nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '2'}) + assert nts_transient != nts_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, datacenter='dc1', rack='rack1', host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) - assert simple_transient.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) + assert nts_transient.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_replication_parsing(self): """ Test equality between passing numeric and string replication factor for NTS """ @@ -318,9 +318,9 @@ def test_nts_export_for_schema(self): assert "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" == strategy.export_for_schema() def test_simple_strategy_make_token_replica_map(self): - host1 = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - host2 = Host('2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - host3 = Host('3', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host1 = Host('1', SimpleConvictionPolicy, datacenter='dc1', rack='rack1', host_id=uuid.uuid4()) + host2 = Host('2', SimpleConvictionPolicy, datacenter='dc1', rack='rack1', host_id=uuid.uuid4()) + host3 = Host('3', SimpleConvictionPolicy, datacenter='dc1', rack='rack1', host_id=uuid.uuid4()) token_to_host_owner = { MD5Token(0): host1, MD5Token(100): host2, @@ -328,23 +328,23 @@ def test_simple_strategy_make_token_replica_map(self): } ring = [MD5Token(0), MD5Token(100), MD5Token(200)] - rf1_replicas = SimpleStrategy({'replication_factor': '1'}).make_token_replica_map(token_to_host_owner, ring) + rf1_replicas = NetworkTopologyStrategy({'dc1': '1'}).make_token_replica_map(token_to_host_owner, ring) assertCountEqual(rf1_replicas[MD5Token(0)], [host1]) assertCountEqual(rf1_replicas[MD5Token(100)], [host2]) assertCountEqual(rf1_replicas[MD5Token(200)], [host3]) - rf2_replicas = SimpleStrategy({'replication_factor': '2'}).make_token_replica_map(token_to_host_owner, ring) + rf2_replicas = NetworkTopologyStrategy({'dc1': '2'}).make_token_replica_map(token_to_host_owner, ring) assertCountEqual(rf2_replicas[MD5Token(0)], [host1, host2]) assertCountEqual(rf2_replicas[MD5Token(100)], [host2, host3]) assertCountEqual(rf2_replicas[MD5Token(200)], [host3, host1]) - rf3_replicas = SimpleStrategy({'replication_factor': '3'}).make_token_replica_map(token_to_host_owner, ring) + rf3_replicas = NetworkTopologyStrategy({'dc1': '3'}).make_token_replica_map(token_to_host_owner, ring) assertCountEqual(rf3_replicas[MD5Token(0)], [host1, host2, host3]) assertCountEqual(rf3_replicas[MD5Token(100)], [host2, host3, host1]) assertCountEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) def test_ss_equals(self): - assert SimpleStrategy({'replication_factor': '1'}) != NetworkTopologyStrategy({'dc1': 2}) + assert NetworkTopologyStrategy({'dc1': '1'}) != NetworkTopologyStrategy({'dc1': 2}) class NameEscapingTest(unittest.TestCase): @@ -409,9 +409,9 @@ def test_is_valid_name(self): class GetReplicasTest(unittest.TestCase): def _get_replicas(self, token_klass): tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] - hosts = [Host("ip%d" % i, SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(len(tokens))] + hosts = [Host("ip%d" % i, SimpleConvictionPolicy, datacenter="dc1", rack="rack1", host_id=uuid.uuid4()) for i in range(len(tokens))] token_to_primary_replica = dict(zip(tokens, hosts)) - keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) + keyspace = KeyspaceMetadata("ks", True, "NetworkTopologyStrategy", {"dc1": "1"}) metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) token_map = TokenMap(token_klass, token_to_primary_replica, tokens, metadata) @@ -524,13 +524,13 @@ class KeyspaceMetadataTest(unittest.TestCase): def test_export_as_string_user_types(self): keyspace_name = 'test' - keyspace = KeyspaceMetadata(keyspace_name, True, 'SimpleStrategy', dict(replication_factor=3)) + keyspace = KeyspaceMetadata(keyspace_name, True, 'NetworkTopologyStrategy', dict(dc1=3)) keyspace.user_types['a'] = UserType(keyspace_name, 'a', ['one', 'two'], ['c', 'int']) keyspace.user_types['b'] = UserType(keyspace_name, 'b', ['one', 'two', 'three'], ['d', 'int', 'a']) keyspace.user_types['c'] = UserType(keyspace_name, 'c', ['one'], ['int']) keyspace.user_types['d'] = UserType(keyspace_name, 'd', ['one'], ['c']) - assert """CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true; + assert """CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': '3'} AND durable_writes = true; CREATE TYPE test.c ( one int @@ -662,7 +662,7 @@ class UnicodeIdentifiersTests(unittest.TestCase): name = b'\'_-()"\xc2\xac'.decode('utf-8') def test_keyspace_name(self): - km = KeyspaceMetadata(self.name, False, 'SimpleStrategy', {'replication_factor': 1}) + km = KeyspaceMetadata(self.name, False, 'NetworkTopologyStrategy', {'dc1': 1}) km.export_as_string() def test_table_name(self): diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..9673b0d634 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -19,7 +19,7 @@ from unittest.mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion, ControlConnectionQueryFallback from cassandra.connection import Connection, ConnectionException from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, @@ -41,6 +41,7 @@ def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] s.cluster.control_connection._tablets_routing_v1 = False + s.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Disabled return s def make_pool(self): @@ -49,6 +50,22 @@ def make_pool(self): pool.borrow_connection.return_value = [Mock(), Mock()] return pool + def make_control_connection(self): + connection = Mock(spec=Connection) + connection.endpoint = 'control-host' + connection.lock = RLock() + connection.in_flight = 0 + connection.max_request_id = 100 + connection.request_ids = deque() + connection._requests = {} + connection.orphaned_request_ids = set() + connection.orphaned_threshold = 75 + connection.orphaned_threshold_reached = False + connection.is_control_connection = True + connection.get_request_id.return_value = 7 + connection.send_msg.return_value = 128 + return connection + def make_session(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] @@ -391,6 +408,268 @@ def test_all_pools_shutdown(self): with pytest.raises(NoHostAvailable): rf.result() + def test_control_connection_fallback_disabled_by_default(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + rf = self.make_response_future(session) + rf.send_request() + + connection.send_msg.assert_not_called() + with pytest.raises(NoHostAvailable): + rf.result() + + def test_control_connection_fallback_updates_connection_keyspace(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + + def set_keyspace_for_all_pools(keyspace, callback): + session.keyspace = keyspace + callback({}) + + session._set_keyspace_for_all_pools.side_effect = set_keyspace_for_all_pools + + connection = self.make_control_connection() + connection.keyspace = 'oldks' + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + result = Mock(spec=ResultMessage, kind=RESULT_KIND_SET_KEYSPACE, new_keyspace='newks') + connection.send_msg.call_args[1]['cb'](result) + + assert connection.keyspace == 'newks' + assert session.keyspace == 'newks' + assert rf.result().current_rows == [] + + def test_control_connection_fallback_when_no_usable_pools(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.SkipPoolCreation + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + connection.send_msg.assert_called_once_with( + rf.message, 7, cb=ANY, encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, result_metadata=[]) + assert connection.in_flight == 1 + assert rf.attempted_hosts == [control_host] + + cb = connection.send_msg.call_args[1]['cb'] + expected_result = (object(), object()) + cb(self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_retries_after_server_error(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + connection.get_request_id.side_effect = [7, 8] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + first_response = Mock(spec=ServerError, info={}) + first_response.summary = 'boom' + first_response.to_exception.return_value = first_response + connection.send_msg.call_args[1]['cb'](first_response) + + rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, False, control_host) + + # The retry decision must come from the future state, not the live connection reference. + rf._connection = Mock(is_control_connection=False) + + rf._retry_task(False, control_host) + + assert connection.send_msg.call_count == 2 + assert connection.send_msg.call_args_list[1][0][0] is rf.message + assert connection.send_msg.call_args_list[1][0][1] == 8 + assert rf.attempted_hosts == [control_host, control_host] + + expected_result = (object(), object()) + connection.send_msg.call_args_list[1][1]['cb']( + self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_fetches_next_page(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + connection.get_request_id.side_effect = [7, 8] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + first_response = self.make_mock_response(['col'], [(1,)]) + first_response.paging_state = b'next-page' + connection.send_msg.call_args[1]['cb'](first_response) + + assert rf.result().current_rows == [(['col'], [(1,)])] + assert rf.has_more_pages + + rf.start_fetching_next_page() + + assert connection.send_msg.call_count == 2 + assert connection.send_msg.call_args_list[1][0][0] is rf.message + assert connection.send_msg.call_args_list[1][0][1] == 8 + assert rf.message.paging_state == b'next-page' + + second_response = self.make_mock_response(['col'], [(2,)]) + connection.send_msg.call_args_list[1][1]['cb'](second_response) + + assert connection.in_flight == 0 + assert rf.result().current_rows == [(['col'], [(2,)])] + + def test_control_connection_fallback_reprepares_prepared_statement(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster.protocol_version = ProtocolVersion.V4 + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + session.submit.side_effect = lambda fn, *args, **kwargs: fn(*args, **kwargs) + + query_id = b'a' * 16 + prepared_statement = Mock( + query_id=query_id, + query_string="SELECT * FROM foobar", + keyspace="FooKeyspace", + result_metadata=[], + result_metadata_id=None) + session.cluster._prepared_statements = {query_id: prepared_statement} + + connection = self.make_control_connection() + connection.keyspace = "FooKeyspace" + connection.get_request_id.side_effect = [7, 8, 9] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + rf.prepared_statement = prepared_statement + assert rf.send_request() + + missing = Mock(spec=PreparedQueryNotFound, info=query_id) + connection.send_msg.call_args_list[0][1]['cb'](missing) + + assert connection.send_msg.call_count == 2 + prepare_message = connection.send_msg.call_args_list[1][0][0] + assert isinstance(prepare_message, PrepareMessage) + assert prepare_message.query == "SELECT * FROM foobar" + assert connection.send_msg.call_args_list[1][0][1] == 8 + + prepared_response = Mock( + spec=ResultMessage, + kind=RESULT_KIND_PREPARED, + query_id=query_id, + column_metadata=[], + result_metadata_id=None) + connection.send_msg.call_args_list[1][1]['cb'](prepared_response) + + assert connection.send_msg.call_count == 3 + assert connection.send_msg.call_args_list[2][0][0] is rf.message + assert connection.send_msg.call_args_list[2][0][1] == 9 + + expected_result = (['col'], [(1,)]) + connection.send_msg.call_args_list[2][1]['cb']( + self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_not_used_when_pool_can_serve(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + pool = Mock(is_shutdown=False) + pool.borrow_connection.side_effect = NoConnectionsAvailable() + session._pools = {'ip1': pool} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + rf = self.make_response_future(session) + rf.send_request() + + connection.send_msg.assert_not_called() + with pytest.raises(NoHostAvailable): + rf.result() + + def test_control_connection_fallback_orphans_stream_on_timeout(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + def send_msg(message, request_id, cb, **kwargs): + connection._requests[request_id] = (cb, kwargs.get('decoder'), kwargs.get('result_metadata')) + return 128 + + connection.send_msg.side_effect = send_msg + + rf = self.make_response_future(session) + rf.send_request() + rf._on_timeout() + + assert 7 in connection.orphaned_request_ids + assert connection.in_flight == 1 + with pytest.raises(OperationTimedOut): + rf.result() + + def test_control_connection_fallback_timeout_without_metadata_host_uses_connection_endpoint(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [] + session._pools = {} + session.cluster.get_control_connection_host.return_value = None + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + def send_msg(message, request_id, cb, **kwargs): + connection._requests[request_id] = (cb, kwargs.get('decoder'), kwargs.get('result_metadata')) + return 128 + + connection.send_msg.side_effect = send_msg + + rf = self.make_response_future(session) + assert rf.send_request() + rf._on_timeout() + + with pytest.raises(OperationTimedOut) as exc_info: + rf.result() + + assert exc_info.value.errors == { + 'control-host': 'Client request timeout. See Session.execute[_async](timeout)' + } + def test_first_pool_shutdown(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] diff --git a/tests/unit/test_session_schema_agreement.py b/tests/unit/test_session_schema_agreement.py new file mode 100644 index 0000000000..ffad687fcc --- /dev/null +++ b/tests/unit/test_session_schema_agreement.py @@ -0,0 +1,204 @@ +from datetime import timedelta +from types import SimpleNamespace +from unittest.mock import Mock +import uuid + +import pytest + +import cassandra.cluster as cluster_module +from cassandra.connection import ConnectionBusy +from cassandra.cluster import ControlConnection, Session, ResultSet +from cassandra.policies import HostDistance, SimpleConvictionPolicy +from cassandra.pool import Host +from cassandra.util import maybe_add_timeout_to_query + + +class FakeTime: + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + + def sleep(self, amount): + self.clock += amount + + +class MockPool: + def __init__(self, host): + self.host = host + self.is_shutdown = False + + +class MockSchemaVersionFuture: + def __init__(self, outcome, auto_complete=True): + self._outcome = outcome + self._auto_complete = auto_complete + self._delivered = False + self._callback_state = None + self._col_names = ("schema_version",) + self._col_types = None + self.has_more_pages = False + self._continuous_paging_session = None + + def _deliver(self): + if self._delivered or self._callback_state is None: + return + + self._delivered = True + callback, errback, callback_args, callback_kwargs, errback_args, errback_kwargs = self._callback_state + if isinstance(self._outcome, Exception): + errback(self._outcome, *errback_args, **errback_kwargs) + else: + row = SimpleNamespace(schema_version=self._outcome) + callback([row], *callback_args, **callback_kwargs) + + def add_callbacks(self, callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + self._callback_state = ( + callback, + errback, + callback_args, + callback_kwargs or {}, + errback_args, + errback_kwargs or {}, + ) + if self._auto_complete: + self._deliver() + return self + + def complete(self): + self._deliver() + + def result(self): + if isinstance(self._outcome, Exception): + raise self._outcome + return ResultSet(self, [SimpleNamespace(schema_version=self._outcome)]) + + +def _host_query_count(session, target_host): + return sum(1 for call in session.execute_async.call_args_list if call.kwargs.get("host") is target_host) + + +def _new_session(schema_versions, distances=None, metadata_request_timeout=timedelta(seconds=2), timeout=2.0): + hosts = [] + connections = {} + distance_map = {} + + if distances is None: + distances = [HostDistance.LOCAL] * len(schema_versions) + + for index, schema_version in enumerate(schema_versions): + host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + hosts.append(host) + distance_map[host] = distances[index] + + cluster = SimpleNamespace( + max_schema_agreement_wait=10, + control_connection=SimpleNamespace( + _timeout=timeout, + _metadata_request_timeout=metadata_request_timeout, + ), + ) + + session = Session.__new__(Session) + session.cluster = cluster + session._profile_manager = SimpleNamespace(distance=lambda host: distance_map.get(host, HostDistance.LOCAL)) + session._pools = {} + session.is_shutdown = False + + for host, schema_version in zip(hosts, schema_versions): + connection = Mock(endpoint=host.endpoint) + connection.future_outcomes = [schema_version] + session._pools[host] = MockPool(host) + connections[host] = connection + + def execute_async(query, parameters=None, trace=False, + custom_payload=None, execution_profile=None, + paging_state=None, timeout=None, host=None, execute_as=None): + connection = connections[host] + outcome = connection.future_outcomes.pop(0) if len(connection.future_outcomes) > 1 else connection.future_outcomes[0] + return MockSchemaVersionFuture(outcome) + + session.execute_async = Mock(side_effect=execute_async) + + return session, hosts, connections + + +def test_wait_for_schema_agreement_retries_with_module_time(monkeypatch): + session, hosts, connections = _new_session(["a", "b"]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + connections[hosts[1]].future_outcomes = ["b", "a"] + + assert session.wait_for_schema_agreement(wait_time=1) + assert clock.clock == pytest.approx(0.2) + for host in hosts: + assert _host_query_count(session, host) == 2 + + +@pytest.mark.parametrize("wait_time", [0, -1]) +def test_wait_for_schema_agreement_rejects_non_positive_wait_time(wait_time): + session, _, _ = _new_session(["a"]) + + with pytest.raises(ValueError, match="wait_time must be greater than 0"): + session.wait_for_schema_agreement(wait_time=wait_time) + + assert session.execute_async.call_count == 0 + + +def test_wait_for_schema_agreement_returns_false_when_no_hosts_match_scope(monkeypatch): + session, _, _ = _new_session(["a"], distances=[HostDistance.IGNORED]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + + assert session.wait_for_schema_agreement(wait_time=1) is False + assert session.execute_async.call_count == 0 + assert clock.clock == pytest.approx(1.0) + + +def test_wait_for_schema_agreement_uses_host_targeted_session_queries(): + session, hosts, _ = _new_session(["a", "a"]) + + assert session.wait_for_schema_agreement(wait_time=0.1) + + expected_query = maybe_add_timeout_to_query( + ControlConnection._SELECT_SCHEMA_LOCAL, + timedelta(seconds=2), + ) + assert session.execute_async.call_count == 2 + assert [call.args[0] for call in session.execute_async.call_args_list] == [expected_query, expected_query] + assert [call.kwargs["host"] for call in session.execute_async.call_args_list] == hosts + for call in session.execute_async.call_args_list: + assert 0 < call.kwargs["timeout"] <= 0.1 + + +def test_wait_for_schema_agreement_retries_after_host_targeted_query_error(monkeypatch): + session, hosts, connections = _new_session(["a", "a"]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + connections[hosts[1]].future_outcomes = [ConnectionBusy("connection overloaded"), "a"] + + assert session.wait_for_schema_agreement(wait_time=1) + assert clock.clock == pytest.approx(0.2) + for host in hosts: + assert _host_query_count(session, host) == 2 + + +def test_wait_for_schema_agreement_queries_hosts_in_order_under_one_deadline(monkeypatch): + session, hosts, _ = _new_session(["a", "a", "a"]) + clock = FakeTime() + monkeypatch.setattr(cluster_module, "time", clock) + + def execute_async(query, parameters=None, trace=False, + custom_payload=None, execution_profile=None, + paging_state=None, timeout=None, host=None, execute_as=None): + clock.sleep(0.01) + return MockSchemaVersionFuture("a") + + session.execute_async = Mock(side_effect=execute_async) + + assert session.wait_for_schema_agreement(wait_time=1) + assert [call.kwargs["host"] for call in session.execute_async.call_args_list] == hosts