diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index 28784749c4..ac3943ef57 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -1,6 +1,6 @@
version: 2
updates:
- - package-ecosystem: "pip"
+ - package-ecosystem: "uv"
directory: "/docs"
schedule:
interval: "daily"
diff --git a/.github/workflows/build-pre-release.yml b/.github/workflows/build-pre-release.yml
index e1326b6aa5..f6473c1cc3 100644
--- a/.github/workflows/build-pre-release.yml
+++ b/.github/workflows/build-pre-release.yml
@@ -15,7 +15,7 @@ on:
jobs:
build-and-publish:
- uses: ./.github/workflows/lib-build-and-push.yml
+ uses: ./.github/workflows/lib-build.yml
with:
python-version: ${{ inputs.python-version }}
target: ${{ inputs.target }}
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index 4e73e811fa..a1a6c854c7 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -10,11 +10,12 @@ on:
jobs:
build-and-publish:
name: "Build wheels"
- uses: ./.github/workflows/lib-build-and-push.yml
- with:
- upload: false
+ uses: ./.github/workflows/lib-build.yml
- # TODO: Remove when https://github.com/pypa/gh-action-pypi-publish/issues/166 is fixed and update build-and-publish.with.upload to ${{ endsWith(github.event.ref, 'scylla') }}
+ # Publishing is a separate job (not inside the reusable workflow) because PyPI Trusted Publishing
+ # requires the *caller* workflow path in the OIDC token. A reusable workflow would embed its own
+ # path instead, causing an `invalid-publisher` error on the PyPI side.
+ # See: https://github.com/pypa/gh-action-pypi-publish/issues/166
publish:
name: "Publish wheels to PyPi"
if: ${{ endsWith(github.event.ref, 'scylla') }}
@@ -23,11 +24,11 @@ jobs:
permissions:
id-token: write
steps:
- - uses: actions/download-artifact@v4
+ - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
path: dist
merge-multiple: true
- - uses: pypa/gh-action-pypi-publish@release/v1
+ - uses: pypa/gh-action-pypi-publish@cef2210092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
with:
skip-existing: true
diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml
index 3e1f1067d7..ebfe383047 100644
--- a/.github/workflows/build-test.yml
+++ b/.github/workflows/build-test.yml
@@ -18,6 +18,4 @@ jobs:
test-wheels-build:
name: "Test wheels building"
if: "!contains(github.event.pull_request.labels.*.name, 'disable-test-build')"
- uses: ./.github/workflows/lib-build-and-push.yml
- with:
- upload: false
\ No newline at end of file
+ uses: ./.github/workflows/lib-build.yml
\ No newline at end of file
diff --git a/.github/workflows/call_jira_sync.yml b/.github/workflows/call_jira_sync.yml
new file mode 100644
index 0000000000..0855246f48
--- /dev/null
+++ b/.github/workflows/call_jira_sync.yml
@@ -0,0 +1,18 @@
+name: Sync Jira Based on PR Events
+
+on:
+ pull_request_target:
+ types: [opened, edited, ready_for_review, review_requested, labeled, unlabeled, closed]
+
+permissions:
+ contents: read
+ pull-requests: write
+ issues: write
+
+jobs:
+ jira-sync:
+ uses: scylladb/github-automation/.github/workflows/main_pr_events_jira_sync.yml@83115dc2553dbf968e73271e97fc7aac16b8145a # main 2026-05-20
+ with:
+ caller_action: ${{ github.event.action }}
+ secrets:
+ caller_jira_auth: ${{ secrets.USER_AND_KEY_FOR_JIRA_AUTOMATION }}
diff --git a/.github/workflows/docs-pages.yaml b/.github/workflows/docs-pages.yml
similarity index 61%
rename from .github/workflows/docs-pages.yaml
rename to .github/workflows/docs-pages.yml
index 31f8dc74c5..a413e3317e 100644
--- a/.github/workflows/docs-pages.yaml
+++ b/.github/workflows/docs-pages.yml
@@ -2,6 +2,9 @@ name: "Docs / Publish"
# For more information,
# see https://sphinx-theme.scylladb.com/stable/deployment/production.html#available-workflows
+permissions:
+ contents: write
+
on:
push:
branches:
@@ -9,28 +12,36 @@ on:
- 'branch-**'
paths:
- "docs/**"
+ - ".github/workflows/docs-pages.yml"
+ - "cassandra/**"
+ - "pyproject.toml"
+ - "setup.py"
+ - "CHANGELOG.rst"
workflow_dispatch:
jobs:
release:
- runs-on: ubuntu-24.04
+ runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ github.event.repository.default_branch }}
persist-credentials: false
fetch-depth: 0
- - name: Set up Python
- uses: actions/setup-python@v5
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
- python-version: '3.10'
- - name: Set up env
- run: make -C docs setupenv
+ working-directory: docs
+ enable-cache: true
+
- name: Build docs
run: make -C docs multiversion
+
- name: Build redirects
run: make -C docs redirects
+
- name: Deploy docs to GitHub Pages
run: ./docs/_utils/deploy.sh
env:
diff --git a/.github/workflows/docs-pr.yaml b/.github/workflows/docs-pr.yaml
deleted file mode 100644
index 28a74f2e58..0000000000
--- a/.github/workflows/docs-pr.yaml
+++ /dev/null
@@ -1,30 +0,0 @@
-name: "Docs / Build PR"
-# For more information,
-# see https://sphinx-theme.scylladb.com/stable/deployment/production.html#available-workflows
-
-on:
- pull_request:
- branches:
- - master
- - 'branch-**'
- paths:
- - "docs/**"
- workflow_dispatch:
-
-jobs:
- build:
- runs-on: ubuntu-24.04
- steps:
- - name: Checkout
- uses: actions/checkout@v4
- with:
- persist-credentials: false
- fetch-depth: 0
- - name: Set up Python
- uses: actions/setup-python@v5
- with:
- python-version: '3.10'
- - name: Set up env
- run: make -C docs setupenv
- - name: Build docs
- run: make -C docs test
diff --git a/.github/workflows/docs-pr.yml b/.github/workflows/docs-pr.yml
new file mode 100644
index 0000000000..1881c227ed
--- /dev/null
+++ b/.github/workflows/docs-pr.yml
@@ -0,0 +1,46 @@
+name: "Docs / Build PR"
+# For more information,
+# see https://sphinx-theme.scylladb.com/stable/deployment/production.html#available-workflows
+
+permissions:
+ contents: read
+
+on:
+ push:
+ branches:
+ - master
+ paths:
+ - "docs/**"
+ - ".github/workflows/docs-pr.yml"
+ - "cassandra/**"
+ - "pyproject.toml"
+ - "setup.py"
+ - "CHANGELOG.rst"
+ pull_request:
+ paths:
+ - "docs/**"
+ - ".github/workflows/docs-pr.yml"
+ - "cassandra/**"
+ - "pyproject.toml"
+ - "setup.py"
+ - "CHANGELOG.rst"
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+ with:
+ persist-credentials: false
+ fetch-depth: 0
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
+ with:
+ working-directory: docs
+ enable-cache: true
+
+ - name: Build docs
+ run: make -C docs test
diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml
index bccbdc63cc..5e76d6bbb4 100644
--- a/.github/workflows/integration-tests.yml
+++ b/.github/workflows/integration-tests.yml
@@ -5,6 +5,18 @@ on:
branches:
- master
- 'branch-**'
+ paths-ignore:
+ - docs/*
+ - examples/*
+ - scripts/*
+ - .gitignore
+ - '*.rst'
+ - '*.ini'
+ - LICENSE
+ - .github/dependabot.yml
+ - .github/pull_request_template.md
+ - "*.md"
+ - .github/workflows/docs-*
pull_request:
paths-ignore:
- docs/*
@@ -16,6 +28,8 @@ on:
- LICENSE
- .github/dependabot.yml
- .github/pull_request_template.md
+ - "*.md"
+ - .github/workflows/docs-*
workflow_dispatch:
jobs:
@@ -23,23 +37,29 @@ jobs:
name: test ${{ matrix.event_loop_manager }} (${{ matrix.python-version }})
if: "!contains(github.event.pull_request.labels.*.name, 'disable-integration-tests')"
runs-on: ubuntu-24.04
+ env:
+ SCYLLA_VERSION: release:2026.1
strategy:
fail-fast: false
matrix:
java-version: [8]
- python-version: ["3.9", "3.11", "3.12", "3.13"]
+ python-version: ["3.11", "3.12", "3.13", "3.14", "3.14t"]
event_loop_manager: ["libev", "asyncio", "asyncore"]
exclude:
- python-version: "3.12"
event_loop_manager: "asyncore"
- python-version: "3.13"
event_loop_manager: "asyncore"
+ - python-version: "3.14"
+ event_loop_manager: "asyncore"
+ - python-version: "3.14t"
+ event_loop_manager: "asyncore"
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Set up JDK ${{ matrix.java-version }}
- uses: actions/setup-java@v4
+ uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5.2.0
with:
java-version: ${{ matrix.java-version }}
distribution: 'adopt'
@@ -48,7 +68,7 @@ jobs:
run: sudo apt-get install libev4 libev-dev
- name: Install uv
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
python-version: ${{ matrix.python-version }}
@@ -57,17 +77,25 @@ jobs:
- name: Build driver
run: uv sync
+ - name: Cache Scylla download
+ uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
+ with:
+ path: ~/.ccm/repository
+ key: scylla-${{ env.SCYLLA_VERSION }}-${{ runner.os }}
+
# This is to get honest accounting of test time vs download time vs build time.
# Not strictly necessary for running tests.
- name: Download Scylla
run: |
- export SCYLLA_VERSION='release:6.2'
uv run ccm create scylla-driver-temp -n 1 --scylla --version ${SCYLLA_VERSION}
uv run ccm remove
- name: Test with pytest
+ env:
+ EVENT_LOOP_MANAGER: ${{ matrix.event_loop_manager }}
+ PROTOCOL_VERSION: 4
run: |
- export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }}
- export SCYLLA_VERSION='release:6.2'
- export PROTOCOL_VERSION=4
+ if [[ "${{ matrix.python-version }}" =~ t$ ]]; then
+ export PYTHON_GIL=0
+ fi
uv run pytest tests/integration/standard/ tests/integration/cqlengine/
diff --git a/.github/workflows/lib-build-and-push.yml b/.github/workflows/lib-build.yml
similarity index 73%
rename from .github/workflows/lib-build-and-push.yml
rename to .github/workflows/lib-build.yml
index b68ef4eba5..f6959ddfec 100644
--- a/.github/workflows/lib-build-and-push.yml
+++ b/.github/workflows/lib-build.yml
@@ -1,14 +1,8 @@
-name: Build and upload to PyPi
+name: Build wheels
on:
workflow_call:
inputs:
- upload:
- description: 'Upload to PyPI'
- type: boolean
- required: false
- default: false
-
python-version:
description: 'Python version to run on'
type: string
@@ -61,7 +55,7 @@ jobs:
was_added=1
elif [[ "${target}" == "macos-x86" ]]; then
[ -n "$was_added" ] && echo -n "," >> /tmp/matrix.json
- echo -n '{"os":"macos-13", "target": "macos-x86"}' >> /tmp/matrix.json
+ echo -n '{"os":"macos-15-intel", "target": "macos-x86"}' >> /tmp/matrix.json
was_added=1
elif [[ "${target}" == "macos-arm" ]]; then
[ -n "$was_added" ] && echo -n "," >> /tmp/matrix.json
@@ -83,11 +77,11 @@ jobs:
include: ${{ fromJson(needs.prepare-matrix.outputs.matrix) }}
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Checkout tag ${{ inputs.target_tag }}
if: inputs.target_tag != ''
- uses: actions/checkout@v4
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ inputs.target_tag }}
@@ -98,28 +92,26 @@ jobs:
echo "CIBW_TEST_COMMAND=true" >> $GITHUB_ENV;
echo "CIBW_TEST_COMMAND_WINDOWS=(exit 0)" >> $GITHUB_ENV;
echo "CIBW_TEST_SKIP=*" >> $GITHUB_ENV;
- echo "CIBW_SKIP=cp2* cp36* pp36* cp37* pp37* cp38* pp38* *i686 *musllinux*" >> $GITHUB_ENV;
- echo "CIBW_BUILD=cp3* pp3*" >> $GITHUB_ENV;
echo "CIBW_BEFORE_TEST=true" >> $GITHUB_ENV;
echo "CIBW_BEFORE_TEST_WINDOWS=(exit 0)" >> $GITHUB_ENV;
- name: Install uv
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
python-version: ${{ inputs.python-version }}
- name: Install cibuildwheel
run: |
- uv tool install 'cibuildwheel==2.22.0'
+ uv tool install 'cibuildwheel==3.2.1'
- name: Install OpenSSL for Windows
if: runner.os == 'Windows'
run: |
- choco install openssl --version=3.5.1 -f -y --no-progress
+ choco install openssl.light --no-progress -y
- name: Install Conan
if: runner.os == 'Windows'
- uses: turtlebrowser/get-conan@main
+ uses: turtlebrowser/get-conan@c171f295f3f507360ee018736a6608731aa2109d # v1.2
- name: Configure libev for Windows
if: runner.os == 'Windows'
@@ -127,7 +119,7 @@ jobs:
conan profile detect
conan install conanfile.py
- - name: Install OpenSSL for MacOS
+ - name: Install libev for MacOS
if: runner.os == 'MacOs'
run: |
brew install libev
@@ -136,9 +128,9 @@ jobs:
if: runner.os == 'MacOS'
run: |
##### Set MACOSX_DEPLOYMENT_TARGET
- if [ "${{ matrix.os }}" == "macos-13" ]; then
- echo "MACOSX_DEPLOYMENT_TARGET=13.0" >> $GITHUB_ENV;
- echo "Enforcing target deployment for 13.0"
+ if [ "${{ matrix.os }}" == "macos-15-intel" ]; then
+ echo "MACOSX_DEPLOYMENT_TARGET=15.0" >> $GITHUB_ENV;
+ echo "Enforcing target deployment for 15.0"
elif [ "${{ matrix.os }}" == "macos-14" ]; then
echo "MACOSX_DEPLOYMENT_TARGET=14.0" >> $GITHUB_ENV;
echo "Enforcing target deployment for 14.0"
@@ -148,14 +140,14 @@ jobs:
if: matrix.target != 'linux-aarch64'
shell: bash
run: |
- GITHUB_WORKFLOW_REF="scylladb/python-driver/.github/workflows/lib-build-and-push.yml@refs/heads/master" cibuildwheel --output-dir wheelhouse
+ cibuildwheel --output-dir wheelhouse
- name: Build wheels for linux aarch64
if: matrix.target == 'linux-aarch64'
run: |
- GITHUB_WORKFLOW_REF="scylladb/python-driver/.github/workflows/lib-build-and-push.yml@refs/heads/master" CIBW_BUILD="cp3*" cibuildwheel --archs aarch64 --output-dir wheelhouse
+ CIBW_BUILD="cp3*" cibuildwheel --archs aarch64 --output-dir wheelhouse
- - uses: actions/upload-artifact@v4
+ - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: wheels-${{ matrix.target }}-${{ matrix.os }}
path: ./wheelhouse/*.whl
@@ -164,34 +156,17 @@ jobs:
name: Build source distribution
runs-on: ubuntu-24.04
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install uv
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
python-version: ${{ inputs.python-version }}
- name: Build sdist
run: uv build --sdist
- - uses: actions/upload-artifact@v4
+ - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: source-dist
path: dist/*.tar.gz
-
- upload_pypi:
- if: inputs.upload
- needs: [build-wheels, build-sdist]
- runs-on: ubuntu-24.04
- permissions:
- id-token: write
-
- steps:
- - uses: actions/download-artifact@v4
- with:
- path: dist
- merge-multiple: true
-
- - uses: pypa/gh-action-pypi-publish@release/v1
- with:
- skip-existing: true
diff --git a/.github/workflows/publish-manually.yml b/.github/workflows/publish-manually.yml
index d2dda897ed..5b9298fb7f 100644
--- a/.github/workflows/publish-manually.yml
+++ b/.github/workflows/publish-manually.yml
@@ -1,5 +1,8 @@
name: Build and upload to PyPi manually
+permissions:
+ contents: read
+
on:
workflow_dispatch:
inputs:
@@ -36,15 +39,17 @@ on:
jobs:
build-and-publish:
name: "Build wheels"
- uses: ./.github/workflows/lib-build-and-push.yml
+ uses: ./.github/workflows/lib-build.yml
with:
- upload: false
python-version: ${{ inputs.python-version }}
ignore_tests: ${{ inputs.ignore_tests }}
target_tag: ${{ inputs.target_tag }}
target: ${{ inputs.target }}
- # TODO: Remove when https://github.com/pypa/gh-action-pypi-publish/issues/166 is fixed and update build-and-publish.with.upload to ${{ inputs.upload }}
+ # Publishing is a separate job (not inside the reusable workflow) because PyPI Trusted Publishing
+ # requires the *caller* workflow path in the OIDC token. A reusable workflow would embed its own
+ # path instead, causing an `invalid-publisher` error on the PyPI side.
+ # See: https://github.com/pypa/gh-action-pypi-publish/issues/166
publish:
name: "Publish wheels to PyPi"
needs: build-and-publish
@@ -53,11 +58,11 @@ jobs:
permissions:
id-token: write
steps:
- - uses: actions/download-artifact@v4
+ - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
path: dist
merge-multiple: true
- - uses: pypa/gh-action-pypi-publish@release/v1
+ - uses: pypa/gh-action-pypi-publish@cef2210092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
with:
skip-existing: true
diff --git a/.gitignore b/.gitignore
index b96d8702d6..881012f340 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,6 +43,11 @@ tests/unit/cython/bytesio_testhelper.c
#iPython
*.ipynb
+uv.lock
+.venv/
+
+
+
# Files from upstream that we don't need
Jenkinsfile
Jenkinsfile.bak
@@ -64,3 +69,22 @@ docs/core_graph.rst
docs/geo_types.rst
docs/graph.rst
docs/graph_fluent.rst
+
+# Personal list of items to do
+TODO.md
+
+# Codex - AI assistant metadata
+.codex/
+.codex-cache/
+.codex-config.json
+.codex-settings.json
+codex.log
+AGENTS.md
+
+# Claude - AI assistant metadata
+.anthropic/
+.claude/
+claude.log
+claude_history.json
+claude_config.json
+CLAUDE.md
\ No newline at end of file
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 80f8d52d7a..39a8aca069 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,6 +1,151 @@
+3.29.10
+=======
+May 10, 2026
+
+Features
+--------
+* Fast-path ``lookup_casstype()`` for simple type names
+* Add ``Session.wait_for_schema_agreement``
+
+Bug Fixes
+---------
+* Fix CQL injection in ``Connection.set_keyspace_blocking`` and ``Connection.set_keyspace_async``
+* Fix libev shutdown crashes by correcting atexit registration
+* Handle ``None`` ``control_connection_timeout`` in ``wait_for_schema_agreement``
+* Clean up failed heartbeat sends
+* Fix ``ExponentialBackoffRetryPolicy.__init__`` super() call
+* Correct ``clustering_key`` to ``clustering`` in column kind filter
+* Fix inverted cooldown check in ``_get_shard_aware_endpoint``
+
+Others
+------
+* Deprecate ``ControlConnection.wait_for_schema_agreement``
+* Add timeout and in-flight observability to ``OperationTimedOut``
+* Drop per-query connection log
+
+3.29.9
+======
+March 18, 2026
+
+Features
+--------
+* Add Private Link support via client routes handler
+* Add optional query_params parameter to QueryMessage
+
+Bug Fixes
+---------
+* Fix segmentation fault in libev prepare_callback during shutdown
+* Add null checks to io_callback and timer_callback in libev wrapper
+* Fix RecursionError in execute_concurrent on synchronous errbacks
+* Fix floating-point precision loss for timestamps far from epoch
+
+Others
+------
+* Cache parsed tablet routing type in ResponseFuture
+* Remove deprecated setup_requires in favor of PEP 517 build-system.requires
+* Update dependency hatchling to v1.29.0
+
+3.29.8
+======
+February 09, 2026
+
+Features
+--------
+* Add frozen parameter to collection columns with FULL index support
+* Include original error in ConnectionShutdown messages
+
+Bug Fixes
+---------
+* Fix IntStat comparison operators and metrics cleanup on shutdown
+* Fix NumPy 2.0 compatibility in numpy_parser
+* Fix race condition during host IP address update
+* Fix infinite retry when single host fails with server error
+* Don't mark node down when control connection fails to connect
+* Call on_add before distance to properly initialize lbp
+* Don't check if host is in initial contact points when setting default local_dc
+* Pull version information from system.local when version info is not present
+* Fix missing call to superclass ``__init__`` during object initialization
+
+Others
+------
+* Remove scales dependency with self-contained metrics implementation
+* Migrate from pytz to zoneinfo
+* Remove Python 2 compatibility code
+* Optimize write path in protocol.py to reduce copies
+* TokenAwarePolicy: remove redundant check if a table is using tablets
+* Don't create Host instances with random host_id
+* Use endpoint instead of Host in _try_connect
+* Remove support for protocols <3 from cython files
+* Return empty query plan if there are no live hosts
+* Replace asynctest with stdlib mock
+
+3.29.7
+======
+December 08, 2025
+
+Bug Fixes
+---------
+* Make compression=None a valid case (#610)
+
+3.29.6
+======
+November 27, 2025
+
+* Rename connection_metadata to client_routes (#608)
+* TokenAwarePolicy: enable shuffling by default (#478)
+* Add support of LWT flag for BatchStatement (#606)
+* Add support of CONNECTION_METADATA_CHANGE event (#601)
+* Add LWT support (#584)
+* Add support for Python 3.14 (#566)
+* Fix dict handling in pool and metrics (#595)
+* Remove serverless code (#590)
+* tests: drop `sure` package (#592)
+* compression: better handle configuration problems (#585)
+
+3.29.5
+======
+November 5, 2025
+
+Bug Fixes
+---------
+* Update TokenAwarePolicy.make_query_plan to schedule to replicas first (#548)
+* Drop _tablets_routing_v1 flag from token-aware policy (#547)
+* Fix dc aware and rack aware policies initialization (#546)
+* Fix Cluster.metadata_request_timeout and default it from control_connection_timeout (#539)
+
+Others
+------
+* Drop python 3.9 support (#564)
+
+3.29.4
+======
+August 16, 2025
+
+Features
+--------
+* Add Cluster.application_info to report application information to server (#486)
+* Move to uv package manager (#496)
+
+Bug Fixes
+---------
+* Fix deadlocks on evicting connection in HostConnectionPool and ConnectionPool (#499)
+* Fix libevreactor crashing when connection added and closed right away (#508)
+
+Others
+------
+* Remove outdated protocols support (v1 and v2) (#493, #525)
+* Remove DSE integration tests (#502)
+* Optimise shard port allocator (#506)
+* Remove self.assert (#507)
+* Minor performance improvement for make_token_replica_map (#513)
+* Remove in-memory Scylla tables support (#518)
+* Add optional dependencies for SNAPPY and LZ4 compressors (#517)
+* Remove support for protocol versions not supported by Scylla (#492)
+* Set monitor_reporting_enabled False by default (#523)
+
3.29.3
======
-Mart 11, 2025
+March 11, 2025
Bug Fixes
---------
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index c635fd8c1b..82bf21e52f 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -18,7 +18,7 @@ good bug reports. They will not be repeated in detail here, but in general, the
Pull Requests
-------------
If you're able to fix a bug yourself, you can `fork the repository `_ and submit a `Pull Request `_ with the fix.
-Please include tests demonstrating the issue and fix. For examples of how to run the tests, consult the `dev README `_.
+Please include tests demonstrating the issue and fix. For examples of how to run the tests, consult the further parts of this document.
Design and Implementation Guidelines
------------------------------------
@@ -26,3 +26,110 @@ Design and Implementation Guidelines
- This project follows `semantic versioning `_, so breaking API changes will only be introduced in major versions.
- Legacy ``cqlengine`` has varying degrees of overreaching client-side validation. Going forward, we will avoid client validation where server feedback is adequate and not overly expensive.
- When writing tests, try to achieve maximal coverage in unit tests (where it is faster to run across many runtimes). Integration tests are good for things where we need to test server interaction, or where it is important to test across different server versions (emulating in unit tests would not be effective).
+
+Dev setup
+=========
+
+We recommend using `uv` tool for running tests, linters and basically everything else,
+since it makes Python tooling ecosystem mostly usable.
+To install it, see instructions at https://docs.astral.sh/uv/getting-started/installation/
+The rest of this document assumes you have `uv` installed.
+
+It is also strongly recommended to use C/C++-caching tool like ccache or sccache.
+When modifying driver files, rebuilding Cython modules is often necessary.
+Without caching, each such rebuild may take over a minute. Caching usually brings it
+down to about 2-3 seconds.
+
+**Important:** After modifying any ``.py`` file under ``cassandra/`` that is
+Cython-compiled (such as ``query.py``, ``protocol.py``, ``cluster.py``, etc.),
+extensions must be rebuilt before running tests. If you always use ``uv run``
+(e.g. ``uv run pytest``), this is handled automatically via the ``cache-keys``
+configuration in ``pyproject.toml``. If you invoke ``pytest`` directly, you can
+rebuild with::
+
+ uv sync --reinstall-package scylla-driver
+
+Without rebuilding, Python will load the stale compiled extension (``.so`` / ``.pyd``)
+instead of your modified ``.py`` source, and your changes will not actually be tested.
+The test suite will emit a warning if it detects this situation.
+
+Building the Docs
+=================
+
+To build and preview the documentation for the ScyllaDB Python driver locally, you must first manually install `python-driver`.
+This is necessary for autogenerating the reference documentation of the driver.
+You can find detailed instructions on how to install the driver in the `Installation guide `_.
+
+After installing the driver, you can build the documentation:
+- Make sure you have Python version compatible with docs. You can see supported version in ``docs/pyproject.toml`` - look for ``python`` in ``tool.poetry.dependencies`` section.
+- Install poetry: ``pip install poetry``
+- To preview docs in your browser: ``make -C docs preview``
+
+Tests
+=====
+
+Running Unit Tests
+------------------
+Unit tests can be run like so::
+
+ uv run pytest tests/unit
+ EVENT_LOOP_MANAGER=gevent uv run pytest tests/unit/io/test_geventreactor.py
+ EVENT_LOOP_MANAGER=eventlet uv run pytest tests/unit/io/test_eventletreactor.py
+
+You can run a specific test method like so::
+
+ uv run pytest tests/unit/test_connection.py::ConnectionTest::test_bad_protocol_version
+
+Running Integration Tests
+-------------------------
+In order to run integration tests, you must specify a version to run using either of:
+* ``SCYLLA_VERSION`` e.g. ``release:2025.2``
+* ``CASSANDRA_VERSION``
+environment variable::
+
+ SCYLLA_VERSION="release:2025.2" uv run pytest tests/integration/standard tests/integration/cqlengine/
+
+Or you can specify a scylla/cassandra directory (to test unreleased versions)::
+
+ SCYLLA_VERSION=/path/to/scylla uv run pytest tests/integration/standard/
+
+Specifying the usage of an already running Scylla cluster
+------------------------------------------------------------
+The test will start the appropriate Scylla clusters when necessary but if you don't want this to happen because a Scylla cluster is already running the flag ``USE_CASS_EXTERNAL`` can be used, for example::
+
+ USE_CASS_EXTERNAL=1 SCYLLA_VERSION='release:5.1' uv run pytest tests/integration/standard
+
+Specify a Protocol Version for Tests
+------------------------------------
+The protocol version defaults to:
+- 4 for Scylla >= 3.0 and Scylla Enterprise > 2019.
+- 3 for older versions of Scylla
+- 5 for Cassandra >= 4.0, 4 for Cassandra >= 2.2, 3 for Cassandra >= 2.1, 2 for Cassandra >= 2.0
+You can overwrite it with the ``PROTOCOL_VERSION`` environment variable::
+
+ PROTOCOL_VERSION=3 SCYLLA_VERSION="release:5.1" uv run pytest tests/integration/standard tests/integration/cqlengine/
+
+Seeing Test Logs in Real Time
+-----------------------------
+Sometimes it's useful to output logs for the tests as they run::
+
+ uv run pytest -s tests/unit/
+
+Use tee to capture logs and see them on your terminal::
+
+ uv run pytest -s tests/unit/ 2>&1 | tee test.log
+
+
+Running the Benchmarks
+======================
+There needs to be a version of Scyll running locally so before running the benchmarks, if ccm is installed:
+
+ uv run ccm create benchmark_cluster --scylla -v release:2025.2 -n 1 -s
+
+To run the benchmarks, pick one of the files under the ``benchmarks/`` dir and run it::
+
+ uv run benchmarks/future_batches.py
+
+There are a few options. Use ``--help`` to see them all::
+
+ uv run benchmarks/future_batches.py --help
diff --git a/MAINTENANCE.md b/MAINTENANCE.md
new file mode 100644
index 0000000000..8fc860ac4b
--- /dev/null
+++ b/MAINTENANCE.md
@@ -0,0 +1,13 @@
+Releasing
+=========
+* Run the tests and ensure they all pass
+* Update the version in ``cassandra/__init__.py``
+* Add the new version in ``docs/conf.py`` (variables: ``TAGS``, ``LATEST_VERSION``, ``DEPRECATED_VERSIONS``).
+ * For patch version releases (like ``3.26.8-scylla -> 3.26.9-scylla``) replace the old version with new one in ``TAGS`` and update ``LATEST_VERSION``.
+ * For minor version releases (like ``3.26.9-scylla -> 3.27.0-scylla``) add new version to ``TAGS``, update ``LATEST_VERSION`` and add previous minor version to ``DEPRECATED_VERSIONS``.
+* Commit the version changes, e.g. ``git commit -m 'Release 3.26.9'``
+* Tag the release. For example: ``git tag -a 3.26.9-scylla -m 'Release 3.26.9'``
+* Push the tag and new ``master`` SIMULTANEOUSLY: ``git push --atomic origin master v6.0.21-scylla``
+* Now new version and its docs should be automatically published. Check `PyPI `_ and `docs `_ to make sure its there.
+* If you didn't push branch and tag simultaneously (or doc publishing failed for other reason) then restart the relevant job from GitHub Actions UI.
+* Publish a GitHub Release and a post on community forum.
diff --git a/README-dev.rst b/README-dev.rst
deleted file mode 100644
index f158226de0..0000000000
--- a/README-dev.rst
+++ /dev/null
@@ -1,102 +0,0 @@
-Releasing
-=========
-* Run the tests and ensure they all pass
-* Update the version in ``cassandra/__init__.py``
-* Add the new version in ``docs/conf.py`` (variables: ``TAGS``, ``LATEST_VERSION``, ``DEPRECATED_VERSIONS``).
- * For patch version releases (like ``3.26.8-scylla -> 3.26.9-scylla``) replace the old version with new one in ``TAGS`` and update ``LATEST_VERSION``.
- * For minor version releases (like ``3.26.9-scylla -> 3.27.0-scylla``) add new version to ``TAGS``, update ``LATEST_VERSION`` and add previous minor version to ``DEPRECATED_VERSIONS``.
-* Commit the version changes, e.g. ``git commit -m 'Release 3.26.9'``
-* Tag the release. For example: ``git tag -a 3.26.9-scylla -m 'Release 3.26.9'``
-* Push the tag and new ``master`` SIMULTANEOUSLY: ``git push --atomic origin master v6.0.21-scylla``
-* Now new version and its docs should be automatically published. Check `PyPI `_ and `docs `_ to make sure its there.
-* If you didn't push branch and tag simultaneously (or doc publishing failed for other reason) then restart the relevant job from GitHub Actions UI.
-* Publish a GitHub Release and a post on community forum.
-
-Building the Docs
-=================
-
-To build and preview the documentation for the ScyllaDB Python driver locally, you must first manually install `python-driver`.
-This is necessary for autogenerating the reference documentation of the driver.
-You can find detailed instructions on how to install the driver in the `Installation guide `_.
-
-After installing the driver, you can build the documentation:
-- Make sure you have Python version compatible with docs. You can see supported version in ``docs/pyproject.toml`` - look for ``python`` in ``tool.poetry.dependencies`` section.
-- Install poetry: ``pip install poetry``
-- To preview docs in your browser: ``make -C docs preview``
-
-Tooling
-=======
-
-We recommend using `uv` tool for running tests, linters and basically everything else,
-since it makes Python tooling ecosystem mostly usable.
-To install it, see instructions at https://docs.astral.sh/uv/getting-started/installation/
-The rest of this document assumes you have `uv` installed.
-
-Tests
-=====
-
-Running Unit Tests
-------------------
-Unit tests can be run like so::
-
- uv run pytest tests/unit
- EVENT_LOOP_MANAGER=gevent uv run pytest tests/unit/io/test_geventreactor.py
- EVENT_LOOP_MANAGER=eventlet uv run pytest tests/unit/io/test_eventletreactor.py
-
-You can run a specific test method like so::
-
- uv run pytest tests/unit/test_connection.py::ConnectionTest::test_bad_protocol_version
-
-Running Integration Tests
--------------------------
-In order to run integration tests, you must specify a version to run using either of:
-* ``SCYLLA_VERSION`` e.g. ``release:5.1``
-* ``CASSANDRA_VERSION``
-environment variable::
-
- SCYLLA_VERSION="release:5.1" uv run pytest tests/integration/standard tests/integration/cqlengine/
-
-Or you can specify a scylla/cassandra directory (to test unreleased versions)::
-
- SCYLLA_VERSION=/path/to/scylla uv run pytest tests/integration/standard/
-
-Specifying the usage of an already running Scylla cluster
-------------------------------------------------------------
-The test will start the appropriate Scylla clusters when necessary but if you don't want this to happen because a Scylla cluster is already running the flag ``USE_CASS_EXTERNAL`` can be used, for example::
-
- USE_CASS_EXTERNAL=1 SCYLLA_VERSION='release:5.1' uv run pytest tests/integration/standard
-
-Specify a Protocol Version for Tests
-------------------------------------
-The protocol version defaults to:
-- 4 for Scylla >= 3.0 and Scylla Enterprise > 2019.
-- 3 for older versions of Scylla
-- 5 for Cassandra >= 4.0, 4 for Cassandra >= 2.2, 3 for Cassandra >= 2.1, 2 for Cassandra >= 2.0
-You can overwrite it with the ``PROTOCOL_VERSION`` environment variable::
-
- PROTOCOL_VERSION=3 SCYLLA_VERSION="release:5.1" uv run pytest tests/integration/standard tests/integration/cqlengine/
-
-Seeing Test Logs in Real Time
------------------------------
-Sometimes it's useful to output logs for the tests as they run::
-
- uv run pytest -s tests/unit/
-
-Use tee to capture logs and see them on your terminal::
-
- uv run pytest -s tests/unit/ 2>&1 | tee test.log
-
-
-Running the Benchmarks
-======================
-There needs to be a version of cassandra running locally so before running the benchmarks, if ccm is installed:
-
- uv run ccm create benchmark_cluster -v 3.0.1 -n 1 -s
-
-To run the benchmarks, pick one of the files under the ``benchmarks/`` dir and run it::
-
- uv run benchmarks/future_batches.py
-
-There are a few options. Use ``--help`` to see them all::
-
- uv run benchmarks/future_batches.py --help
diff --git a/README.rst b/README.rst
index f6a983a5b2..84ceb443a3 100644
--- a/README.rst
+++ b/README.rst
@@ -20,7 +20,7 @@ Scylla Enterprise (2018.1.x+) using exclusively Cassandra's binary protocol and
.. image:: https://github.com/scylladb/python-driver/actions/workflows/integration-tests.yml/badge.svg?branch=master
:target: https://github.com/scylladb/python-driver/actions/workflows/integration-tests.yml?query=event%3Apush+branch%3Amaster
-The driver supports Python versions 3.9-3.13.
+The driver supports Python versions 3.10-3.14.
.. **Note:** This driver does not support big-endian systems.
diff --git a/benchmarks/base.py b/benchmarks/base.py
index 2000b4069f..3922eefad5 100644
--- a/benchmarks/base.py
+++ b/benchmarks/base.py
@@ -21,7 +21,7 @@
from optparse import OptionParser
import uuid
-from greplin import scales
+from cassandra.metrics import getStats
dirname = os.path.dirname(os.path.abspath(__file__))
sys.path.append(dirname)
@@ -97,7 +97,7 @@ def setup(options):
try:
session.execute("""
CREATE KEYSPACE %s
- WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
+ WITH replication = { 'class': 'NetworkTopologyStrategy', 'replication_factor': '2' }
""" % options.keyspace)
log.debug("Setting keyspace...")
@@ -192,7 +192,7 @@ def benchmark(thread_class):
log.info("Total time: %0.2fs" % total)
log.info("Average throughput: %0.2f/sec" % (options.num_ops / total))
if options.enable_metrics:
- stats = scales.getStats()['cassandra']
+ stats = getStats()['cassandra']
log.info("Connection errors: %d", stats['connection_errors'])
log.info("Write timeouts: %d", stats['write_timeouts'])
log.info("Read timeouts: %d", stats['read_timeouts'])
diff --git a/benchmarks/callback_full_pipeline.py b/benchmarks/callback_full_pipeline.py
index a4a4c33315..87eb999cfe 100644
--- a/benchmarks/callback_full_pipeline.py
+++ b/benchmarks/callback_full_pipeline.py
@@ -49,10 +49,7 @@ def insert_next(self, previous_result=sentinel):
def run(self):
self.start_profile()
- if self.protocol_version >= 3:
- concurrency = 1000
- else:
- concurrency = 100
+ concurrency = 1000
for _ in range(min(concurrency, self.num_queries)):
self.insert_next()
diff --git a/cassandra/__init__.py b/cassandra/__init__.py
index dfded7d1a6..1286f20e9b 100644
--- a/cassandra/__init__.py
+++ b/cassandra/__init__.py
@@ -23,7 +23,7 @@ def emit(self, record):
logging.getLogger('cassandra').addHandler(NullHandler())
-__version_info__ = (3, 29, 3)
+__version_info__ = (3, 29, 10)
__version__ = '.'.join(map(str, __version_info__))
@@ -135,16 +135,6 @@ class ProtocolVersion(object):
"""
Defines native protocol versions supported by this driver.
"""
- V1 = 1
- """
- v1, supported in Cassandra 1.2-->2.2
- """
-
- V2 = 2
- """
- v2, supported in Cassandra 2.0-->2.2;
- added support for lightweight transactions, batch operations, and automatic query paging.
- """
V3 = 3
"""
@@ -180,9 +170,9 @@ class ProtocolVersion(object):
DSE private protocol v2, supported in DSE 6.0+
"""
- SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1)
+ SUPPORTED_VERSIONS = (V5, V4, V3)
"""
- A tuple of all supported protocol versions
+ A tuple of all supported protocol versions for ScyllaDB, including future v5 version.
"""
BETA_VERSIONS = (V6,)
@@ -233,14 +223,6 @@ def uses_error_code_map(cls, version):
def uses_keyspace_flag(cls, version):
return version >= cls.V5 and version != cls.DSE_V1
- @classmethod
- def has_continuous_paging_support(cls, version):
- return version >= cls.DSE_V1
-
- @classmethod
- def has_continuous_paging_next_pages(cls, version):
- return version >= cls.DSE_V2
-
@classmethod
def has_checksumming_support(cls, version):
return cls.V5 <= version < cls.DSE_V1
@@ -705,10 +687,29 @@ class OperationTimedOut(DriverException):
The last :class:`~.Host` this operation was attempted against.
"""
- def __init__(self, errors=None, last_host=None):
+ timeout = None
+ """
+ The timeout value (in seconds) that was in effect when the operation
+ timed out, or ``None`` if not applicable.
+ """
+
+ in_flight = None
+ """
+ The number of in-flight requests on the connection at the time of
+ the timeout (includes orphaned requests), or ``None`` if not applicable.
+ """
+
+ def __init__(self, errors=None, last_host=None, timeout=None, in_flight=None):
self.errors = errors
self.last_host = last_host
+ self.timeout = timeout
+ self.in_flight = in_flight
message = "errors=%s, last_host=%s" % (self.errors, self.last_host)
+ if self.timeout is not None:
+ message += " (timeout=%ss" % self.timeout
+ if self.in_flight is not None:
+ message += ", in_flight=%d" % self.in_flight
+ message += ")"
Exception.__init__(self, message)
diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py
new file mode 100644
index 0000000000..80b2477a6d
--- /dev/null
+++ b/cassandra/client_routes.py
@@ -0,0 +1,451 @@
+# Copyright 2026 ScyllaDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Client Routes support for Private Link and similar network configurations.
+
+This module implements support for dynamic address translation via the
+system.client_routes table and CLIENT_ROUTES_CHANGE events.
+"""
+
+from __future__ import absolute_import
+
+from dataclasses import dataclass
+import enum
+import logging
+import socket
+import threading
+import uuid
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
+
+from cassandra import ConsistencyLevel
+from cassandra.protocol import QueryMessage
+from cassandra.query import dict_factory
+
+if TYPE_CHECKING:
+ from cassandra.connection import Connection
+
+log = logging.getLogger(__name__)
+
+
+class ClientRoutesChangeType(enum.Enum):
+ """
+ Types of CLIENT_ROUTES_CHANGE events.
+
+ Currently the protocol defines only UPDATE_NODES.
+ New variants will be added here if the protocol is extended.
+ """
+ UPDATE_NODES = "UPDATE_NODES"
+
+
+@dataclass
+class ClientRouteProxy:
+ """
+ :param connection_id: String identifying the connection (required)
+ :param connection_addr_override:: Optional string address for initial connection
+ """
+
+ connection_id: str
+ connection_addr_override: Optional[str] = None
+
+ def __post_init__(self):
+ if self.connection_id is None:
+ raise ValueError("connection_id is required")
+
+class ClientRoutesConfig:
+ """
+ Configuration for client routes (Private Link support).
+
+ :param proxies: List of :class:`ClientRouteProxy` objects
+ (REQUIRED, at least one)
+ :param advanced_shard_awareness: Whether to enable advanced shard awareness
+ (default: ``False``)
+ """
+
+ proxies: List[ClientRouteProxy]
+ advanced_shard_awareness: bool
+
+ def __init__(self, proxies: List[ClientRouteProxy], advanced_shard_awareness: bool = False):
+ """
+ :param proxies: List of ClientRouteProxy objects
+ :param advanced_shard_awareness: Enable advanced shard awareness (default False)
+ """
+ if not proxies:
+ raise ValueError("At least one proxy must be specified")
+
+ if not isinstance(proxies, (list, tuple)):
+ raise TypeError("proxies must be a list or tuple")
+
+ for proxy in proxies:
+ if not isinstance(proxy, ClientRouteProxy):
+ raise TypeError("All proxies must be ClientRouteProxy instances")
+
+ self.proxies = proxies
+ self.advanced_shard_awareness = advanced_shard_awareness
+
+ def __repr__(self) -> str:
+ return (f"ClientRoutesConfig(proxies={self.proxies}, "
+ f"advanced_shard_awareness={self.advanced_shard_awareness})")
+
+
+@dataclass(frozen=True)
+class _Route:
+ connection_id: str
+ host_id: uuid.UUID
+ address: str # ipv4, ipv6 or DNS hostname from system.client_routes
+ port: int
+
+class _RouteStore:
+ """
+ Thread-safe storage for routes. Reads are safe under CPython's GIL;
+ writes are serialized with a lock.
+
+ This uses atomic pointer swaps for updates, allowing lock-free reads
+ while serializing writes.
+ """
+
+ _routes_by_host_id: Dict[uuid.UUID, _Route]
+ _lock: threading.Lock
+
+ def __init__(self) -> None:
+ self._routes_by_host_id = {}
+ self._lock = threading.Lock()
+
+ def get_by_host_id(self, host_id: uuid.UUID) -> Optional[_Route]:
+ """
+ Get route for a host ID (lock-free read).
+
+ :param host_id: UUID of the host
+ :return: _Route or None
+ """
+ return self._routes_by_host_id.get(host_id)
+
+ def get_all(self) -> List[_Route]:
+ """
+ Get all routes as a list (lock-free read).
+
+ :return: List of _Route
+ """
+ return list(self._routes_by_host_id.values())
+
+ def _select_preferred_routes(self, new_routes: List[_Route]) -> List[_Route]:
+ """
+ When multiple routes exist for the same host_id (different connection_ids),
+ prefer the connection_id already in use. Only migrate to a different
+ connection_id when the previously used one is no longer available.
+
+ Must be called under self._lock.
+ """
+ by_host: Dict[uuid.UUID, List[_Route]] = {}
+ for route in new_routes:
+ by_host.setdefault(route.host_id, []).append(route)
+
+ selected = []
+ for host_id, candidates in by_host.items():
+ if len(candidates) == 1:
+ selected.append(candidates[0])
+ continue
+
+ existing = self._routes_by_host_id.get(host_id)
+ if existing:
+ preferred = [c for c in candidates if c.connection_id == existing.connection_id]
+ if preferred:
+ selected.append(preferred[0])
+ continue
+
+ selected.append(candidates[0])
+
+ return selected
+
+ def update(self, routes: List[_Route]) -> None:
+ """
+ Replace all routes atomically.
+
+ :param routes: List of _Route objects
+ """
+ with self._lock:
+ preferred = self._select_preferred_routes(routes)
+ self._routes_by_host_id = {route.host_id: route for route in preferred}
+
+ def merge(self, new_routes: List[_Route], affected_host_ids: Set[uuid.UUID]) -> None:
+ """
+ Merge new routes with existing ones atomically.
+
+ Routes for affected_host_ids are replaced entirely: existing routes
+ for those hosts are dropped and replaced with whatever is in new_routes.
+ This handles deletions from system.client_routes (affected host present
+ but no new route for it).
+
+ :param new_routes: List of _Route objects to merge
+ :param affected_host_ids: Set of host IDs affected by the change.
+ """
+ with self._lock:
+ preferred = self._select_preferred_routes(new_routes)
+ new_by_host = {r.host_id: r for r in preferred}
+
+ updated = {hid: r for hid, r in self._routes_by_host_id.items()
+ if hid not in affected_host_ids}
+ updated.update(new_by_host)
+ self._routes_by_host_id = updated
+
+
+class _ClientRoutesHandler:
+ """
+ Handles dynamic address translation for Private Link via system.client_routes.
+
+ Lifecycle:
+ 1. Construction: Create with configuration
+ 2. Initialization: Read system.client_routes after control connection established
+ 3. Steady state: Listen for CLIENT_ROUTES_CHANGE events and update routes
+ 4. Translation: Translate addresses using Host ID lookup
+ """
+
+ config: 'ClientRoutesConfig'
+ ssl_enabled: bool
+ _routes: _RouteStore
+ _connection_ids: Set[str]
+ _proxy_addresses_override: Dict[str, str]
+
+ def __init__(self, config: 'ClientRoutesConfig', ssl_enabled: bool = False):
+ """
+ :param config: ClientRoutesConfig instance
+ :param ssl_enabled: Whether TLS is enabled (determines port selection)
+ """
+ if not isinstance(config, ClientRoutesConfig):
+ raise TypeError("config must be a ClientRoutesConfig instance")
+
+ self.config = config
+ self.ssl_enabled = ssl_enabled
+ self._routes = _RouteStore()
+ self._connection_ids = {dep.connection_id for dep in config.proxies}
+ # Precalculate proxy address mappings for efficient lookup
+ self._proxy_addresses_override = {
+ proxy.connection_id: proxy.connection_addr_override
+ for proxy in config.proxies
+ if proxy.connection_addr_override
+ }
+
+ def initialize(self, connection: 'Connection', timeout: float) -> None:
+ """
+ Load all routes from system.client_routes.
+
+ Called once at startup and again whenever the control connection
+ is re-established. Reads all configured connection IDs and
+ replaces the in-memory route store atomically.
+
+ Raises on failure so the caller can decide how to react (e.g.
+ abort startup or schedule a reconnect).
+
+ :param connection: The Connection instance to execute queries on
+ :param timeout: Query timeout in seconds
+ """
+ log.info("[client routes] Loading routes for %d proxies", len(self.config.proxies))
+
+ routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids)
+ self._routes.update(routes)
+
+ def handle_client_routes_change(self, connection: 'Connection', timeout: float,
+ change_type: 'ClientRoutesChangeType',
+ connection_ids: Sequence[str], host_ids: Sequence[str]) -> None:
+ """
+ Handle CLIENT_ROUTES_CHANGE event.
+
+ Currently the protocol defines only :attr:`ClientRoutesChangeType.UPDATE_NODES`.
+ New variants will be added to the enum if the protocol is extended.
+
+ :param connection: The Connection instance to execute queries on
+ :param timeout: Query timeout in seconds
+ :param change_type: A :class:`ClientRoutesChangeType` value
+ :param connection_ids: Affected connection ID strings; empty means all.
+ :param host_ids: Affected host ID strings; empty means all.
+ """
+
+ full_refresh = False
+ if not connection_ids or not host_ids:
+ log.warning(
+ "[client routes] CLIENT_ROUTES_CHANGE has no connection_ids or host_ids, doing full refresh")
+ full_refresh = True
+ elif len(connection_ids) != len(host_ids):
+ log.warning("[client routes] CLIENT_ROUTES_CHANGE has mismatched lengths (conn: %d, host: %d), doing full refresh",
+ len(connection_ids), len(host_ids))
+ full_refresh = True
+
+ if full_refresh:
+ routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids)
+ self._routes.update(routes)
+ return
+
+ host_uuids = [uuid.UUID(hid) for hid in host_ids]
+ pairs = [(cid, hid) for cid, hid in zip(connection_ids, host_uuids)
+ if cid in self._connection_ids]
+
+ if not pairs:
+ return
+
+ routes = self._query_routes_for_change_event(connection, timeout, pairs)
+ self._routes.merge(routes, affected_host_ids=set(host_uuids))
+
+ def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float,
+ connection_ids: Set[str]) -> List[_Route]:
+ """
+ Query all routes for the given connection IDs (complete refresh).
+
+ Used when control connection reconnects or as a fallback when
+ CLIENT_ROUTES_CHANGE event has malformed data.
+
+ :param connection: Connection to execute query on
+ :param timeout: Query timeout in seconds
+ :param connection_ids: Set of connection ID strings
+ :return: List of _Route
+ """
+ if not connection_ids:
+ return []
+
+ placeholders = ', '.join('?' for _ in connection_ids)
+ query = f"SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes WHERE connection_id IN ({placeholders})"
+ params = [cid.encode('utf-8') for cid in connection_ids]
+
+ log.debug("[client routes] Querying all routes for connection_ids=%s", connection_ids)
+ return self._execute_routes_query(connection, timeout, query, params)
+
+ def _query_routes_for_change_event(self, connection: 'Connection', timeout: float,
+ route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]:
+ """
+ Query specific routes affected by a CLIENT_ROUTES_CHANGE event.
+
+ Takes a list of (connection_id, host_id) pairs that represent the exact
+ routes affected by an operation. This provides precise updates without
+ fetching unrelated routes.
+
+ If the pairs list is empty or None, falls back to a complete refresh
+ of all routes for safety.
+
+ :param connection: Connection to execute query on
+ :param timeout: Query timeout in seconds
+ :param route_pairs: List of (connection_id, host_id) tuples
+ :return: List of _Route
+ """
+ unique_pairs = list(dict.fromkeys(route_pairs))
+
+ conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs))
+ host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs))
+
+ log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE "
+ "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5])
+
+ conn_ph = ', '.join('?' for _ in conn_ids)
+ host_ph = ', '.join('?' for _ in host_ids)
+ query = (
+ "SELECT connection_id, host_id, address, port, tls_port "
+ "FROM system.client_routes "
+ f"WHERE connection_id IN ({conn_ph}) AND host_id IN ({host_ph})"
+ )
+ params: List = [cid.encode('utf-8') for cid in conn_ids]
+ params.extend(hid.bytes for hid in host_ids)
+
+ return self._execute_routes_query(connection, timeout, query, params)
+
+ def _execute_routes_query(self, connection: 'Connection', timeout: float,
+ query: str, params: List) -> List[_Route]:
+ """
+ Execute a routes query and parse results.
+
+ Common helper for both complete refresh and change event queries.
+
+ :param connection: Connection to execute query on
+ :param timeout: Query timeout in seconds
+ :param query: CQL query string
+ :param params: Query parameters
+ :return: List of _Route
+ """
+ log.debug("[client routes] Executing query: %s with %d parameters", query, len(params))
+
+ query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE,
+ query_params=params if params else None)
+ result = connection.wait_for_response(
+ query_msg, timeout=timeout
+ )
+
+ routes = []
+ broken = 0
+ rows = dict_factory(result.column_names, result.parsed_rows)
+ for row in rows:
+ try:
+ absent = []
+ port = row['tls_port'] if self.ssl_enabled else row['port']
+ connection_id = row['connection_id']
+ host_id = row['host_id']
+ address = row['address']
+
+ if not port:
+ absent.append("tls_port" if self.ssl_enabled else "port")
+ if not connection_id:
+ absent.append("connection_id")
+ if not host_id:
+ absent.append("host_id")
+ if not address:
+ absent.append("address")
+
+ if absent:
+ log.error("[client routes] read a route %s, that has no values for the following fields: %s", row, ",".join(absent))
+ broken += 1
+ continue
+
+ final_address = self._proxy_addresses_override.get(connection_id, address)
+
+ routes.append(_Route(
+ connection_id=connection_id,
+ host_id=host_id,
+ address=final_address,
+ port=port,
+ ))
+ except Exception as e:
+ log.warning("[client routes] Failed to parse route row: %s", e)
+ broken += 1
+
+ if broken and not routes:
+ raise RuntimeError(
+ "[client routes] All %d route rows failed validation; "
+ "refusing to return empty result that would wipe the route store" % broken
+ )
+
+ return routes
+
+ def resolve_host(self, host_id: uuid.UUID) -> Optional[Tuple[str, int]]:
+ """
+ Resolve a host_id to an (address, port) pair.
+
+ Looks up the current route and selects the appropriate port.
+
+ :param host_id: Host UUID to resolve
+ :return: Tuple of (address, port) or None if no route mapping exists
+ """
+ route = self._routes.get_by_host_id(host_id)
+ if route is None:
+ return None
+
+ if not route.port:
+ raise ValueError("Mapping for host %s has no port" % host_id)
+
+ try:
+ result = socket.getaddrinfo(route.address, route.port,
+ socket.AF_UNSPEC, socket.SOCK_STREAM)
+ if not result:
+ raise socket.gaierror("No addresses found for %s" % route.address)
+ resolved_ip = result[0][4][0]
+ return resolved_ip, route.port
+ except socket.gaierror as e:
+ log.warning('[client routes] Could not resolve hostname "%s" (host_id=%s): %s',
+ route.address, host_id, e)
+ raise
diff --git a/cassandra/cluster.py b/cassandra/cluster.py
index 679293a52d..1181c6f686 100644
--- a/cassandra/cluster.py
+++ b/cassandra/cluster.py
@@ -20,16 +20,18 @@
import atexit
import datetime
+from enum import Enum
from binascii import hexlify
from collections import defaultdict
from collections.abc import Mapping
-from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures
+from concurrent.futures import Future, ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures
from copy import copy
from functools import partial, reduce, wraps
from itertools import groupby, count, chain
+import enum
import json
import logging
-from typing import Optional
+from typing import Any, Dict, Optional, Union, Tuple
from warnings import warn
from random import random
import re
@@ -48,10 +50,11 @@
SchemaTargetType, DriverException, ProtocolVersion,
UnresolvableContactPoints, DependencyException)
from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider
-from cassandra.connection import (ConnectionException, ConnectionShutdown,
+from cassandra.client_routes import ClientRoutesChangeType, ClientRoutesConfig, _ClientRoutesHandler
+from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown,
ConnectionHeartbeat, ProtocolVersionUnsupported,
EndPoint, DefaultEndPoint, DefaultEndPointFactory,
- ContinuousPagingState, SniEndPointFactory, ConnectionBusy)
+ SniEndPointFactory, ConnectionBusy, locally_supported_compressions)
from cassandra.cqltypes import UserType
import cassandra.cqltypes as types
from cassandra.encoder import Encoder
@@ -75,7 +78,7 @@
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
NeverRetryPolicy)
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
- HostConnectionPool, HostConnection,
+ HostConnection,
NoConnectionsAvailable)
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, TraceUnavailable,
@@ -95,7 +98,6 @@
GraphSON3Serializer)
from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory
from cassandra.datastax import cloud as dscloud
-from cassandra.scylla.cloud import CloudConfiguration
from cassandra.application_info import ApplicationInfoBase
try:
@@ -191,16 +193,6 @@ def _connection_reduce_fn(val,import_fn):
log = logging.getLogger(__name__)
-
-DEFAULT_MIN_REQUESTS = 5
-DEFAULT_MAX_REQUESTS = 100
-
-DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2
-DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8
-
-DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1
-DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2
-
_GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0')
_NOT_SET = object()
@@ -224,6 +216,14 @@ def __init__(self, message, errors):
self.errors = errors
+class SchemaAgreementScope(str, Enum):
+ """Scope selectors for :meth:`.Session.wait_for_schema_agreement`."""
+
+ RACK = 'rack'
+ DC = 'dc'
+ CLUSTER = 'cluster'
+
+
def _future_completed(future):
""" Helper for run_in_executor() """
exc = future.exception()
@@ -515,8 +515,9 @@ def __init__(self, load_balancing_policy=None, retry_policy=None,
class ProfileManager(object):
- def __init__(self):
+ def __init__(self, pools_allowed: bool=True):
self.profiles = dict()
+ self.pools_allowed = pools_allowed
def _profiles_without_explicit_lbps(self):
names = (profile_name for
@@ -528,6 +529,8 @@ def _profiles_without_explicit_lbps(self):
)
def distance(self, host):
+ if not self.pools_allowed:
+ return HostDistance.IGNORED
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
@@ -543,10 +546,14 @@ def check_supported(self):
p.load_balancing_policy.check_supported()
def on_up(self, host):
+ if not self.pools_allowed:
+ return
for p in self.profiles.values():
p.load_balancing_policy.on_up(host)
def on_down(self, host):
+ if not self.pools_allowed:
+ return
for p in self.profiles.values():
p.load_balancing_policy.on_down(host)
@@ -620,6 +627,31 @@ class _ConfigMode(object):
PROFILES = 2
+class ControlConnectionQueryFallback(enum.Enum):
+ """
+ Controls how application queries use the control connection when node pools
+ are unavailable.
+
+ ``Disabled`` requires a usable node pool for application queries. If the
+ driver cannot establish one during session startup, it raises
+ :class:`NoHostAvailable`.
+
+ ``Fallback`` still attempts to create node pools, but allows application
+ queries to fall back to the control connection when no usable node pool is
+ available. Session startup is allowed to proceed even if the initial pool
+ attempts all fail.
+
+ ``SkipPoolCreation`` disables node-pool creation for the session and uses
+ the control-connection fallback path for application queries.
+
+ The fallback path is not used for requests targeted to an explicit host.
+ """
+
+ Disabled = "Disabled"
+ Fallback = "Fallback"
+ SkipPoolCreation = "SkipPoolCreation"
+
+
class Cluster(object):
"""
The main class to use when interacting with a Cassandra cluster.
@@ -672,7 +704,7 @@ class Cluster(object):
server will be automatically used.
"""
- protocol_version = ProtocolVersion.DSE_V2
+ protocol_version = ProtocolVersion.V5
"""
The maximum version of the native protocol to use.
@@ -680,7 +712,7 @@ class Cluster(object):
If not set in the constructor, the driver will automatically downgrade
version based on a negotiation with the server, but it is most efficient
- to set this to the maximum supported by your version of Cassandra.
+ to set this to the maximum supported by your version of ScyllaDB.
Setting this will also prevent conflicting versions negotiated if your
cluster is upgraded.
@@ -695,7 +727,7 @@ class Cluster(object):
Used for testing new protocol features incrementally before the new version is complete.
"""
- compression = True
+ compression: Union[bool, str, None] = True
"""
Controls compression for communications between the driver and Cassandra.
If left as the default of :const:`True`, either lz4 or snappy compression
@@ -705,7 +737,7 @@ class Cluster(object):
You may also set this to 'snappy' or 'lz4' to request that specific
compression type.
- Setting this to :const:`False` disables compression.
+ Setting this to :const:`False` or :const:`None` disables compression.
"""
_application_info: Optional[ApplicationInfoBase] = None
@@ -731,9 +763,6 @@ def auth_provider(self):
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
such as :class:`~.PlainTextAuthProvider`.
- When :attr:`~.Cluster.protocol_version` is 1, this should be
- a function that accepts one argument, the IP address of a node,
- and returns a dict of credentials for that node.
When not using authentication, this should be left as :const:`None`.
"""
@@ -851,8 +880,8 @@ def default_retry_policy(self, policy):
Using ssl_options without ssl_context is deprecated and will be removed in the
next major release.
- An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket``
- when new sockets are created. This should be used when client encryption is enabled
+ An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket``
+ when new sockets are created. This should be used when client encryption is enabled
in Cassandra.
The following documentation only applies when ssl_options is used without ssl_context.
@@ -943,6 +972,16 @@ def default_retry_policy(self, policy):
If set to :const:`None`, there will be no timeout for these queries.
"""
+ allow_control_connection_query_fallback: ControlConnectionQueryFallback = ControlConnectionQueryFallback.Disabled
+ """
+ Controls whether application queries may fall back to the control connection.
+
+ ``Disabled`` keeps the old behavior.
+ ``Fallback`` enables control-connection fallback when no usable node pools exist.
+ ``SkipPoolCreation`` skips node-pool creation and uses the control connection fallback path.
+ This fallback is still not used for requests targeted to an explicit host.
+ """
+
idle_heartbeat_interval = 30
"""
Interval, in seconds, on which to heartbeat idle connections. This helps
@@ -1039,7 +1078,7 @@ def default_retry_policy(self, policy):
documentation for :meth:`Session.timestamp_generator`.
"""
- monitor_reporting_enabled = True
+ monitor_reporting_enabled = False
"""
A boolean indicating if monitor reporting, which sends gathered data to
Insights when running against DSE 6.8 and higher.
@@ -1095,10 +1134,19 @@ def default_retry_policy(self, policy):
used for columns in this cluster.
"""
- metadata_request_timeout = datetime.timedelta(seconds=2)
+ metadata_request_timeout: Optional[float] = None
"""
- Timeout for all queries used by driver it self.
- Supported only by Scylla clusters.
+ Specifies a server-side timeout (in seconds) for all internal driver queries,
+ such as schema metadata lookups and cluster topology requests.
+
+ The timeout is enforced by appending `USING TIMEOUT ` to queries
+ executed by the driver.
+
+ - A value of `0` disables explicit timeout enforcement. In this case,
+ the driver does not add `USING TIMEOUT`, and the timeout is determined
+ by the server's defaults.
+ - Only supported when connected to Scylla clusters.
+ - If not explicitly set, defaults to the value of `control_connection_timeout`.
"""
@property
@@ -1176,7 +1224,7 @@ def token_metadata_enabled(self, enabled):
def __init__(self,
contact_points=_NOT_SET,
port=9042,
- compression=True,
+ compression: Union[bool, str, None] = True,
auth_provider=None,
load_balancing_policy=None,
reconnection_policy=None,
@@ -1217,9 +1265,11 @@ def __init__(self,
cloud=None,
scylla_cloud=None,
shard_aware_options=None,
- metadata_request_timeout=None,
+ metadata_request_timeout: Optional[float] = None,
column_encryption_policy=None,
- application_info:Optional[ApplicationInfoBase]=None
+ application_info:Optional[ApplicationInfoBase]=None,
+ client_routes_config:Optional[ClientRoutesConfig]=None,
+ allow_control_connection_query_fallback:Optional[ControlConnectionQueryFallback]=ControlConnectionQueryFallback.Disabled
):
"""
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1237,27 +1287,15 @@ def __init__(self,
if port < 1 or port > 65535:
raise ValueError("Invalid port number (%s) (1-65535)" % port)
+ if not isinstance(allow_control_connection_query_fallback, ControlConnectionQueryFallback):
+ raise TypeError(
+ "allow_control_connection_query_fallback must be a ControlConnectionQueryFallback value")
+
if connection_class is not None:
self.connection_class = connection_class
if scylla_cloud is not None:
- if contact_points is not _NOT_SET or ssl_context or ssl_options:
- raise ValueError("contact_points, ssl_context, and ssl_options "
- "cannot be specified with a scylla cloud configuration")
- if shard_aware_options and not shard_aware_options.disable_shardaware_port:
- raise ValueError("shard_aware_options.disable_shardaware_port=False "
- "cannot be specified with a scylla cloud configuration")
- uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection)
- uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection)
-
- scylla_cloud_config = CloudConfiguration.create(scylla_cloud, pyopenssl=uses_twisted or uses_eventlet,
- endpoint_factory=endpoint_factory)
- ssl_context = scylla_cloud_config.ssl_context
- endpoint_factory = scylla_cloud_config.endpoint_factory
- contact_points = scylla_cloud_config.contact_points
- ssl_options = scylla_cloud_config.ssl_options
- auth_provider = scylla_cloud_config.auth_provider
- shard_aware_options = ShardAwareOptions(shard_aware_options, disable_shardaware_port=True)
+ raise NotImplementedError("scylla_cloud was removed and not supported anymore")
if cloud is not None:
self.cloud = cloud
@@ -1300,11 +1338,69 @@ def __init__(self,
if column_encryption_policy is not None:
self.column_encryption_policy = column_encryption_policy
+ if client_routes_config is not None and endpoint_factory is not None:
+ raise ValueError("client_routes_config and endpoint_factory are mutually exclusive")
+
+ self._client_routes_handler = None
+ if client_routes_config is not None:
+ if not isinstance(client_routes_config, ClientRoutesConfig):
+ raise TypeError("client_routes_config must be a ClientRoutesConfig instance")
+
+ # SSL hostname verification is incompatible with client routes:
+ # connections go through NLB proxies whose addresses won't match
+ # server certificates.
+ _check_hostname_enabled = False
+ if ssl_context is not None and ssl_context.check_hostname:
+ _check_hostname_enabled = True
+ if ssl_options is not None and ssl_options.get('check_hostname', False):
+ _check_hostname_enabled = True
+ if _check_hostname_enabled:
+ raise ValueError(
+ "SSL hostname verification (check_hostname=True) is currently incompatible "
+ "with client_routes_config. When using client routes, connections "
+ "go through NLB proxies whose addresses won't match server "
+ "certificates. Disable hostname verification by setting "
+ "ssl_context.check_hostname = False."
+ )
+
+ ssl_enabled = ssl_context is not None or ssl_options is not None
+ self._client_routes_handler = _ClientRoutesHandler(client_routes_config, ssl_enabled=ssl_enabled)
+
+ if contact_points is _NOT_SET or not self._contact_points_explicit:
+ seed_addrs = [dep.connection_addr_override for dep in client_routes_config.proxies
+ if dep.connection_addr_override]
+ if seed_addrs:
+ self.contact_points = seed_addrs
+ self._contact_points_explicit = True
+ log.info("[client routes] Using %d deployment connection addresses as contact points",
+ len(seed_addrs))
+
+ if self._client_routes_handler is not None:
+ endpoint_factory = ClientRoutesEndPointFactory(self._client_routes_handler, self.port)
self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port)
self.endpoint_factory.configure(self)
self._resolve_hostnames()
+ if isinstance(compression, bool) or compression is None:
+ compression = bool(compression)
+ if compression and not locally_supported_compressions:
+ log.error(
+ "Compression is enabled, but no compression libraries are available. "
+ "Disabling compression, consider installing one of the Python packages: lz4 and/or python-snappy."
+ )
+ compression = False
+ elif isinstance(compression, str):
+ if not locally_supported_compressions.get(compression):
+ raise ValueError(
+ "Compression '%s' was requested, but it is not available. "
+ "Consider installing the corresponding Python package." % compression
+ )
+ else:
+ raise TypeError(
+ "The 'compression' option must be either a string (e.g., 'lz4' or 'snappy') "
+ "or a boolean (True to enable any available compression, False to disable it)."
+ )
self.compression = compression
if protocol_version is not _NOT_SET:
@@ -1315,8 +1411,6 @@ def __init__(self,
self.no_compact = no_compact
self.auth_provider = auth_provider
- if metadata_request_timeout is not None:
- self.metadata_request_timeout = metadata_request_timeout
if load_balancing_policy is not None:
if isinstance(load_balancing_policy, type):
@@ -1358,7 +1452,8 @@ def __init__(self,
else:
self.timestamp_generator = MonotonicTimestampGenerator()
- self.profile_manager = ProfileManager()
+ self.profile_manager = ProfileManager(
+ pools_allowed=allow_control_connection_query_fallback != ControlConnectionQueryFallback.SkipPoolCreation)
self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(
self.load_balancing_policy,
self.default_retry_policy,
@@ -1427,6 +1522,8 @@ def __init__(self,
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait
self.control_connection_timeout = control_connection_timeout
+ self.allow_control_connection_query_fallback = allow_control_connection_query_fallback
+ self.metadata_request_timeout = self.control_connection_timeout if metadata_request_timeout is None else metadata_request_timeout
self.idle_heartbeat_interval = idle_heartbeat_interval
self.idle_heartbeat_timeout = idle_heartbeat_timeout
self.schema_event_refresh_window = schema_event_refresh_window
@@ -1439,6 +1536,10 @@ def __init__(self,
self.monitor_reporting_interval = monitor_reporting_interval
self.shard_aware_options = ShardAwareOptions(opts=shard_aware_options)
+ if (client_routes_config is not None
+ and not client_routes_config.advanced_shard_awareness):
+ self.shard_aware_options.disable_shardaware_port = True
+
self._listeners = set()
self._listener_lock = Lock()
@@ -1452,30 +1553,6 @@ def __init__(self,
self._user_types = defaultdict(dict)
- self._min_requests_per_connection = {
- HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
- HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
- HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
- }
-
- self._max_requests_per_connection = {
- HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
- HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
- HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
- }
-
- self._core_connections_per_host = {
- HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
- HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
- HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
- }
-
- self._max_connections_per_host = {
- HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
- HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
- HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
- }
-
self.executor = self._create_thread_pool_executor(max_workers=executor_threads)
self.scheduler = _Scheduler(self.executor)
@@ -1664,116 +1741,8 @@ def add_execution_profile(self, name, profile, pool_wait_timeout=5):
futures.update(session.update_created_pools())
_, not_done = wait_futures(futures, pool_wait_timeout)
if not_done:
- raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.")
-
- def get_min_requests_per_connection(self, host_distance):
- return self._min_requests_per_connection[host_distance]
-
- def set_min_requests_per_connection(self, host_distance, min_requests):
- """
- Sets a threshold for concurrent requests per connection, below which
- connections will be considered for disposal (down to core connections;
- see :meth:`~Cluster.set_core_connections_per_host`).
-
- Pertains to connection pool management in protocol versions {1,2}.
- """
- if self.protocol_version >= 3:
- raise UnsupportedOperation(
- "Cluster.set_min_requests_per_connection() only has an effect "
- "when using protocol_version 1 or 2.")
- if min_requests < 0 or min_requests > 126 or \
- min_requests >= self._max_requests_per_connection[host_distance]:
- raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" %
- (self._min_requests_per_connection[host_distance],))
- self._min_requests_per_connection[host_distance] = min_requests
-
- def get_max_requests_per_connection(self, host_distance):
- return self._max_requests_per_connection[host_distance]
-
- def set_max_requests_per_connection(self, host_distance, max_requests):
- """
- Sets a threshold for concurrent requests per connection, above which new
- connections will be created to a host (up to max connections;
- see :meth:`~Cluster.set_max_connections_per_host`).
-
- Pertains to connection pool management in protocol versions {1,2}.
- """
- if self.protocol_version >= 3:
- raise UnsupportedOperation(
- "Cluster.set_max_requests_per_connection() only has an effect "
- "when using protocol_version 1 or 2.")
- if max_requests < 1 or max_requests > 127 or \
- max_requests <= self._min_requests_per_connection[host_distance]:
- raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" %
- (self._min_requests_per_connection[host_distance],))
- self._max_requests_per_connection[host_distance] = max_requests
-
- def get_core_connections_per_host(self, host_distance):
- """
- Gets the minimum number of connections per Session that will be opened
- for each host with :class:`~.HostDistance` equal to `host_distance`.
- The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
- :attr:`~HostDistance.REMOTE`.
-
- This property is ignored if :attr:`~.Cluster.protocol_version` is
- 3 or higher.
- """
- return self._core_connections_per_host[host_distance]
-
- def set_core_connections_per_host(self, host_distance, core_connections):
- """
- Sets the minimum number of connections per Session that will be opened
- for each host with :class:`~.HostDistance` equal to `host_distance`.
- The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
- :attr:`~HostDistance.REMOTE`.
-
- Protocol version 1 and 2 are limited in the number of concurrent
- requests they can send per connection. The driver implements connection
- pooling to support higher levels of concurrency.
-
- If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this
- is not supported (there is always one connection per host, unless
- the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`)
- and using this will result in an :exc:`~.UnsupportedOperation`.
- """
- if self.protocol_version >= 3:
- raise UnsupportedOperation(
- "Cluster.set_core_connections_per_host() only has an effect "
- "when using protocol_version 1 or 2.")
- old = self._core_connections_per_host[host_distance]
- self._core_connections_per_host[host_distance] = core_connections
- if old < core_connections:
- self._ensure_core_connections()
-
- def get_max_connections_per_host(self, host_distance):
- """
- Gets the maximum number of connections per Session that will be opened
- for each host with :class:`~.HostDistance` equal to `host_distance`.
- The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for
- :attr:`~HostDistance.REMOTE`.
-
- This property is ignored if :attr:`~.Cluster.protocol_version` is
- 3 or higher.
- """
- return self._max_connections_per_host[host_distance]
-
- def set_max_connections_per_host(self, host_distance, max_connections):
- """
- Sets the maximum number of connections per Session that will be opened
- for each host with :class:`~.HostDistance` equal to `host_distance`.
- The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
- :attr:`~HostDistance.REMOTE`.
-
- If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this
- is not supported (there is always one connection per host, unless
- the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`)
- and using this will result in an :exc:`~.UnsupportedOperation`.
- """
- if self.protocol_version >= 3:
- raise UnsupportedOperation(
- "Cluster.set_max_connections_per_host() only has an effect "
- "when using protocol_version 1 or 2.")
- self._max_connections_per_host[host_distance] = max_connections
+ raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout." % pool_wait_timeout,
+ timeout=pool_wait_timeout)
def connection_factory(self, endpoint, host_conn = None, *args, **kwargs):
"""
@@ -1818,14 +1787,7 @@ def protocol_downgrade(self, host_endpoint, previous_version):
"http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint)
self.protocol_version = new_version
- def _add_resolved_hosts(self):
- for endpoint in self.endpoints_resolved:
- host, new = self.add_host(endpoint, signal=False)
- if new:
- host.set_up()
- for listener in self.listeners:
- listener.on_add(host)
-
+ def _populate_hosts(self):
self.profile_manager.populate(
weakref.proxy(self), self.metadata.all_hosts())
self.load_balancing_policy.populate(
@@ -1852,17 +1814,10 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self.contact_points, self.protocol_version)
self.connection_class.initialize_reactor()
_register_cluster_shutdown(self)
-
- self._add_resolved_hosts()
try:
self.control_connection.connect()
-
- # we set all contact points up for connecting, but we won't infer state after this
- for endpoint in self.endpoints_resolved:
- h = self.metadata.get_host(endpoint)
- if h and self.profile_manager.distance(h) == HostDistance.IGNORED:
- h.is_up = None
+ self._populate_hosts()
log.debug("Control connection created")
except Exception:
@@ -1871,14 +1826,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self.shutdown()
raise
- # Update the information about tablet support after connection handshake.
- self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1
- child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None
- while child_policy is not None:
- if hasattr(child_policy, '_tablet_routing_v1'):
- child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1
- child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None
-
self.profile_manager.check_supported() # todo: rename this method
if self.idle_heartbeat_interval:
@@ -1918,7 +1865,8 @@ def get_all_pools(self):
return pools
def is_shard_aware(self):
- return bool(self.get_all_pools()[0].host.sharding_info)
+ pools = self.get_all_pools()
+ return bool(pools and pools[0].host.sharding_info)
def shard_aware_stats(self):
if self.is_shard_aware():
@@ -1952,6 +1900,9 @@ def shutdown(self):
self.executor.shutdown()
+ if self.metrics_enabled and self.metrics:
+ self.metrics.shutdown()
+
_discard_cluster_shutdown(self)
def __enter__(self):
@@ -2020,7 +1971,7 @@ def on_up(self, host):
"""
Intended for internal use only.
"""
- if self.is_shutdown:
+ if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation:
return
log.debug("Waiting to acquire lock for handling up status of node %s", host)
@@ -2128,7 +2079,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False):
"""
Intended for internal use only.
"""
- if self.is_shutdown:
+ if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation:
return
with host.lock:
@@ -2159,14 +2110,14 @@ def on_add(self, host, refresh_nodes=True):
log.debug("Handling new host %r and notifying listeners", host)
+ self.profile_manager.on_add(host)
+ self.control_connection.on_add(host, refresh_nodes)
+
distance = self.profile_manager.distance(host)
if distance != HostDistance.IGNORED:
self._prepare_all_queries(host)
log.debug("Done preparing queries for new host %r", host)
- self.profile_manager.on_add(host)
- self.control_connection.on_add(host, refresh_nodes)
-
if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the "
"load balancing policy has marked it as IGNORED", host)
@@ -2733,23 +2684,26 @@ def __init__(self, cluster, hosts, keyspace=None):
# create connection pools in parallel
self._initial_connect_futures = set()
- for host in hosts:
- future = self.add_or_renew_pool(host, is_host_addition=False)
- if future:
- self._initial_connect_futures.add(future)
-
- futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED)
- while futures.not_done and not any(f.result() for f in futures.done):
- futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED)
-
- if not any(f.result() for f in self._initial_connect_futures):
- msg = "Unable to connect to any servers"
- if self.keyspace:
- msg += " using keyspace '%s'" % self.keyspace
- raise NoHostAvailable(msg, [h.address for h in hosts])
+ fallback_mode = self.cluster.allow_control_connection_query_fallback
+ if fallback_mode is not ControlConnectionQueryFallback.SkipPoolCreation:
+ for host in hosts:
+ future = self.add_or_renew_pool(host, is_host_addition=False)
+ if future:
+ self._initial_connect_futures.add(future)
+
+ futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED)
+ while futures.not_done and not any(f.result() for f in futures.done):
+ futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED)
+
+ # Only Disabled requires an initial pool to come up.
+ if not any(f.result() for f in self._initial_connect_futures) and \
+ fallback_mode is ControlConnectionQueryFallback.Disabled:
+ msg = "Unable to connect to any servers"
+ if self.keyspace:
+ msg += " using keyspace '%s'" % self.keyspace
+ raise NoHostAvailable(msg, [h.address for h in hosts])
self.session_id = uuid.uuid4()
- self._graph_paging_available = self._check_graph_paging_available()
if self.cluster.column_encryption_policy is not None:
try:
@@ -2946,26 +2900,10 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro
def _maybe_set_graph_paging(self, execution_profile):
graph_paging = execution_profile.continuous_paging_options
if execution_profile.continuous_paging_options is _NOT_SET:
- graph_paging = ContinuousPagingOptions() if self._graph_paging_available else None
+ graph_paging = None
execution_profile.continuous_paging_options = graph_paging
- def _check_graph_paging_available(self):
- """Verify if we can enable graph paging. This executed only once when the session is created."""
-
- if not ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version):
- return False
-
- for host in self.cluster.metadata.all_hosts():
- if host.dse_version is None:
- return False
-
- version = Version(host.dse_version)
- if version < _GRAPH_PAGING_MIN_DSE_VERSION:
- return False
-
- return True
-
def _resolve_execution_profile_options(self, execution_profile):
"""
Determine the GraphSON protocol and row factory for a graph query. This is useful
@@ -3101,25 +3039,15 @@ def _create_response_future(self, query, parameters, trace, custom_payload,
spec_exec_policy = execution_profile.speculative_execution_policy
fetch_size = query.fetch_size
- if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
+ if fetch_size is FETCH_SIZE_UNSET:
fetch_size = self.default_fetch_size
- elif self._protocol_version == 1:
- fetch_size = None
start_time = time.time()
- if self._protocol_version >= 3 and self.use_client_timestamp:
+ if self.use_client_timestamp:
timestamp = self.cluster.timestamp_generator()
else:
timestamp = None
- supports_continuous_paging_state = (
- ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version)
- )
- if continuous_paging_options and supports_continuous_paging_state:
- continuous_paging_state = ContinuousPagingState(continuous_paging_options.max_queue_size)
- else:
- continuous_paging_state = None
-
if isinstance(query, SimpleStatement):
query_string = query.query_string
statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None
@@ -3163,7 +3091,7 @@ def _create_response_future(self, query, parameters, trace, custom_payload,
self, message, query, timeout, metrics=self._metrics,
prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory,
load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan,
- continuous_paging_state=continuous_paging_state, host=host)
+ continuous_paging_state=None, host=host)
def get_execution_profile(self, name):
"""
@@ -3281,7 +3209,7 @@ def prepare(self, query, custom_payload=None, keyspace=None):
prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
- self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
+ self._protocol_version, response.column_metadata, response.result_metadata_id, response.is_lwt, self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload
self.cluster.add_prepared(response.query_id, prepared_statement)
@@ -3372,17 +3300,16 @@ def add_or_renew_pool(self, host, is_host_addition):
"""
For internal use only.
"""
+ if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation:
+ return None
+
distance = self._profile_manager.distance(host)
if distance == HostDistance.IGNORED:
return None
def run_add_or_renew_pool():
try:
- if self._protocol_version >= 3:
- new_pool = HostConnection(host, distance, self)
- else:
- # TODO remove host pool again ???
- new_pool = HostConnectionPool(host, distance, self)
+ new_pool = HostConnection(host, distance, self)
except AuthenticationFailed as auth_exc:
conn_exc = ConnectionException(str(auth_exc), endpoint=host)
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
@@ -3446,6 +3373,9 @@ def update_created_pools(self):
For internal use only.
"""
+ if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation:
+ return set()
+
futures = set()
for host in self.cluster.metadata.all_hosts():
distance = self._profile_manager.distance(host)
@@ -3514,6 +3444,185 @@ def pool_finished_setting_keyspace(pool, host_errors):
for pool in tuple(self._pools.values()):
pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace)
+ def wait_for_schema_agreement(self, wait_time: Optional[float] = None,
+ scope: SchemaAgreementScope = SchemaAgreementScope.CLUSTER) -> bool:
+ """
+ Wait for connected hosts in the selected scope to report the same
+ schema version from ``system.local``.
+
+ By default, the timeout for this operation is governed by
+ :attr:`~.Cluster.max_schema_agreement_wait` and
+ :attr:`~.Cluster.control_connection_timeout`.
+
+ Passing ``wait_time`` here overrides
+ :attr:`~.Cluster.max_schema_agreement_wait`. If provided, ``wait_time``
+ must be greater than 0.
+
+ ``scope`` determines which connected hosts participate in the check.
+ Pass :attr:`SchemaAgreementScope.RACK`, :attr:`SchemaAgreementScope.DC`,
+ or :attr:`SchemaAgreementScope.CLUSTER`.
+ The default is :attr:`SchemaAgreementScope.CLUSTER`. ``RACK`` narrows
+ the check to connected hosts in the local rack only. ``DC`` checks
+ connected hosts in the local datacenter. ``CLUSTER`` queries every
+ connected host across all datacenters.
+
+ :param wait_time: Override for
+ :attr:`~.Cluster.max_schema_agreement_wait`, should be positive
+ number.
+ :param scope: Restricts the check to connected hosts in the local rack,
+ local datacenter, or whole connected cluster.
+ :returns: ``True`` when the selected connected hosts agree on schema,
+ otherwise ``False``.
+ :raises ValueError: If ``wait_time`` is provided and is not greater
+ than 0.
+ :raises ValueError: If ``scope`` is not one of the schema agreement
+ scope values.
+ """
+
+ if wait_time is not None and wait_time <= 0:
+ raise ValueError("wait_time must be greater than 0")
+
+ total_timeout = wait_time if wait_time is not None else self.cluster.max_schema_agreement_wait
+ if total_timeout <= 0:
+ raise ValueError("total_timeout must be greater than 0")
+
+ deadline = time.time() + total_timeout
+ schema_mismatches = None
+ scope_label = 'local rack' if scope is SchemaAgreementScope.RACK else (
+ 'local datacenter' if scope is SchemaAgreementScope.DC else 'cluster')
+
+ while time.time() < deadline:
+ schema_mismatches = self._get_schema_mismatches_for_scope(deadline, scope)
+ if schema_mismatches is None:
+ return True
+
+ log.debug("[session] Connected hosts in the %s still disagree on schema, trying again", scope_label)
+ remaining = deadline - time.time()
+ if remaining > 0:
+ time.sleep(min(0.2, remaining))
+
+ log.warning("[session] Connected hosts in the %s are reporting a schema disagreement: %s",
+ scope_label, schema_mismatches)
+ return False
+
+ def _get_schema_mismatches_for_scope(self, deadline: float,
+ scope: SchemaAgreementScope) -> Optional[Dict[Any, Any]]:
+ hosts = self._get_schema_agreement_hosts(scope)
+ mismatches = defaultdict(list)
+ errors = {}
+ scope_label = 'local rack' if scope is SchemaAgreementScope.RACK else (
+ 'local datacenter' if scope is SchemaAgreementScope.DC else 'cluster')
+
+ if not hosts:
+ errors[scope.value] = ConnectionException(
+ "No connected hosts available in the %s" % (scope_label,)
+ )
+ return {'unavailable': errors}
+
+ metadata_request_timeout = self.cluster.control_connection._metadata_request_timeout
+ query = maybe_add_timeout_to_query(ControlConnection._SELECT_SCHEMA_LOCAL, metadata_request_timeout)
+
+ schema_version_futures = []
+ for host in hosts:
+ try:
+ schema_version_future = self._query_local_schema_version(host, query, deadline)
+ except Exception as exc:
+ errors[host.endpoint] = exc
+ continue
+
+ schema_version_futures.append((host, schema_version_future))
+
+ if schema_version_futures:
+ # Start all host queries first, then wait for the whole batch.
+ remaining = max(0.0, deadline - time.time())
+ if remaining > 0:
+ wait_futures([future for _, future in schema_version_futures], timeout=remaining)
+
+ for host, future in schema_version_futures:
+ if future.done():
+ try:
+ rows = future.result()
+ except Exception as exc:
+ errors[host.endpoint] = exc
+ continue
+
+ row = rows.one()
+ schema_version = getattr(row, "schema_version", None) if row is not None else None
+ mismatches[schema_version].append(host.endpoint)
+ else:
+ errors[host.endpoint] = OperationTimedOut(last_host=host, timeout=max(0.0, deadline - time.time()))
+
+ if len(mismatches) == 1 and None not in mismatches and not errors:
+ log.debug("[session] Connected hosts in the %s agree on schema", scope_label)
+ return None
+
+ if errors:
+ mismatches['unavailable'] = errors
+ return dict(mismatches)
+
+ def _get_schema_agreement_hosts(self, scope: SchemaAgreementScope) -> Tuple[Host, ...]:
+ if scope is SchemaAgreementScope.RACK:
+ allowed_distances = (HostDistance.LOCAL_RACK,)
+ elif scope is SchemaAgreementScope.DC:
+ allowed_distances = (HostDistance.LOCAL_RACK, HostDistance.LOCAL)
+ else:
+ allowed_distances = (HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE)
+
+ return tuple(
+ host for host, pool in tuple(self._pools.items())
+ if host.is_up
+ and not pool.is_shutdown
+ and self._profile_manager.distance(host) in allowed_distances)
+
+ def _query_local_schema_version(self, host: Host, query: str, deadline: float) -> Future:
+ remaining = max(0.0, deadline - time.time())
+ try:
+ response_future = self.execute_async(
+ query,
+ timeout=self._schema_agreement_query_timeout(remaining),
+ host=host,
+ )
+ except OperationTimedOut as timeout:
+ log.debug("[session] Timed out waiting for schema version from %s: %s", host, timeout)
+ raise
+ except Exception as exc:
+ log.debug("[session] Error querying schema version from %s: %s", host, exc)
+ raise
+
+ # execute_async returns cassandra.cluster.ResponseFuture, which does not have bulk waiting logic for it.
+ # That is why _query_local_schema_version returns concurrent.futures.Future
+ # so that schema agreement logic could use concurrent.futures.wait_futures to wait on them.
+ # schema_version_future is an adapter between cassandra.cluster.ResponseFuture and concurrent.futures.Future
+ # to make things work
+ schema_version_future = Future()
+
+ def _set_result(result, result_future=schema_version_future, response_future=response_future):
+ if result_future.done():
+ return
+ try:
+ result_future.set_result(ResultSet(response_future, result))
+ except Exception as exc:
+ result_future.set_exception(exc)
+
+ def _set_exception(exc, result_future=schema_version_future):
+ if result_future.done():
+ return
+ result_future.set_exception(exc)
+
+ try:
+ response_future.add_callbacks(_set_result, _set_exception)
+ except Exception as exc:
+ log.debug("[session] Error registering schema version callback from %s: %s", host, exc)
+ raise
+
+ return schema_version_future
+
+ def _schema_agreement_query_timeout(self, remaining: float) -> float:
+ control_timeout = self.cluster.control_connection._timeout
+ if control_timeout is None:
+ return max(0.0, remaining)
+ return max(0.0, min(control_timeout, remaining))
+
def user_type_registered(self, keyspace, user_type, klass):
"""
Called by the parent Cluster instance when the user registers a new
@@ -3620,9 +3729,9 @@ class ControlConnection(object):
Internal
"""
- _SELECT_PEERS = "SELECT * FROM system.peers"
+ _SELECT_PEERS = "SELECT peer, data_center, host_id, rack, release_version, rpc_address, schema_version, tokens FROM system.peers"
_SELECT_PEERS_NO_TOKENS_TEMPLATE = "SELECT host_id, peer, data_center, rack, rpc_address, {nt_col_name}, release_version, schema_version FROM system.peers"
- _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'"
+ _SELECT_LOCAL = "SELECT broadcast_address, cluster_name, data_center, host_id, listen_address, partitioner, rack, release_version, rpc_address, schema_version, tokens FROM system.local WHERE key='local'"
_SELECT_LOCAL_NO_TOKENS = "SELECT host_id, cluster_name, data_center, rack, partitioner, release_version, schema_version, rpc_address FROM system.local WHERE key='local'"
# Used only when token_metadata_enabled is set to False
_SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = "SELECT rpc_address FROM system.local WHERE key='local'"
@@ -3708,28 +3817,22 @@ def _set_new_connection(self, conn):
if old:
log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn)
old.close()
-
- def _connect_host_in_lbp(self):
+
+ def _try_connect_to_hosts(self):
errors = {}
- lbp = (
- self._cluster.load_balancing_policy
- if self._cluster._config_mode == _ConfigMode.LEGACY else
- self._cluster._default_load_balancing_policy
- )
- for host in lbp.make_query_plan():
+ lbp = self._cluster.load_balancing_policy \
+ if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy
+
+ for endpoint in chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved):
try:
- return (self._try_connect(host), None)
- except ConnectionException as exc:
- errors[str(host.endpoint)] = exc
- log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
- self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
+ return (self._try_connect(endpoint), None)
except Exception as exc:
- errors[str(host.endpoint)] = exc
- log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
+ errors[str(endpoint)] = exc
+ log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True)
if self._is_shutdown:
raise DriverException("[control connection] Reconnection in progress during shutdown")
-
+
return (None, errors)
def _reconnect_internal(self):
@@ -3741,43 +3844,43 @@ def _reconnect_internal(self):
to the exception that was raised when an attempt was made to open
a connection to that host.
"""
- (conn, _) = self._connect_host_in_lbp()
+ (conn, _) = self._try_connect_to_hosts()
if conn is not None:
return conn
# Try to re-resolve hostnames as a fallback when all hosts are unreachable
self._cluster._resolve_hostnames()
- self._cluster._add_resolved_hosts()
+ self._cluster._populate_hosts()
- (conn, errors) = self._connect_host_in_lbp()
+ (conn, errors) = self._try_connect_to_hosts()
if conn is not None:
return conn
-
+
raise NoHostAvailable("Unable to connect to any servers", errors)
- def _try_connect(self, host):
+ def _try_connect(self, endpoint):
"""
Creates a new Connection, registers for pushed events, and refreshes
node/token and schema metadata.
"""
- log.debug("[control connection] Opening new connection to %s", host)
+ log.debug("[control connection] Opening new connection to %s", endpoint)
while True:
try:
- connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True)
+ connection = self._cluster.connection_factory(endpoint, is_control_connection=True)
if self._is_shutdown:
connection.close()
raise DriverException("Reconnecting during shutdown")
break
except ProtocolVersionUnsupported as e:
- self._cluster.protocol_downgrade(host.endpoint, e.startup_version)
+ self._cluster.protocol_downgrade(endpoint, e.startup_version)
except ProtocolException as e:
# protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver
# protocol version. If the protocol version was not explicitly specified,
# and that the server raises a beta protocol error, we should downgrade.
if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error:
- self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version)
+ self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version)
else:
raise
@@ -3790,9 +3893,11 @@ def _try_connect(self, host):
if connection.features.sharding_info is not None:
self._uses_peers_v2 = False
- # Cassandra does not support "USING TIMEOUT"
- self._metadata_request_timeout = None if connection.features.sharding_info is None \
- else datetime.timedelta(seconds=self._cluster.control_connection_timeout)
+ # Only ScyllaDB supports "USING TIMEOUT"
+ # Sharding information signals it is ScyllaDB
+ self._metadata_request_timeout = None if connection.features.sharding_info is None or not self._cluster.metadata_request_timeout \
+ else datetime.timedelta(seconds=self._cluster.metadata_request_timeout)
+
self._tablets_routing_v1 = connection.features.tablets_routing_v1
# use weak references in both directions
@@ -3801,11 +3906,21 @@ def _try_connect(self, host):
# this object (after a dereferencing a weakref)
self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection)))
try:
- connection.register_watchers({
+ watchers = {
"TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'),
"STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'),
"SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change')
- }, register_timeout=self._timeout)
+ }
+
+ if self._cluster._client_routes_handler is not None:
+ watchers["CLIENT_ROUTES_CHANGE"] = partial(_watch_callback, self_weakref, '_handle_client_routes_change')
+
+ connection.register_watchers(watchers, register_timeout=self._timeout)
+
+ if self._cluster._client_routes_handler is not None:
+ self._cluster._client_routes_handler.initialize(
+ connection,
+ self._timeout)
sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS
@@ -3920,7 +4035,7 @@ def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_w
if self._cluster.is_shutdown:
return False
- agreed = self.wait_for_schema_agreement(connection,
+ agreed = self._wait_for_schema_agreement(connection=connection,
preloaded_results=preloaded_results,
wait_time=schema_agreement_wait)
@@ -3990,67 +4105,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
self._cluster.metadata.cluster_name = cluster_name
partitioner = local_row.get("partitioner")
- tokens = local_row.get("tokens")
-
- host = self._cluster.metadata.get_host(connection.original_endpoint)
- if host:
- datacenter = local_row.get("data_center")
- rack = local_row.get("rack")
- self._update_location_info(host, datacenter, rack)
-
- # support the use case of connecting only with public address
- if isinstance(self._cluster.endpoint_factory, SniEndPointFactory):
- new_endpoint = self._cluster.endpoint_factory.create(local_row)
-
- if new_endpoint.address:
- host.endpoint = new_endpoint
-
- host.host_id = local_row.get("host_id")
-
- found_host_ids.add(host.host_id)
- found_endpoints.add(host.endpoint)
-
- host.listen_address = local_row.get("listen_address")
- host.listen_port = local_row.get("listen_port")
- host.broadcast_address = _NodeInfo.get_broadcast_address(local_row)
- host.broadcast_port = _NodeInfo.get_broadcast_port(local_row)
-
- host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row)
- host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row)
- if host.broadcast_rpc_address is None:
- if self._token_meta_enabled:
- # local rpc_address is not available, use the connection endpoint
- host.broadcast_rpc_address = connection.endpoint.address
- host.broadcast_rpc_port = connection.endpoint.port
- else:
- # local rpc_address has not been queried yet, try to fetch it
- # separately, which might fail because C* < 2.1.6 doesn't have rpc_address
- # in system.local. See CASSANDRA-9436.
- local_rpc_address_query = QueryMessage(
- query=maybe_add_timeout_to_query(self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, self._metadata_request_timeout),
- consistency_level=ConsistencyLevel.ONE)
- success, local_rpc_address_result = connection.wait_for_response(
- local_rpc_address_query, timeout=self._timeout, fail_on_error=False)
- if success:
- row = dict_factory(
- local_rpc_address_result.column_names,
- local_rpc_address_result.parsed_rows)
- host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0])
- host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0])
- else:
- host.broadcast_rpc_address = connection.endpoint.address
- host.broadcast_rpc_port = connection.endpoint.port
+ tokens = local_row.get("tokens", None)
- host.release_version = local_row.get("release_version")
- host.dse_version = local_row.get("dse_version")
- host.dse_workload = local_row.get("workload")
- host.dse_workloads = local_row.get("workloads")
+ peers_result.insert(0, local_row)
- if partitioner and tokens:
- token_map[host] = tokens
-
- self._cluster.metadata.update_host(host, old_endpoint=connection.endpoint)
- connection.original_endpoint = connection.endpoint = host.endpoint
# Check metadata.partitioner to see if we haven't built anything yet. If
# every node in the cluster was in the contact points, we won't discover
# any new nodes, so we need this additional check. (See PYTHON-90)
@@ -4080,14 +4138,16 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
host = self._cluster.metadata.get_host_by_host_id(host_id)
if host and host.endpoint != endpoint:
log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id)
- old_endpoint = host.endpoint
- host.endpoint = endpoint
- self._cluster.metadata.update_host(host, old_endpoint)
reconnector = host.get_and_set_reconnection_handler(None)
if reconnector:
reconnector.cancel()
self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True)
+ old_endpoint = host.endpoint
+ host.endpoint = endpoint
+ self._cluster.metadata.update_host(host, old_endpoint)
+ self._cluster.on_up(host)
+
if host is None:
log.debug("[control connection] Found new host to connect to: %s", endpoint)
host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id)
@@ -4223,6 +4283,44 @@ def _handle_status_change(self, event):
# this will be run by the scheduler
self._cluster.on_down(host, is_host_addition=False)
+ def _handle_client_routes_change(self, event: Dict[str, Any]) -> None:
+ """
+ Handle CLIENT_ROUTES_CHANGE event from the server.
+
+ This event indicates that the system.client_routes table has been updated
+ and we need to refresh our route mappings.
+ """
+ if self._cluster._client_routes_handler is None:
+ log.warning("[control connection] Received CLIENT_ROUTES_CHANGE but no handler configured")
+ return
+
+ raw_change_type = event.get("change_type")
+ try:
+ change_type = ClientRoutesChangeType(raw_change_type)
+ except ValueError:
+ log.warning("[control connection] Unknown CLIENT_ROUTES_CHANGE type: %s", raw_change_type)
+ return
+
+ connection_ids = tuple(event.get("connection_ids", []))
+ host_ids = tuple(event.get("host_ids", []))
+
+ self._cluster.scheduler.schedule_unique(
+ 0,
+ self._handle_client_routes_refresh,
+ self._connection, self._timeout, change_type, connection_ids, host_ids
+ )
+
+ def _handle_client_routes_refresh(self, connection, timeout,
+ change_type, connection_ids, host_ids):
+ try:
+ self._cluster._client_routes_handler.handle_client_routes_change(
+ connection, timeout, change_type, connection_ids, host_ids)
+ except ReferenceError:
+ pass # our weak reference to the Cluster is no good
+ except Exception:
+ log.debug("[control connection] Error handling CLIENT_ROUTES_CHANGE", exc_info=True)
+ self._signal_error()
+
def _handle_schema_change(self, event):
if self._schema_event_refresh_window < 0:
return
@@ -4230,7 +4328,30 @@ def _handle_schema_change(self, event):
self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event)
def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
+ """
+ Wait for schema agreement from the control connection's metadata view.
+
+ This method is intended for internal metadata refresh flows. External
+ callers should use :meth:`.Session.wait_for_schema_agreement` instead.
+
+ The control connection observes schema agreement from its own
+ perspective, which may include hosts the session is not using, and it
+ may fail when the control connection itself is transiently unhealthy.
+ That can produce false positives or failures that do not reflect
+ whether a session can safely proceed.
+ .. deprecated:: 3.30.0
+ Use :meth:`.Session.wait_for_schema_agreement` instead.
+ """
+ warn("ControlConnection.wait_for_schema_agreement is deprecated and will be removed in 4.0. "
+ "Use Session.wait_for_schema_agreement instead. "
+ "This method is for internal metadata refresh use only.",
+ DeprecationWarning, stacklevel=2)
+ return self._wait_for_schema_agreement(connection=connection,
+ preloaded_results=preloaded_results,
+ wait_time=wait_time)
+
+ def _wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait
if total_timeout <= 0:
return True
@@ -4268,7 +4389,8 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai
local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout),
consistency_level=cl)
try:
- timeout = min(self._timeout, total_timeout - elapsed)
+ remaining = total_timeout - elapsed
+ timeout = min(self._timeout, remaining) if self._timeout is not None else remaining
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=timeout)
except OperationTimedOut as timeout:
@@ -4349,8 +4471,9 @@ def _get_peers_query(self, peers_query_type, connection=None):
query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
- host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version
- host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version
+ original_endpoint_host = self._cluster.metadata.get_host(connection.original_endpoint)
+ host_release_version = None if original_endpoint_host is None else original_endpoint_host.release_version
+ host_dse_version = None if original_endpoint_host is None else original_endpoint_host.dse_version
uses_native_address_query = (
host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)
@@ -4586,9 +4709,10 @@ class ResponseFuture(object):
_timer = None
_protocol_handler = ProtocolHandler
_spec_execution_plan = NoSpeculativeExecutionPlan()
- _continuous_paging_options = None
_continuous_paging_session = None
_host = None
+ _control_connection_query_attempted = False
+ _TABLET_ROUTING_CTYPE = None
_warned_timeout = False
@@ -4608,6 +4732,7 @@ def __init__(self, session, message, query, timeout, metrics=None, prepared_stat
self._callback_lock = Lock()
self._start_time = start_time or time.time()
self._host = host
+ self._control_connection_query_attempted = False
self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan
self._make_query_plan()
self._event = Event()
@@ -4654,6 +4779,7 @@ def _on_timeout(self, _attempts=0):
)
return
+ conn_in_flight = None
if self._connection is not None:
try:
self._connection._requests.pop(self._req_id)
@@ -4664,9 +4790,14 @@ def _on_timeout(self, _attempts=0):
except KeyError:
key = "Connection defunct by heartbeat"
errors = {key: "Client request timeout. See Session.execute[_async](timeout)"}
- self._set_final_exception(OperationTimedOut(errors, self._current_host))
+ self._set_final_exception(OperationTimedOut(errors, self._current_host,
+ timeout=self.timeout,
+ in_flight=self._connection.in_flight))
return
+ # Capture connection stats before pool.return_connection() can alter state
+ conn_in_flight = self._connection.in_flight
+
pool = self.session._pools.get(self._current_host)
if pool and not pool.is_shutdown:
# Do not return the stream ID to the pool yet. We cannot reuse it
@@ -4680,18 +4811,31 @@ def _on_timeout(self, _attempts=0):
self._connection.orphaned_threshold_reached = True
pool.return_connection(self._connection, stream_was_orphaned=True)
+ elif self._connection.is_control_connection:
+ with self._connection.lock:
+ self._connection.orphaned_request_ids.add(self._req_id)
+ if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold:
+ self._connection.orphaned_threshold_reached = True
errors = self._errors
if not errors:
if self.is_schema_agreed:
- key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout'
+ if self._current_host is None:
+ key = 'no host queried before timeout'
+ elif self._connection is not None and self._connection.is_control_connection:
+ control_host = self.session.cluster.get_control_connection_host()
+ key = str(control_host.endpoint) if control_host is not None else str(self._connection.endpoint)
+ else:
+ key = str(self._current_host.endpoint)
errors = {key: "Client request timeout. See Session.execute[_async](timeout)"}
else:
connection = self.session.cluster.control_connection._connection
host = str(connection.endpoint) if connection else 'unknown'
errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."}
- self._set_final_exception(OperationTimedOut(errors, self._current_host))
+ self._set_final_exception(OperationTimedOut(errors, self._current_host,
+ timeout=self.timeout,
+ in_flight=conn_in_flight))
def _on_speculative_execute(self):
self._timer = None
@@ -4720,7 +4864,7 @@ def _make_query_plan(self):
# or to the explicit host target if set
if self._host:
# returning a single value effectively disables retries
- self.query_plan = [self._host]
+ self.query_plan = iter([self._host])
else:
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
@@ -4740,14 +4884,110 @@ def send_request(self, error_no_hosts=True):
self._on_timeout()
return True
if error_no_hosts:
+ if self._fallback_to_control_connection():
+ req_id = self._query_control_connection()
+ if req_id is not None:
+ self._req_id = req_id
+ return True
+
self._set_final_exception(NoHostAvailable(
"Unable to complete the operation against any hosts", self._errors))
return False
+ def _has_usable_node_pool(self):
+ try:
+ pools = tuple(self.session._pools.values())
+ except (AttributeError, TypeError):
+ return False
+
+ return any(pool and not pool.is_shutdown for pool in pools)
+
+ def _fallback_to_control_connection(self):
+ fallback_mode = self.session.cluster.allow_control_connection_query_fallback
+ if fallback_mode is ControlConnectionQueryFallback.Disabled:
+ return False
+ if self._host or self._control_connection_query_attempted:
+ return False
+ if fallback_mode is ControlConnectionQueryFallback.SkipPoolCreation:
+ return True
+ return not self._has_usable_node_pool()
+
+ def _borrow_control_connection(self, connection):
+ with connection.lock:
+ if connection.in_flight >= connection.max_request_id:
+ raise NoConnectionsAvailable("All request IDs are currently in use")
+ connection.in_flight += 1
+ return connection.get_request_id()
+
+ def _release_control_connection_request(self, connection, request_id):
+ with connection.lock:
+ connection.in_flight -= 1
+ connection.request_ids.append(request_id)
+ connection._requests.pop(request_id, None)
+
+ def _handle_control_connection_response(self, connection, cb, response):
+ with connection.lock:
+ connection.in_flight -= 1
+ cb(response)
+
+ def _query_control_connection(self, message=None, cb=None, connection=None, host=None):
+ self._control_connection_query_attempted = True
+
+ if message is None:
+ message = self.message
+
+ if connection is None:
+ control_connection = self.session.cluster.control_connection
+ connection = control_connection._connection if control_connection else None
+ if not connection:
+ self._errors['control connection'] = ConnectionException("Control connection is not connected")
+ return None
+
+ if host is None:
+ host = self.session.cluster.get_control_connection_host() or connection.endpoint
+ self._current_host = host
+
+ request_id = None
+ request_sent = False
+ try:
+ request_id = self._borrow_control_connection(connection)
+ self._connection = connection
+ result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []
+ if cb is None:
+ cb = partial(self._set_result, host, connection, None)
+ cb = partial(self._handle_control_connection_response, connection, cb)
+
+ log.debug("No usable node pools; falling back to control connection for host %s", host)
+ self.request_encoded_size = connection.send_msg(message, request_id, cb=cb,
+ encoder=self._protocol_handler.encode_message,
+ decoder=self._protocol_handler.decode_message,
+ result_metadata=result_meta)
+ request_sent = True
+ self.attempted_hosts.append(host)
+ return request_id
+ except NoConnectionsAvailable as exc:
+ log.debug("Control connection is at capacity")
+ self._errors[host] = exc
+ except ConnectionBusy as exc:
+ log.debug("Control connection is busy")
+ self._errors[host] = exc
+ except Exception as exc:
+ log.debug("Error querying control connection", exc_info=True)
+ self._errors[host] = exc
+ if self._metrics is not None:
+ self._metrics.on_connection_error()
+ finally:
+ if request_id is not None and not request_sent:
+ self._release_control_connection_request(connection, request_id)
+
+ return None
+
def _query(self, host, message=None, cb=None):
if message is None:
message = self.message
+ self._control_connection_query_attempted = False
+
pool = self.session._pools.get(host)
if not pool:
self._errors[host] = ConnectionException("Host has been marked down or removed")
@@ -4858,12 +5098,17 @@ def start_fetching_next_page(self):
self._event.clear()
self._final_result = _NOT_SET
self._final_exception = None
+ self._control_connection_query_attempted = False
self._start_timer()
self.send_request()
def _reprepare(self, prepare_message, host, connection, pool):
cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool)
- request_id = self._query(host, prepare_message, cb=cb)
+ if pool is None and connection is not None and connection.is_control_connection:
+ request_id = self._query_control_connection(prepare_message, cb=cb,
+ connection=connection, host=host)
+ else:
+ request_id = self._query(host, prepare_message, cb=cb)
if request_id is None:
# try to submit the original prepared statement on some other host
self.send_request()
@@ -4886,7 +5131,10 @@ def _set_result(self, host, connection, pool, response):
if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload:
protocol = self.session.cluster.protocol_version
info = self._custom_payload.get('tablets-routing-v1')
- ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))')
+ ctype = ResponseFuture._TABLET_ROUTING_CTYPE
+ if ctype is None:
+ ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))')
+ ResponseFuture._TABLET_ROUTING_CTYPE = ctype
tablet_routing_info = ctype.from_binary(info, protocol)
first_token = tablet_routing_info[0]
last_token = tablet_routing_info[1]
@@ -4899,6 +5147,8 @@ def _set_result(self, host, connection, pool, response):
if isinstance(response, ResultMessage):
if response.kind == RESULT_KIND_SET_KEYSPACE:
session = getattr(self, 'session', None)
+ if connection is not None:
+ connection.keyspace = response.new_keyspace
# since we're running on the event loop thread, we need to
# use a non-blocking method for setting the keyspace on
# all connections in this session, otherwise the event
@@ -5075,10 +5325,13 @@ def _execute_after_prepare(self, host, connection, pool, response):
new_metadata_id = response.result_metadata_id
if new_metadata_id is not None:
self.prepared_statement.result_metadata_id = new_metadata_id
-
+
# use self._query to re-use the same host and
# at the same time properly borrow the connection
- request_id = self._query(host)
+ if pool is None and connection is not None and connection.is_control_connection:
+ request_id = self._query_control_connection(connection=connection, host=host)
+ else:
+ request_id = self._query(host)
if request_id is None:
# this host errored out, move on to the next
self.send_request()
@@ -5191,6 +5444,11 @@ def _retry_task(self, reuse_connection, host):
# to retry the operation
return
+ if self._control_connection_query_attempted:
+ self._control_connection_query_attempted = False
+ self.send_request()
+ return
+
if reuse_connection and self._query(host) is not None:
return
diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py
index fb8f26e1cc..b96d0b12d4 100644
--- a/cassandra/concurrent.py
+++ b/cassandra/concurrent.py
@@ -33,13 +33,7 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
``parameters`` item must be a sequence or :const:`None`.
The `concurrency` parameter controls how many statements will be executed
- concurrently. When :attr:`.Cluster.protocol_version` is set to 1 or 2,
- it is recommended that this be kept below 100 times the number of
- core connections per host times the number of connected hosts (see
- :meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded,
- the event loop thread may attempt to block on new connection creation,
- substantially impacting throughput. If :attr:`~.Cluster.protocol_version`
- is 3 or higher, you can safely experiment with higher levels of concurrency.
+ concurrently.
If `raise_on_first_error` is left as :const:`True`, execution will stop
after the first failed statement and the corresponding exception will be
@@ -98,8 +92,6 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
class _ConcurrentExecutor(object):
- max_error_recursion = 100
-
def __init__(self, session, statements_and_params, execution_profile):
self.session = session
self._enum_statements = enumerate(iter(statements_and_params))
@@ -109,7 +101,7 @@ def __init__(self, session, statements_and_params, execution_profile):
self._results_queue = []
self._current = 0
self._exec_count = 0
- self._exec_depth = 0
+ self._executing = False
def execute(self, concurrency, fail_fast):
self._fail_fast = fail_fast
@@ -133,22 +125,34 @@ def _execute_next(self):
pass
def _execute(self, idx, statement, params):
- self._exec_depth += 1
+ # When execute_async completes synchronously (e.g. immediate timeout),
+ # the errback fires inline: _on_error -> _put_result -> _execute_next
+ # -> _execute. Without protection this recurses once per remaining
+ # statement and blows the stack.
+ #
+ # ``_executing`` marks that we are already inside this method higher up
+ # the call stack. When a synchronous callback re-enters, we just stash
+ # the pending work in ``_pending_executions`` and let the outermost
+ # invocation drain it in a loop -- no recursion.
+ if self._executing:
+ self._pending_executions.append((idx, statement, params))
+ return
+
+ self._executing = True
+ self._pending_executions = [(idx, statement, params)]
try:
- future = self.session.execute_async(statement, params, timeout=None, execution_profile=self._execution_profile)
- args = (future, idx)
- future.add_callbacks(
- callback=self._on_success, callback_args=args,
- errback=self._on_error, errback_args=args)
- except Exception as exc:
- # If we're not failing fast and all executions are raising, there is a chance of recursing
- # here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry
- # and let the event loop thread return.
- if self._exec_depth < self.max_error_recursion:
- self._put_result(exc, idx, False)
- else:
- self.session.submit(self._put_result, exc, idx, False)
- self._exec_depth -= 1
+ while self._pending_executions:
+ p_idx, p_statement, p_params = self._pending_executions.pop(0)
+ try:
+ future = self.session.execute_async(p_statement, p_params, timeout=None, execution_profile=self._execution_profile)
+ args = (future, p_idx)
+ future.add_callbacks(
+ callback=self._on_success, callback_args=args,
+ errback=self._on_error, errback_args=args)
+ except Exception as exc:
+ self._put_result(exc, p_idx, False)
+ finally:
+ self._executing = False
def _on_success(self, result, future, idx):
future.clear_callbacks()
diff --git a/cassandra/connection.py b/cassandra/connection.py
index e1646cafc1..f07160e385 100644
--- a/cassandra/connection.py
+++ b/cassandra/connection.py
@@ -25,12 +25,14 @@
from threading import Thread, Event, RLock, Condition
import time
import ssl
+import uuid
import weakref
import random
import itertools
-from typing import Optional
+from typing import Any, Dict, Optional, Tuple, Union
from cassandra.application_info import ApplicationInfoBase
+from cassandra.client_routes import _ClientRoutesHandler
from cassandra.protocol_features import ProtocolFeatures
if 'gevent.monkey' in sys.modules:
@@ -64,6 +66,7 @@
try:
import lz4
except ImportError:
+ log.debug("lz4 package could not be imported. LZ4 Compression will not be available")
pass
else:
# The compress and decompress functions we need were moved from the lz4 to
@@ -102,6 +105,7 @@ def lz4_decompress(byts):
try:
import snappy
except ImportError:
+ log.debug("snappy package could not be imported. Snappy Compression will not be available")
pass
else:
# work around apparently buggy snappy decompress
@@ -123,7 +127,6 @@ def decompress(byts):
DEFAULT_LOCAL_PORT_LOW = 49152
DEFAULT_LOCAL_PORT_HIGH = 65535
-frame_header_v1_v2 = struct.Struct('>BbBi')
frame_header_v3 = struct.Struct('>BhBi')
@@ -229,7 +232,7 @@ class DefaultEndPointFactory(EndPointFactory):
port = None
"""
If no port is discovered in the row, this is the default port
- used for endpoint creation.
+ used for endpoint creation.
"""
def __init__(self, port=None):
@@ -327,6 +330,50 @@ def create_from_sni(self, sni):
return SniEndPoint(self._proxy_address, sni, self._port)
+class ClientRoutesEndPointFactory(EndPointFactory):
+ """
+ EndPointFactory for Client Routes (Private Link) support.
+
+ Creates ClientRoutesEndPoint instances that defer both address translation
+ (host_id -> hostname lookup) and DNS resolution until connection time.
+ This ensures immediate reaction to infrastructure changes.
+ """
+
+ client_routes_handler: _ClientRoutesHandler
+ default_port: int
+
+ def __init__(self, client_routes_handler: _ClientRoutesHandler, default_port: int = None) -> None:
+ """
+ :param client_routes_handler: _ClientRoutesHandler instance to lookup routes
+ :param default_port: Default port if none found in row
+ """
+ self.client_routes_handler = client_routes_handler
+ self.default_port = default_port
+
+ def create(self, row: Dict[str, Any]) -> 'ClientRoutesEndPoint':
+ """
+ Create a ClientRoutesEndPoint from a system.peers row.
+
+ Stores only the host_id and handler reference. Both translation
+ (route lookup) and DNS resolution happen later in resolve().
+ """
+ from cassandra.metadata import _NodeInfo
+ host_id = row.get("host_id")
+
+ if host_id is None:
+ raise ValueError("No host_id to create ClientRoutesEndPoint")
+
+ addr = _NodeInfo.get_broadcast_rpc_address(row)
+ port = _NodeInfo.get_broadcast_rpc_port(row) or _NodeInfo.get_broadcast_port(row) or self.default_port
+
+ return ClientRoutesEndPoint(
+ host_id=host_id,
+ handler=self.client_routes_handler,
+ original_address=addr,
+ original_port=port,
+ )
+
+
@total_ordering
class UnixSocketEndPoint(EndPoint):
"""
@@ -368,6 +415,76 @@ def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path)
+@total_ordering
+class ClientRoutesEndPoint(EndPoint):
+ """
+ Client Routes (Private Link) EndPoint implementation.
+
+ Defers both address translation (route lookup) and DNS resolution
+ until resolve() is called at connection time. This ensures immediate
+ reaction to infrastructure changes and CLIENT_ROUTES_CHANGE events.
+ """
+
+ _host_id: uuid.UUID
+ _handler: _ClientRoutesHandler
+ _original_address: str
+ _original_port: int
+
+ def __init__(self, host_id: uuid.UUID, handler: _ClientRoutesHandler, original_address: str, original_port: int = None) -> None:
+ """
+ :param host_id: Host UUID for route lookup
+ :param handler: _ClientRoutesHandler instance
+ :param original_address: Original address from system.peers (for identification)
+ :param original_port: Original port if route doesn't specify one
+ """
+ self._host_id = host_id
+ self._handler = handler
+ self._original_address = original_address
+ self._original_port = original_port
+
+ @property
+ def address(self) -> str:
+ """Returns the original address (updated by resolve())."""
+ return self._original_address
+
+ @property
+ def port(self) -> Optional[int]:
+ return self._original_port
+
+ @property
+ def host_id(self) -> uuid.UUID:
+ return self._host_id
+
+ def resolve(self) -> Tuple[str, int]:
+ """
+ Resolve endpoint by delegating to the handler.
+ Falls back to original address/port if no route mapping is available.
+ """
+ result = self._handler.resolve_host(self._host_id)
+ if result is None:
+ return self._original_address, self._original_port
+ return result
+
+ def __eq__(self, other):
+ return (isinstance(other, ClientRoutesEndPoint) and
+ self._host_id == other._host_id and
+ self._original_address == other._original_address)
+
+ def __hash__(self):
+ return hash((self._host_id, self._original_address))
+
+ def __lt__(self, other):
+ return ((self._host_id, self._original_address) <
+ (other._host_id, other._original_address))
+
+ def __str__(self):
+ return str("%s (host_id=%s)" % (self._original_address, self._host_id))
+
+ def __repr__(self):
+ return "<%s: host_id=%s, original_addr=%s>" % (
+ self.__class__.__name__, self._host_id, self._original_address)
+
+
class _Frame(object):
def __init__(self, version, flags, stream, opcode, body_offset, end_pos):
self.version = version
@@ -444,33 +561,6 @@ class ProtocolError(Exception):
class CrcMismatchException(ConnectionException):
pass
-
-class ContinuousPagingState(object):
- """
- A class for specifying continuous paging state, only supported starting with DSE_V2.
- """
-
- num_pages_requested = None
- """
- How many pages we have already requested
- """
-
- num_pages_received = None
- """
- How many pages we have already received
- """
-
- max_queue_size = None
- """
- The max queue size chosen by the user via the options
- """
-
- def __init__(self, max_queue_size):
- self.num_pages_requested = max_queue_size # the initial query requests max_queue_size
- self.num_pages_received = 0
- self.max_queue_size = max_queue_size
-
-
class ContinuousPagingSession(object):
def __init__(self, stream_id, decoder, row_factory, connection, state):
self.stream_id = stream_id
@@ -668,15 +758,29 @@ def reset_cql_frame_buffer(self):
self.reset_io_buffer()
-class ShardawarePortGenerator:
- @classmethod
- def generate(cls, shard_id, total_shards):
- start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
- available_ports = itertools.chain(range(start, DEFAULT_LOCAL_PORT_HIGH), range(DEFAULT_LOCAL_PORT_LOW, start))
+class ShardAwarePortGenerator:
+ def __init__(self, start_port: int, end_port: int):
+ self.start_port = start_port
+ self.end_port = end_port
+
+ @staticmethod
+ def _align(value: int, total_shards: int):
+ shift = value % total_shards
+ if shift == 0:
+ return value
+ return value + total_shards - shift
+
+ def generate(self, shard_id: int, total_shards: int):
+ start = self._align(random.randrange(self.start_port, self.end_port), total_shards) + shard_id
+ beginning = self._align(self.start_port, total_shards) + shard_id
+ available_ports = itertools.chain(range(start, self.end_port, total_shards),
+ range(beginning, start, total_shards))
for port in available_ports:
- if port % total_shards == shard_id:
- yield port
+ yield port
+
+
+DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
class Connection(object):
@@ -691,7 +795,7 @@ class Connection(object):
protocol_version = ProtocolVersion.MAX_SUPPORTED
keyspace = None
- compression = True
+ compression: Union[bool, str] = True
_compression_type = None
compressor = None
decompressor = None
@@ -772,7 +876,7 @@ def _iobuf(self):
return self._io_buffer.io_buffer
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
- ssl_options=None, sockopts=None, compression=True,
+ ssl_options=None, sockopts=None, compression: Union[bool, str] = True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
@@ -817,17 +921,12 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
if not self.ssl_context and self.ssl_options:
self.ssl_context = self._build_ssl_context_from_options()
- if protocol_version >= 3:
- self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1)
- # Don't fill the deque with 2**15 items right away. Start with some and add
- # more if needed.
- initial_size = min(300, self.max_in_flight)
- self.request_ids = deque(range(initial_size))
- self.highest_request_id = initial_size - 1
- else:
- self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1)
- self.request_ids = deque(range(self.max_request_id + 1))
- self.highest_request_id = self.max_request_id
+ self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1)
+ # Don't fill the deque with 2**15 items right away. Start with some and add
+ # more if needed.
+ initial_size = min(300, self.max_in_flight)
+ self.request_ids = deque(range(initial_size))
+ self.highest_request_id = initial_size - 1
self.lock = RLock()
self.connected_event = Event()
@@ -885,7 +984,8 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs):
raise conn.last_error
elif not conn.connected_event.is_set():
conn.close()
- raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout)
+ raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout,
+ timeout=timeout)
else:
return conn
@@ -934,7 +1034,7 @@ def _wrap_socket_from_context(self):
def _initiate_connection(self, sockaddr):
if self.features.shard_id is not None:
- for port in ShardawarePortGenerator.generate(self.features.shard_id, self.total_shards):
+ for port in DefaultShardAwarePortGenerator.generate(self.features.shard_id, self.total_shards):
try:
self._socket.bind(('', port))
break
@@ -1104,9 +1204,15 @@ def handle_pushed(self, response):
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None):
if self.is_defunct:
- raise ConnectionShutdown("Connection to %s is defunct" % self.endpoint)
+ msg = "Connection to %s is defunct" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ raise ConnectionShutdown(msg)
elif self.is_closed:
- raise ConnectionShutdown("Connection to %s is closed" % self.endpoint)
+ msg = "Connection to %s is closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ raise ConnectionShutdown(msg)
elif not self._socket_writable:
raise ConnectionBusy("Connection %s is overloaded" % self.endpoint)
@@ -1137,8 +1243,12 @@ def wait_for_responses(self, *msgs, **kwargs):
failed, the corresponding Exception will be raised.
"""
if self.is_closed or self.is_defunct:
- raise ConnectionShutdown("Connection %s is already closed" % (self, ))
+ msg = "Connection %s is already closed" % (self,)
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ raise ConnectionShutdown(msg)
timeout = kwargs.get('timeout')
+ original_timeout = timeout # preserve for exception reporting
fail_on_error = kwargs.get('fail_on_error', True)
waiter = ResponseWaiter(self, len(msgs), fail_on_error)
@@ -1163,7 +1273,8 @@ def wait_for_responses(self, *msgs, **kwargs):
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
- raise OperationTimedOut()
+ raise OperationTimedOut(timeout=original_timeout,
+ in_flight=self.in_flight)
time.sleep(0.01)
try:
@@ -1205,11 +1316,10 @@ def _read_frame_header(self):
version = buf[0] & PROTOCOL_VERSION_MASK
if version not in ProtocolVersion.SUPPORTED_VERSIONS:
raise ProtocolError("This version of the driver does not support protocol version %d" % version)
- frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2
# this frame header struct is everything after the version byte
- header_size = frame_header.size + 1
+ header_size = frame_header_v3.size + 1
if pos >= header_size:
- flags, stream, op, body_len = frame_header.unpack_from(buf, 1)
+ flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1)
if body_len < 0:
raise ProtocolError("Received negative body length: %r" % body_len)
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
@@ -1401,10 +1511,11 @@ def _handle_options_response(self, options_response):
overlap = (set(locally_supported_compressions.keys()) &
set(remote_supported_compressions))
if len(overlap) == 0:
- log.debug("No available compression types supported on both ends."
- " locally supported: %r. remotely supported: %r",
- locally_supported_compressions.keys(),
- remote_supported_compressions)
+ if locally_supported_compressions:
+ log.error("No available compression types supported on both ends."
+ " locally supported: %r. remotely supported: %r",
+ locally_supported_compressions.keys(),
+ remote_supported_compressions)
else:
compression_type = None
if isinstance(self.compression, str):
@@ -1550,7 +1661,8 @@ def set_keyspace_blocking(self, keyspace):
if not keyspace or keyspace == self.keyspace:
return
- query = QueryMessage(query='USE "%s"' % (keyspace,),
+ from cassandra.metadata import escape_name
+ query = QueryMessage(query='USE %s' % (escape_name(keyspace),),
consistency_level=ConsistencyLevel.ONE)
try:
result = self.wait_for_response(query)
@@ -1604,7 +1716,8 @@ def set_keyspace_async(self, keyspace, callback):
callback(self, None)
return
- query = QueryMessage(query='USE "%s"' % (keyspace,),
+ from cassandra.metadata import escape_name
+ query = QueryMessage(query='USE %s' % (escape_name(keyspace),),
consistency_level=ConsistencyLevel.ONE)
def process_result(result):
@@ -1686,7 +1799,8 @@ def deliver(self, timeout=None):
if self.error:
raise self.error
elif not self.event.is_set():
- raise OperationTimedOut()
+ raise OperationTimedOut(timeout=timeout,
+ in_flight=self.connection.in_flight)
else:
return self.responses
@@ -1702,7 +1816,19 @@ def __init__(self, connection, owner):
with connection.lock:
if connection.in_flight < connection.max_request_id:
connection.in_flight += 1
- connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
+ request_id = connection.get_request_id()
+ try:
+ connection.send_msg(OptionsMessage(), request_id, self._options_callback)
+ except Exception as exc:
+ if connection.is_control_connection:
+ connection.in_flight -= 1
+ # send_msg() registers the callback before writing to the socket,
+ # so a write failure must unwind that registration here.
+ connection._requests.pop(request_id, None)
+ if request_id not in connection.request_ids:
+ connection.request_ids.append(request_id)
+ self._exception = exc
+ self._event.set()
else:
self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
self._event.set()
@@ -1713,7 +1839,10 @@ def wait(self, timeout):
if self._exception:
raise self._exception
else:
- raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint)
+ raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,),
+ self.connection.endpoint,
+ timeout=timeout,
+ in_flight=self.connection.in_flight)
def _options_callback(self, response):
if isinstance(response, SupportedMessage):
diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py
index 3d85587524..509b606ccc 100644
--- a/cassandra/cqlengine/columns.py
+++ b/cassandra/cqlengine/columns.py
@@ -837,7 +837,30 @@ def to_database(self, value):
class BaseContainerColumn(BaseCollectionColumn):
- pass
+ """
+ Base class for container columns (Set, List, Map).
+
+ Supports optional freezing for immutable collections.
+ """
+
+ frozen = False
+ """
+ bool flag, indicates this collection should be frozen (immutable).
+ Frozen collections use FULL indexes instead of VALUES indexes.
+ """
+
+ def __init__(self, types, frozen=False, **kwargs):
+ """
+ :param types: a sequence of sub types in this collection
+ :param frozen: if True, the collection will be frozen (immutable)
+ """
+ self.frozen = frozen
+ super(BaseContainerColumn, self).__init__(types, **kwargs)
+
+ def _apply_frozen(self):
+ """Apply frozen wrapper to db_type if frozen=True."""
+ if self.frozen:
+ self._freeze_db_type()
class Set(BaseContainerColumn):
@@ -849,18 +872,21 @@ class Set(BaseContainerColumn):
_python_type_hashable = False
- def __init__(self, value_type, strict=True, default=set, **kwargs):
+ def __init__(self, value_type, strict=True, default=set, frozen=False, **kwargs):
"""
:param value_type: a column class indicating the types of the value
:param strict: sets whether non set values will be coerced to set
type on validation, or raise a validation error, defaults to True
+ :param frozen: if True, the collection will be frozen (immutable) and
+ use FULL indexes instead of VALUES indexes
"""
self.strict = strict
- super(Set, self).__init__((value_type,), default=default, **kwargs)
+ super(Set, self).__init__((value_type,), frozen=frozen, default=default, **kwargs)
self.value_col = self.types[0]
if not self.value_col._python_type_hashable:
raise ValidationError("Cannot create a Set with unhashable value type (see PYTHON-494)")
self.db_type = 'set<{0}>'.format(self.value_col.db_type)
+ self._apply_frozen()
def validate(self, value):
val = super(Set, self).validate(value)
@@ -899,13 +925,16 @@ class List(BaseContainerColumn):
_python_type_hashable = False
- def __init__(self, value_type, default=list, **kwargs):
+ def __init__(self, value_type, default=list, frozen=False, **kwargs):
"""
:param value_type: a column class indicating the types of the value
+ :param frozen: if True, the collection will be frozen (immutable) and
+ use FULL indexes instead of VALUES indexes
"""
- super(List, self).__init__((value_type,), default=default, **kwargs)
+ super(List, self).__init__((value_type,), frozen=frozen, default=default, **kwargs)
self.value_col = self.types[0]
self.db_type = 'list<{0}>'.format(self.value_col.db_type)
+ self._apply_frozen()
def validate(self, value):
val = super(List, self).validate(value)
@@ -937,12 +966,14 @@ class Map(BaseContainerColumn):
_python_type_hashable = False
- def __init__(self, key_type, value_type, default=dict, **kwargs):
+ def __init__(self, key_type, value_type, default=dict, frozen=False, **kwargs):
"""
:param key_type: a column class indicating the types of the key
:param value_type: a column class indicating the types of the value
+ :param frozen: if True, the collection will be frozen (immutable) and
+ use FULL indexes instead of VALUES indexes
"""
- super(Map, self).__init__((key_type, value_type), default=default, **kwargs)
+ super(Map, self).__init__((key_type, value_type), frozen=frozen, default=default, **kwargs)
self.key_col = self.types[0]
self.value_col = self.types[1]
@@ -950,6 +981,7 @@ def __init__(self, key_type, value_type, default=dict, **kwargs):
raise ValidationError("Cannot create a Map with unhashable key type (see PYTHON-494)")
self.db_type = 'map<{0}, {1}>'.format(self.key_col.db_type, self.value_col.db_type)
+ self._apply_frozen()
def validate(self, value):
val = super(Map, self).validate(value)
diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py
index 4ac4192a80..684bc50b8a 100644
--- a/cassandra/cqlengine/management.py
+++ b/cassandra/cqlengine/management.py
@@ -56,7 +56,7 @@ def _get_context(keyspaces, connections):
def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None):
"""
- Creates a keyspace with SimpleStrategy for replica placement
+ Creates a keyspace with NetworkTopologyStrategy for replica placement
If the keyspace already exists, it will not be modified.
@@ -66,11 +66,11 @@ def create_keyspace_simple(name, replication_factor, durable_writes=True, connec
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
:param str name: name of keyspace to create
- :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy`
+ :param int replication_factor: keyspace replication factor, used with :attr:`~.NetworkTopologyStrategy`
:param bool durable_writes: Write log is bypassed if set to False
:param list connections: List of connection names
"""
- _create_keyspace(name, durable_writes, 'SimpleStrategy',
+ _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy',
{'replication_factor': replication_factor}, connections=connections)
@@ -282,7 +282,11 @@ def _sync_table(model, connection=None):
qs = ['CREATE INDEX']
qs += ['ON {0}'.format(cf_name)]
- qs += ['("{0}")'.format(column.db_field_name)]
+ # Use FULL index for frozen collections, VALUES index (implicit) for non-frozen
+ if isinstance(column, columns.BaseContainerColumn) and column.frozen:
+ qs += ['(FULL("{0}"))'.format(column.db_field_name)]
+ else:
+ qs += ['("{0}")'.format(column.db_field_name)]
qs = ' '.join(qs)
execute(qs, connection=connection)
diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py
index d1580f00ff..547a13c979 100644
--- a/cassandra/cqltypes.py
+++ b/cassandra/cqltypes.py
@@ -249,6 +249,8 @@ def lookup_casstype(casstype):
"""
if isinstance(casstype, (CassandraType, CassandraTypeType)):
return casstype
+ if '(' not in casstype:
+ return lookup_casstype_simple(casstype)
try:
return parse_casstype_args(casstype)
except (ValueError, AssertionError, IndexError) as e:
@@ -636,24 +638,24 @@ def interpret_datestring(val):
except ValueError:
continue
# scale seconds to millis for the raw value
- return (calendar.timegm(tval) + offset) * 1e3
+ return (calendar.timegm(tval) + offset) * 1000
else:
raise ValueError("can't interpret %r as a date" % (val,))
@staticmethod
def deserialize(byts, protocol_version):
- timestamp = int64_unpack(byts) / 1000.0
- return util.datetime_from_timestamp(timestamp)
+ timestamp_ms = int64_unpack(byts)
+ return util.datetime_from_ms_timestamp(timestamp_ms)
@staticmethod
def serialize(v, protocol_version):
try:
# v is datetime
timestamp_seconds = calendar.timegm(v.utctimetuple())
- timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3
+ timestamp = timestamp_seconds * 1000 + getattr(v, 'microsecond', 0) // 1000
except AttributeError:
try:
- timestamp = calendar.timegm(v.timetuple()) * 1e3
+ timestamp = calendar.timegm(v.timetuple()) * 1000
except AttributeError:
# Ints and floats are valid timestamps too
if type(v) not in _number_types:
@@ -812,18 +814,13 @@ class _SimpleParameterizedType(_ParameterizedType):
@classmethod
def deserialize_safe(cls, byts, protocol_version):
subtype, = cls.subtypes
- if protocol_version >= 3:
- unpack = int32_unpack
- length = 4
- else:
- unpack = uint16_unpack
- length = 2
- numelements = unpack(byts[:length])
+ length = 4
+ numelements = int32_unpack(byts[:length])
p = length
result = []
inner_proto = max(3, protocol_version)
for _ in range(numelements):
- itemlen = unpack(byts[p:p + length])
+ itemlen = int32_unpack(byts[p:p + length])
p += length
if itemlen < 0:
result.append(None)
@@ -839,16 +836,15 @@ def serialize_safe(cls, items, protocol_version):
raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes
- pack = int32_pack if protocol_version >= 3 else uint16_pack
buf = io.BytesIO()
- buf.write(pack(len(items)))
+ buf.write(int32_pack(len(items)))
inner_proto = max(3, protocol_version)
for item in items:
if item is None:
- buf.write(pack(-1))
+ buf.write(int32_pack(-1))
else:
itembytes = subtype.to_binary(item, inner_proto)
- buf.write(pack(len(itembytes)))
+ buf.write(int32_pack(len(itembytes)))
buf.write(itembytes)
return buf.getvalue()
@@ -872,18 +868,13 @@ class MapType(_ParameterizedType):
@classmethod
def deserialize_safe(cls, byts, protocol_version):
key_type, value_type = cls.subtypes
- if protocol_version >= 3:
- unpack = int32_unpack
- length = 4
- else:
- unpack = uint16_unpack
- length = 2
- numelements = unpack(byts[:length])
+ length = 4
+ numelements = int32_unpack(byts[:length])
p = length
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
inner_proto = max(3, protocol_version)
for _ in range(numelements):
- key_len = unpack(byts[p:p + length])
+ key_len = int32_unpack(byts[p:p + length])
p += length
if key_len < 0:
keybytes = None
@@ -893,7 +884,7 @@ def deserialize_safe(cls, byts, protocol_version):
p += key_len
key = key_type.from_binary(keybytes, inner_proto)
- val_len = unpack(byts[p:p + length])
+ val_len = int32_unpack(byts[p:p + length])
p += length
if val_len < 0:
val = None
@@ -908,9 +899,8 @@ def deserialize_safe(cls, byts, protocol_version):
@classmethod
def serialize_safe(cls, themap, protocol_version):
key_type, value_type = cls.subtypes
- pack = int32_pack if protocol_version >= 3 else uint16_pack
buf = io.BytesIO()
- buf.write(pack(len(themap)))
+ buf.write(int32_pack(len(themap)))
try:
items = themap.items()
except AttributeError:
@@ -919,16 +909,16 @@ def serialize_safe(cls, themap, protocol_version):
for key, val in items:
if key is not None:
keybytes = key_type.to_binary(key, inner_proto)
- buf.write(pack(len(keybytes)))
+ buf.write(int32_pack(len(keybytes)))
buf.write(keybytes)
else:
- buf.write(pack(-1))
+ buf.write(int32_pack(-1))
if val is not None:
valbytes = value_type.to_binary(val, inner_proto)
- buf.write(pack(len(valbytes)))
+ buf.write(int32_pack(len(valbytes)))
buf.write(valbytes)
else:
- buf.write(pack(-1))
+ buf.write(int32_pack(-1))
return buf.getvalue()
diff --git a/cassandra/cython_utils.pxd b/cassandra/cython_utils.pxd
index 4a1e71dba5..7469657b04 100644
--- a/cassandra/cython_utils.pxd
+++ b/cassandra/cython_utils.pxd
@@ -1,2 +1,3 @@
from libc.stdint cimport int64_t
cdef datetime_from_timestamp(double timestamp)
+cdef datetime_from_ms_timestamp(int64_t timestamp_ms)
diff --git a/cassandra/cython_utils.pyx b/cassandra/cython_utils.pyx
index 7539f33f31..f3421063da 100644
--- a/cassandra/cython_utils.pyx
+++ b/cassandra/cython_utils.pyx
@@ -60,3 +60,22 @@ cdef datetime_from_timestamp(double timestamp):
microseconds += tmp
return DATETIME_EPOC + timedelta_new(days, seconds, microseconds)
+
+
+cdef datetime_from_ms_timestamp(int64_t timestamp_ms):
+ """
+ Creates a datetime from a timestamp in milliseconds using integer
+ arithmetic to preserve precision for large values.
+ """
+ cdef int64_t total_seconds = timestamp_ms // 1000
+ cdef int microseconds = ((timestamp_ms % 1000) * 1000)
+ # For negative timestamps, ensure microseconds is non-negative
+ if microseconds < 0:
+ total_seconds -= 1
+ microseconds += 1000000
+ cdef int days = (total_seconds // 86400)
+ cdef int seconds = (total_seconds % 86400)
+ if seconds < 0:
+ days -= 1
+ seconds += 86400
+ return DATETIME_EPOC + timedelta_new(days, seconds, microseconds)
diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx
index 7c256674b0..98e8676bbc 100644
--- a/cassandra/deserializers.pyx
+++ b/cassandra/deserializers.pyx
@@ -17,7 +17,7 @@ from libc.stdint cimport int32_t, uint16_t
include 'cython_marshal.pyx'
from cassandra.buffer cimport Buffer, to_bytes, slice_buffer
-from cassandra.cython_utils cimport datetime_from_timestamp
+from cassandra.cython_utils cimport datetime_from_timestamp, datetime_from_ms_timestamp
from cython.view cimport array as cython_array
from cassandra.tuple cimport tuple_new, tuple_set
@@ -135,8 +135,8 @@ cdef class DesCounterColumnType(DesLongType):
cdef class DesDateType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
- cdef double timestamp = unpack_num[int64_t](buf) / 1000.0
- return datetime_from_timestamp(timestamp)
+ cdef int64_t timestamp_ms = unpack_num[int64_t](buf)
+ return datetime_from_ms_timestamp(timestamp_ms)
cdef class TimestampType(DesDateType):
@@ -208,15 +208,9 @@ cdef class _DesSingleParamType(_DesParameterizedType):
cdef class DesListType(_DesSingleParamType):
cdef deserialize(self, Buffer *buf, int protocol_version):
- cdef uint16_t v2_and_below = 2
- cdef int32_t v3_and_above = 3
- if protocol_version >= 3:
- result = _deserialize_list_or_set[int32_t](
- v3_and_above, buf, protocol_version, self.deserializer)
- else:
- result = _deserialize_list_or_set[uint16_t](
- v2_and_below, buf, protocol_version, self.deserializer)
+ result = _deserialize_list_or_set(
+ buf, protocol_version, self.deserializer)
return result
@@ -225,60 +219,49 @@ cdef class DesSetType(DesListType):
return util.sortedset(DesListType.deserialize(self, buf, protocol_version))
-ctypedef fused itemlen_t:
- uint16_t # protocol <= v2
- int32_t # protocol >= v3
-
-cdef list _deserialize_list_or_set(itemlen_t dummy_version,
- Buffer *buf, int protocol_version,
+cdef list _deserialize_list_or_set(Buffer *buf, int protocol_version,
Deserializer deserializer):
"""
Deserialize a list or set.
-
- The 'dummy' parameter is needed to make fused types work, so that
- we can specialize on the protocol version.
"""
cdef Buffer itemlen_buf
cdef Buffer elem_buf
- cdef itemlen_t numelements
+ cdef int32_t numelements
cdef int offset
cdef list result = []
- _unpack_len[itemlen_t](buf, 0, &numelements)
- offset = sizeof(itemlen_t)
+ _unpack_len(buf, 0, &numelements)
+ offset = sizeof(int32_t)
protocol_version = max(3, protocol_version)
for _ in range(numelements):
- subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version)
+ subelem(buf, &elem_buf, &offset)
result.append(from_binary(deserializer, &elem_buf, protocol_version))
return result
cdef inline int subelem(
- Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1:
+ Buffer *buf, Buffer *elem_buf, int* offset) except -1:
"""
Read the next element from the buffer: first read the size (in bytes) of the
element, then fill elem_buf with a newly sliced buffer of this size (and the
right offset).
"""
- cdef itemlen_t elemlen
+ cdef int32_t elemlen
- _unpack_len[itemlen_t](buf, offset[0], &elemlen)
- offset[0] += sizeof(itemlen_t)
+ _unpack_len(buf, offset[0], &elemlen)
+ offset[0] += sizeof(int32_t)
slice_buffer(buf, elem_buf, offset[0], elemlen)
offset[0] += elemlen
return 0
-cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1:
+cdef int _unpack_len(Buffer *buf, int offset, int32_t *output) except -1:
cdef Buffer itemlen_buf
- slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t))
+ slice_buffer(buf, &itemlen_buf, offset, sizeof(int32_t))
- if itemlen_t is uint16_t:
- output[0] = unpack_num[uint16_t](&itemlen_buf)
- else:
- output[0] = unpack_num[int32_t](&itemlen_buf)
+ output[0] = unpack_num[int32_t](&itemlen_buf)
return 0
@@ -295,42 +278,33 @@ cdef class DesMapType(_DesParameterizedType):
self.val_deserializer = self.deserializers[1]
cdef deserialize(self, Buffer *buf, int protocol_version):
- cdef uint16_t v2_and_below = 0
- cdef int32_t v3_and_above = 0
key_type, val_type = self.cqltype.subtypes
- if protocol_version >= 3:
- result = _deserialize_map[int32_t](
- v3_and_above, buf, protocol_version,
- self.key_deserializer, self.val_deserializer,
- key_type, val_type)
- else:
- result = _deserialize_map[uint16_t](
- v2_and_below, buf, protocol_version,
- self.key_deserializer, self.val_deserializer,
- key_type, val_type)
+ result = _deserialize_map(
+ buf, protocol_version,
+ self.key_deserializer, self.val_deserializer,
+ key_type, val_type)
return result
-cdef _deserialize_map(itemlen_t dummy_version,
- Buffer *buf, int protocol_version,
+cdef _deserialize_map(Buffer *buf, int protocol_version,
Deserializer key_deserializer, Deserializer val_deserializer,
key_type, val_type):
cdef Buffer key_buf, val_buf
cdef Buffer itemlen_buf
- cdef itemlen_t numelements
+ cdef int32_t numelements
cdef int offset
cdef list result = []
- _unpack_len[itemlen_t](buf, 0, &numelements)
- offset = sizeof(itemlen_t)
+ _unpack_len(buf, 0, &numelements)
+ offset = sizeof(int32_t)
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
protocol_version = max(3, protocol_version)
for _ in range(numelements):
- subelem[itemlen_t](buf, &key_buf, &offset, dummy_version)
- subelem[itemlen_t](buf, &val_buf, &offset, numelements)
+ subelem(buf, &key_buf, &offset)
+ subelem(buf, &val_buf, &offset)
key = from_binary(key_deserializer, &key_buf, protocol_version)
val = from_binary(val_deserializer, &val_buf, protocol_version)
themap._insert_unchecked(key, to_bytes(&key_buf), val)
diff --git a/cassandra/encoder.py b/cassandra/encoder.py
index e834550fd3..d803c087ba 100644
--- a/cassandra/encoder.py
+++ b/cassandra/encoder.py
@@ -142,7 +142,7 @@ def cql_encode_datetime(self, val):
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
- return str(int(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
+ return str(timestamp * 1000 + getattr(val, 'microsecond', 0) // 1000)
def cql_encode_date(self, val):
"""
diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py
index 41b744602d..452667c8eb 100644
--- a/cassandra/io/asyncioreactor.py
+++ b/cassandra/io/asyncioreactor.py
@@ -23,8 +23,8 @@
asyncio.run_coroutine_threadsafe
except AttributeError:
raise ImportError(
- 'Cannot use asyncioreactor without access to '
- 'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)'
+ "Cannot use asyncioreactor without access to "
+ "asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)"
)
@@ -38,12 +38,12 @@ class AsyncioTimer(object):
@property
def end(self):
- raise NotImplementedError('{} is not compatible with TimerManager and '
- 'does not implement .end()')
+ raise NotImplementedError(
+ "{} is not compatible with TimerManager and does not implement .end()"
+ )
def __init__(self, timeout, callback, loop):
- delayed = self._call_delayed_coro(timeout=timeout,
- callback=callback)
+ delayed = self._call_delayed_coro(timeout=timeout, callback=callback)
self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop)
@staticmethod
@@ -63,17 +63,61 @@ def cancel(self):
def finish(self):
# connection.Timer method not implemented here because we can't inspect
# the Handle returned from call_later
- raise NotImplementedError('{} is not compatible with TimerManager and '
- 'does not implement .finish()')
+ raise NotImplementedError(
+ "{} is not compatible with TimerManager and does not implement .finish()"
+ )
+
+
+class _AsyncioProtocol(asyncio.Protocol):
+ """
+ Protocol adapter for asyncio SSL connections. Bridges asyncio's
+ transport/protocol API back to AsyncioConnection's buffer processing.
+ """
+
+ def __init__(self, connection, loop_args=None):
+ self._connection = connection
+ self.transport = None
+ self.write_ready = asyncio.Event(**(loop_args or {}))
+ self.write_ready.set()
+
+ def connection_made(self, transport):
+ self.transport = transport
+
+ def data_received(self, data):
+ conn = self._connection
+ conn._iobuf.write(data)
+ if conn._iobuf.tell():
+ conn.process_io_buffer()
+
+ def pause_writing(self):
+ self.write_ready.clear()
+
+ def resume_writing(self):
+ self.write_ready.set()
+
+ def connection_lost(self, exc):
+ # Unblock any paused writer so shutdown does not hang
+ self.write_ready.set()
+ conn = self._connection
+ if exc:
+ log.debug("Connection %s lost: %s", conn, exc)
+ conn.defunct(exc)
+ else:
+ log.debug("Connection %s closed by server", conn)
+ conn.close()
+
+ def eof_received(self):
+ return False
class AsyncioConnection(Connection):
"""
- An experimental implementation of :class:`.Connection` that uses the
- ``asyncio`` module in the Python standard library for its event loop.
+ An implementation of :class:`.Connection` that uses the ``asyncio``
+ module in the Python standard library for its event loop.
- Note that it requires ``asyncio`` features that were only introduced in the
- 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
+ Supports SSL connections via asyncio's native TLS transport, which
+ avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
+ low-level socket methods (``sock_sendall``, ``sock_recv``).
"""
_loop = None
@@ -88,26 +132,109 @@ class AsyncioConnection(Connection):
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self._background_tasks = set()
+ self._transport = None
+ self._using_ssl = bool(self.ssl_context)
self._connect_socket()
self._socket.setblocking(0)
loop_args = dict()
if sys.version_info[0] == 3 and sys.version_info[1] < 10:
- loop_args['loop'] = self._loop
+ loop_args["loop"] = self._loop
+ self._protocol = _AsyncioProtocol(self, loop_args) if self._using_ssl else None
+ self._ssl_ready = asyncio.Event(**loop_args) if self._using_ssl else None
self._write_queue = asyncio.Queue(**loop_args)
self._write_queue_lock = asyncio.Lock(**loop_args)
# see initialize_reactor -- loop is running in a separate thread, so we
# have to use a threadsafe call
- self._read_watcher = asyncio.run_coroutine_threadsafe(
- self.handle_read(), loop=self._loop
- )
+ if self._using_ssl:
+ # For SSL: set up asyncio transport/protocol, then start writer
+ self._read_watcher = asyncio.run_coroutine_threadsafe(
+ self._setup_ssl_and_run(), loop=self._loop
+ )
+ else:
+ # For non-SSL: use low-level sock_sendall/sock_recv as before
+ self._read_watcher = asyncio.run_coroutine_threadsafe(
+ self.handle_read(), loop=self._loop
+ )
self._write_watcher = asyncio.run_coroutine_threadsafe(
self.handle_write(), loop=self._loop
)
self._send_options_message()
+ def _connect_socket(self):
+ """
+ Override base class to skip SSL wrapping of the socket.
+ For SSL connections, the plain TCP socket is connected here, and TLS
+ is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
+ """
+ sockerr = None
+ addresses = self._get_socket_addresses()
+ for af, socktype, proto, _, sockaddr in addresses:
+ try:
+ self._socket = self._socket_impl.socket(af, socktype, proto)
+ # Do NOT wrap with ssl_context here -- asyncio will handle TLS
+ self._socket.settimeout(self.connect_timeout)
+ self._initiate_connection(sockaddr)
+ self._socket.settimeout(None)
+
+ local_addr = self._socket.getsockname()
+ log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
+ sockerr = None
+ break
+ except socket.error as err:
+ if self._socket:
+ self._socket.close()
+ self._socket = None
+ sockerr = err
+
+ if sockerr:
+ raise socket.error(
+ sockerr.errno,
+ "Tried connecting to %s. Last error: %s"
+ % ([a[4] for a in addresses], sockerr.strerror or sockerr),
+ )
+
+ if self.sockopts:
+ for args in self.sockopts:
+ self._socket.setsockopt(*args)
+
+ async def _setup_ssl_and_run(self):
+ """
+ Upgrade the plain TCP connection to TLS using asyncio's native SSL
+ transport, then continuously read data via the protocol callbacks.
+ """
+ try:
+ ssl_context = self.ssl_context
+ server_hostname = None
+ if self.ssl_options:
+ server_hostname = self.ssl_options.get("server_hostname", None)
+ if server_hostname is None:
+ # asyncio's create_connection requires server_hostname when
+ # ssl= is set. Use endpoint address for SNI/verification when
+ # check_hostname is enabled; otherwise pass "" to suppress SNI.
+ server_hostname = (
+ self.endpoint.address if ssl_context.check_hostname else ""
+ )
+
+ transport, protocol = await self._loop.create_connection(
+ lambda: self._protocol,
+ sock=self._socket,
+ ssl=ssl_context,
+ server_hostname=server_hostname,
+ )
+ self._transport = transport
+
+ if self._check_hostname:
+ self._validate_hostname()
+ self._ssl_ready.set()
+ except Exception as exc:
+ log.debug("SSL setup failed for %s: %s", self, exc)
+ self.defunct(exc)
+ # Unblock handle_write so it can observe the defunct state and exit
+ self._ssl_ready.set()
+ return
@classmethod
def initialize_reactor(cls):
@@ -126,8 +253,9 @@ def initialize_reactor(cls):
cls._loop = asyncio.new_event_loop()
# daemonize so the loop will be shut down on interpreter
# shutdown
- cls._loop_thread = Thread(target=cls._loop.run_forever,
- daemon=True, name="asyncio_thread")
+ cls._loop_thread = Thread(
+ target=cls._loop.run_forever, daemon=True, name="asyncio_thread"
+ )
cls._loop_thread.start()
@classmethod
@@ -142,9 +270,7 @@ def close(self):
# close from the loop thread to avoid races when removing file
# descriptors
- asyncio.run_coroutine_threadsafe(
- self._close(), loop=self._loop
- )
+ asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop)
async def _close(self):
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
@@ -152,7 +278,10 @@ async def _close(self):
self._write_watcher.cancel()
if self._read_watcher:
self._read_watcher.cancel()
- if self._socket:
+ if self._transport:
+ self._transport.close()
+ self._transport = None
+ elif self._socket:
self._loop.remove_writer(self._socket.fileno())
self._loop.remove_reader(self._socket.fileno())
self._socket.close()
@@ -160,8 +289,10 @@ async def _close(self):
log.debug("Closed socket to %s" % (self.endpoint,))
if not self.is_defunct:
- self.error_all_requests(
- ConnectionShutdown("Connection to %s was closed" % self.endpoint))
+ msg = "Connection to %s was closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ self.error_all_requests(ConnectionShutdown(msg))
# don't leave in-progress operations hanging
self.connected_event.set()
@@ -170,15 +301,12 @@ def push(self, data):
if len(data) > buff_size:
chunks = []
for i in range(0, len(data), buff_size):
- chunks.append(data[i:i + buff_size])
+ chunks.append(data[i : i + buff_size])
else:
chunks = [data]
if self._loop_thread != threading.current_thread():
- asyncio.run_coroutine_threadsafe(
- self._push_msg(chunks),
- loop=self._loop
- )
+ asyncio.run_coroutine_threadsafe(self._push_msg(chunks), loop=self._loop)
else:
# avoid races/hangs by just scheduling this, not using threadsafe
task = self._loop.create_task(self._push_msg(chunks))
@@ -192,13 +320,25 @@ async def _push_msg(self, chunks):
for chunk in chunks:
self._write_queue.put_nowait(chunk)
-
async def handle_write(self):
+ # For SSL connections, wait until the TLS handshake completes
+ if self._ssl_ready:
+ await self._ssl_ready.wait()
+ if self.is_defunct:
+ return
while True:
try:
next_msg = await self._write_queue.get()
if next_msg:
- await self._loop.sock_sendall(self._socket, next_msg)
+ if self._transport:
+ # SSL: use asyncio transport (handles TLS transparently)
+ await self._protocol.write_ready.wait()
+ if self.is_closed or self.is_defunct or not self._transport:
+ return
+ self._transport.write(next_msg)
+ else:
+ # Non-SSL: use low-level socket API
+ await self._loop.sock_sendall(self._socket, next_msg)
except socket.error as err:
log.debug("Exception in send for %s: %s", self, err)
self.defunct(err)
@@ -221,8 +361,7 @@ async def handle_read(self):
await asyncio.sleep(0)
continue
except socket.error as err:
- log.debug("Exception during socket recv for %s: %s",
- self, err)
+ log.debug("Exception during socket recv for %s: %s", self, err)
self.defunct(err)
return # leave the read loop
except asyncio.CancelledError:
diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py
index 2c75e7139d..02466ad0d2 100644
--- a/cassandra/io/asyncorereactor.py
+++ b/cassandra/io/asyncorereactor.py
@@ -385,12 +385,14 @@ def close(self):
log.debug("Closed socket to %s", self.endpoint)
if not self.is_defunct:
- self.error_all_requests(
- ConnectionShutdown("Connection to %s was closed" % self.endpoint))
+ msg = "Connection to %s was closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ self.error_all_requests(ConnectionShutdown(msg))
#This happens when the connection is shutdown while waiting for the ReadyMessage
if not self.connected_event.is_set():
- self.last_error = ConnectionShutdown("Connection to %s was closed" % self.endpoint)
+ self.last_error = ConnectionShutdown(msg)
# don't leave in-progress operations hanging
self.connected_event.set()
diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py
index 42874036d5..234a4a574c 100644
--- a/cassandra/io/eventletreactor.py
+++ b/cassandra/io/eventletreactor.py
@@ -145,8 +145,10 @@ def close(self):
log.debug("Closed socket to %s" % (self.endpoint,))
if not self.is_defunct:
- self.error_all_requests(
- ConnectionShutdown("Connection to %s was closed" % self.endpoint))
+ msg = "Connection to %s was closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ self.error_all_requests(ConnectionShutdown(msg))
# don't leave in-progress operations hanging
self.connected_event.set()
diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py
index 4f1f158aa7..7516fdd6df 100644
--- a/cassandra/io/geventreactor.py
+++ b/cassandra/io/geventreactor.py
@@ -95,8 +95,10 @@ def close(self):
log.debug("Closed socket to %s" % (self.endpoint,))
if not self.is_defunct:
- self.error_all_requests(
- ConnectionShutdown("Connection to %s was closed" % self.endpoint))
+ msg = "Connection to %s was closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ self.error_all_requests(ConnectionShutdown(msg))
# don't leave in-progress operations hanging
self.connected_event.set()
diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py
index 29039653f4..3da809931f 100644
--- a/cassandra/io/libevreactor.py
+++ b/cassandra/io/libevreactor.py
@@ -13,7 +13,6 @@
# limitations under the License.
import atexit
from collections import deque
-from functools import partial
import logging
import os
import socket
@@ -116,6 +115,10 @@ def _cleanup(self):
if not self._thread:
return
+ # Stop the prepare watcher first to prevent race conditions
+ if self._preparer:
+ self._preparer.stop()
+
for conn in self._live_conns | self._new_conns | self._closed_conns:
conn.close()
for watcher in (conn._write_watcher, conn._read_watcher):
@@ -125,8 +128,9 @@ def _cleanup(self):
self.notify() # wake the timer watcher
# PYTHON-752 Thread might have just been created and not started
+ # Use longer timeout to allow proper cleanup
with self._lock_thread:
- self._thread.join(timeout=1.0)
+ self._thread.join(timeout=5.0)
if self._thread.is_alive():
log.warning(
@@ -165,6 +169,10 @@ def connection_created(self, conn):
def connection_destroyed(self, conn):
with self._conn_set_lock:
+ new_conns = self._new_conns.copy()
+ new_conns.discard(conn)
+ self._new_conns = new_conns
+
new_live_conns = self._live_conns.copy()
new_live_conns.discard(conn)
self._live_conns = new_live_conns
@@ -194,7 +202,8 @@ def _loop_will_run(self, prepare):
self._new_conns = set()
for conn in to_start:
- conn._read_watcher.start()
+ if conn._read_watcher:
+ conn._read_watcher.start()
changed = True
@@ -222,8 +231,20 @@ def _loop_will_run(self, prepare):
self._notifier.send()
+def _atexit_cleanup():
+ """Cleanup function called by atexit that uses the current _global_loop value.
+
+ This wrapper ensures that cleanup receives the actual LibevLoop instance
+ instead of None, which was the value of _global_loop when the module was
+ imported.
+ """
+ global _global_loop
+ if _global_loop is not None:
+ _cleanup(_global_loop)
+
+
_global_loop = None
-atexit.register(partial(_cleanup, _global_loop))
+atexit.register(_atexit_cleanup)
class LibevConnection(Connection):
@@ -292,8 +313,10 @@ def close(self):
# don't leave in-progress operations hanging
if not self.is_defunct:
- self.error_all_requests(
- ConnectionShutdown("Connection to %s was closed" % self.endpoint))
+ msg = "Connection to %s was closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ self.error_all_requests(ConnectionShutdown(msg))
self.connected_event.set()
def handle_write(self, watcher, revents, errno=None):
diff --git a/cassandra/io/libevwrapper.c b/cassandra/io/libevwrapper.c
index f32504fa34..fc25f9ceba 100644
--- a/cassandra/io/libevwrapper.c
+++ b/cassandra/io/libevwrapper.c
@@ -118,7 +118,13 @@ IO_dealloc(libevwrapper_IO *self) {
static void io_callback(struct ev_loop *loop, ev_io *watcher, int revents) {
libevwrapper_IO *self = watcher->data;
PyObject *result;
- PyGILState_STATE gstate = PyGILState_Ensure();
+ PyGILState_STATE gstate;
+
+ if (!self || !self->callback) {
+ return; // Skip callback if object is being destroyed
+ }
+
+ gstate = PyGILState_Ensure();
if (revents & EV_ERROR && errno) {
result = PyObject_CallFunction(self->callback, "Obi", self, revents, errno);
} else {
@@ -354,6 +360,10 @@ static void prepare_callback(struct ev_loop *loop, ev_prepare *watcher, int reve
PyObject *result = NULL;
PyGILState_STATE gstate;
+ if (!self || !self->callback) {
+ return; // Skip callback if object is being destroyed
+ }
+
gstate = PyGILState_Ensure();
result = PyObject_CallFunction(self->callback, "O", self);
if (!result) {
@@ -473,6 +483,10 @@ static void timer_callback(struct ev_loop *loop, ev_timer *watcher, int revents)
PyObject *result = NULL;
PyGILState_STATE gstate;
+ if (!self || !self->callback) {
+ return; // Skip callback if object is being destroyed
+ }
+
gstate = PyGILState_Ensure();
result = PyObject_CallFunction(self->callback, NULL);
if (!result) {
diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py
index e4605a7446..446200bf63 100644
--- a/cassandra/io/twistedreactor.py
+++ b/cassandra/io/twistedreactor.py
@@ -283,8 +283,10 @@ def close(self):
log.debug("Closed socket to %s", self.endpoint)
if not self.is_defunct:
- self.error_all_requests(
- ConnectionShutdown("Connection to %s was closed" % self.endpoint))
+ msg = "Connection to %s was closed" % self.endpoint
+ if self.last_error:
+ msg += ": %s" % (self.last_error,)
+ self.error_all_requests(ConnectionShutdown(msg))
# don't leave in-progress operations hanging
self.connected_event.set()
diff --git a/tests/integration/cqlengine/advanced/__init__.py b/cassandra/lwt_info.py
similarity index 60%
rename from tests/integration/cqlengine/advanced/__init__.py
rename to cassandra/lwt_info.py
index 386372eb4a..d64c08bbcf 100644
--- a/tests/integration/cqlengine/advanced/__init__.py
+++ b/cassandra/lwt_info.py
@@ -1,4 +1,4 @@
-# Copyright DataStax, Inc.
+# Copyright 2020 ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+class _LwtInfo:
+ """
+ Holds LWT-related information parsed from the server's supported features.
+ """
+
+ def __init__(self, lwt_meta_bit_mask):
+ self.lwt_meta_bit_mask = lwt_meta_bit_mask
+
+ def get_lwt_flag(self, flags):
+ return (flags & self.lwt_meta_bit_mask) == self.lwt_meta_bit_mask
diff --git a/cassandra/marshal.py b/cassandra/marshal.py
index a527a9e1d7..413e1831d4 100644
--- a/cassandra/marshal.py
+++ b/cassandra/marshal.py
@@ -33,11 +33,6 @@ def _make_packer(format_string):
float_pack, float_unpack = _make_packer('>f')
double_pack, double_unpack = _make_packer('>d')
-# Special case for cassandra header
-header_struct = struct.Struct('>BBbB')
-header_pack = header_struct.pack
-header_unpack = header_struct.unpack
-
# in protocol version 3 and higher, the stream ID is two bytes
v3_header_struct = struct.Struct('>BBhB')
v3_header_pack = v3_header_struct.pack
diff --git a/cassandra/metadata.py b/cassandra/metadata.py
index 30bcf81654..43399b7152 100644
--- a/cassandra/metadata.py
+++ b/cassandra/metadata.py
@@ -139,8 +139,9 @@ def export_schema_as_string(self):
def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None,
metadata_request_timeout=None, **kwargs):
- server_version = self.get_host(connection.original_endpoint).release_version
- dse_version = self.get_host(connection.original_endpoint).dse_version
+ host = self.get_host(connection.original_endpoint)
+ server_version = host.release_version if host else None
+ dse_version = host.dse_version if host else None
parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size)
if not target_type:
@@ -153,12 +154,7 @@ def refresh(self, connection, timeout, target_type=None, change_type=None, fetch
meta = parse_method(self.keyspaces, **kwargs)
if meta:
update_method = getattr(self, '_update_' + tt_lower)
- if tt_lower == 'keyspace' and connection.protocol_version < 3:
- # we didn't have 'type' target in legacy protocol versions, so we need to query those too
- user_types = parser.get_types_map(self.keyspaces, **kwargs)
- self._update_keyspace(meta, user_types)
- else:
- update_method(meta)
+ update_method(meta)
else:
drop_method = getattr(self, '_drop_' + tt_lower)
drop_method(**kwargs)
@@ -574,10 +570,11 @@ def __init__(self, options_map):
def make_token_replica_map(self, token_to_host_owner, ring):
replica_map = {}
- for i in range(len(ring)):
+ ring_len = len(ring)
+ for i in range(ring_len):
j, hosts = 0, list()
- while len(hosts) < self.replication_factor and j < len(ring):
- token = ring[(i + j) % len(ring)]
+ while len(hosts) < self.replication_factor and j < ring_len:
+ token = ring[(i + j) % ring_len]
host = token_to_host_owner[token]
if host not in hosts:
hosts.append(host)
@@ -634,10 +631,14 @@ def make_token_replica_map(self, token_to_host_owner, ring):
hosts_per_dc = defaultdict(set)
for i, token in enumerate(ring):
host = token_to_host_owner[token]
- dc_to_token_offset[host.datacenter].append(i)
- if host.datacenter and host.rack:
- dc_racks[host.datacenter].add(host.rack)
- hosts_per_dc[host.datacenter].add(host)
+ host_dc = host.datacenter
+ if host_dc in dc_rf_map:
+ # if the host is in a DC that has a replication factor, add it
+ # to the list of token offsets for that DC
+ dc_to_token_offset[host_dc].append(i)
+ if host.rack:
+ dc_racks[host_dc].add(host.rack)
+ hosts_per_dc[host_dc].add(host)
# A map of DCs to an index into the dc_to_token_offset value for that dc.
# This is how we keep track of advancing around the ring for each DC.
@@ -649,8 +650,6 @@ def make_token_replica_map(self, token_to_host_owner, ring):
# go through each DC and find the replicas in that DC
for dc in dc_to_token_offset.keys():
- if dc not in dc_rf_map:
- continue
# advance our per-DC index until we're up to at least the
# current token in the ring
@@ -662,34 +661,34 @@ def make_token_replica_map(self, token_to_host_owner, ring):
dc_to_current_index[dc] = index
replicas_remaining = dc_rf_map[dc]
- replicas_this_dc = 0
+ num_replicas_this_dc = 0
skipped_hosts = []
racks_placed = set()
- racks_this_dc = dc_racks[dc]
- hosts_this_dc = len(hosts_per_dc[dc])
+ num_racks_this_dc = len(dc_racks[dc])
+ num_hosts_this_dc = len(hosts_per_dc[dc])
for token_offset_index in range(index, index+num_tokens):
- if token_offset_index >= len(token_offsets):
- token_offset_index = token_offset_index - len(token_offsets)
+ if replicas_remaining == 0 or num_replicas_this_dc == num_hosts_this_dc:
+ break
+
+ if token_offset_index >= num_tokens:
+ token_offset_index = token_offset_index - num_tokens
token_offset = token_offsets[token_offset_index]
host = token_to_host_owner[ring[token_offset]]
- if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc:
- break
-
if host in replicas:
continue
- if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc):
+ if host.rack in racks_placed and len(racks_placed) < num_racks_this_dc:
skipped_hosts.append(host)
continue
replicas.append(host)
- replicas_this_dc += 1
+ num_replicas_this_dc += 1
replicas_remaining -= 1
racks_placed.add(host.rack)
- if len(racks_placed) == len(racks_this_dc):
+ if len(racks_placed) == num_racks_this_dc:
for host in skipped_hosts:
if replicas_remaining == 0:
break
@@ -1894,7 +1893,7 @@ def hash_fn(cls, key):
def __init__(self, token):
""" `token` is an int or string representing the token. """
- self.value = int(token)
+ super().__init__(int(token))
class MD5Token(HashToken):
@@ -2077,7 +2076,6 @@ def __init__(self, connection, timeout, fetch_size, metadata_request_timeout):
self.types_result = []
self.functions_result = []
self.aggregates_result = []
- self.scylla_result = []
self.keyspace_table_rows = defaultdict(list)
self.keyspace_table_col_rows = defaultdict(lambda: defaultdict(list))
@@ -2085,7 +2083,6 @@ def __init__(self, connection, timeout, fetch_size, metadata_request_timeout):
self.keyspace_func_rows = defaultdict(list)
self.keyspace_agg_rows = defaultdict(list)
self.keyspace_table_trigger_rows = defaultdict(lambda: defaultdict(list))
- self.keyspace_scylla_rows = defaultdict(lambda: defaultdict(list))
def get_all_keyspaces(self):
self._query_all()
@@ -2531,23 +2528,9 @@ def _query_all(self):
self._aggregate_results()
def _aggregate_results(self):
- m = self.keyspace_scylla_rows
- for row in self.scylla_result:
- ksname = row["keyspace_name"]
- cfname = row[self._table_name_col]
- m[ksname][cfname].append(row)
-
m = self.keyspace_table_rows
for row in self.tables_result:
ksname = row["keyspace_name"]
- cfname = row[self._table_name_col]
- # in_memory property is stored in scylla private table
- # add it to table properties if enabled
- try:
- if self.keyspace_scylla_rows[ksname][cfname][0]["in_memory"] == True:
- row["in_memory"] = True
- except (IndexError, KeyError):
- pass
m[ksname].append(row)
m = self.keyspace_table_col_rows
@@ -2593,7 +2576,10 @@ class SchemaParserV3(SchemaParserV22):
_SELECT_FUNCTIONS = "SELECT * FROM system_schema.functions"
_SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates"
_SELECT_VIEWS = "SELECT * FROM system_schema.views"
- _SELECT_SCYLLA = "SELECT * FROM system_schema.scylla_tables"
+
+ def _is_not_scylla(self):
+ """Check if NOT connected to ScyllaDB by checking for shard awareness."""
+ return getattr(getattr(self.connection, 'features', None), 'shard_id', None) is None
_table_name_col = 'table_name'
@@ -2645,40 +2631,44 @@ def get_table(self, keyspaces, keyspace, table):
indexes_query = QueryMessage(
query=maybe_add_timeout_to_query(self._SELECT_INDEXES + where_clause, self.metadata_request_timeout),
consistency_level=cl, fetch_size=fetch_size)
- triggers_query = QueryMessage(
- query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout),
- consistency_level=cl, fetch_size=fetch_size)
- scylla_query = QueryMessage(
- query=maybe_add_timeout_to_query(self._SELECT_SCYLLA + where_clause, self.metadata_request_timeout),
- consistency_level=cl, fetch_size=fetch_size)
+
+ # ScyllaDB doesn't have triggers, skip the query
+ if self._is_not_scylla():
+ triggers_query = QueryMessage(
+ query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout),
+ consistency_level=cl, fetch_size=fetch_size)
# in protocol v4 we don't know if this event is a view or a table, so we look for both
where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder)
view_query = QueryMessage(
query=maybe_add_timeout_to_query(self._SELECT_VIEWS + where_clause, self.metadata_request_timeout),
consistency_level=cl, fetch_size=fetch_size)
- ((cf_success, cf_result), (col_success, col_result),
- (indexes_sucess, indexes_result), (triggers_success, triggers_result),
- (view_success, view_result),
- (scylla_success, scylla_result)) = (
- self.connection.wait_for_responses(
- cf_query, col_query, indexes_query, triggers_query,
- view_query, scylla_query, timeout=self.timeout, fail_on_error=False)
- )
+
+ if self._is_not_scylla():
+ ((cf_success, cf_result), (col_success, col_result),
+ (indexes_sucess, indexes_result), (triggers_success, triggers_result),
+ (view_success, view_result)) = (
+ self.connection.wait_for_responses(
+ cf_query, col_query, indexes_query, triggers_query,
+ view_query, timeout=self.timeout, fail_on_error=False)
+ )
+ else:
+ ((cf_success, cf_result), (col_success, col_result),
+ (indexes_sucess, indexes_result),
+ (view_success, view_result)) = (
+ self.connection.wait_for_responses(
+ cf_query, col_query, indexes_query,
+ view_query, timeout=self.timeout, fail_on_error=False)
+ )
+
table_result = self._handle_results(cf_success, cf_result, query_msg=cf_query)
col_result = self._handle_results(col_success, col_result, query_msg=col_query)
if table_result:
indexes_result = self._handle_results(indexes_sucess, indexes_result, query_msg=indexes_query)
- triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query)
- # in_memory property is stored in scylla private table
- # add it to table properties if enabled
- scylla_result = self._handle_results(scylla_success, scylla_result, expected_failures=(InvalidRequest,),
- query_msg=scylla_query)
- try:
- if scylla_result[0]["in_memory"] == True:
- table_result[0]["in_memory"] = True
- except (IndexError, KeyError):
- pass
+ if self._is_not_scylla():
+ triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query)
+ else:
+ triggers_result = None
return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result)
view_result = self._handle_results(view_success, view_result, query_msg=view_query)
@@ -2727,9 +2717,10 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row
self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual)
- for trigger_row in trigger_rows:
- trigger_meta = self._build_trigger_metadata(table_meta, trigger_row)
- table_meta.triggers[trigger_meta.name] = trigger_meta
+ if self._is_not_scylla():
+ for trigger_row in trigger_rows:
+ trigger_meta = self._build_trigger_metadata(table_meta, trigger_row)
+ table_meta.triggers[trigger_meta.name] = trigger_meta
for index_row in index_rows:
index_meta = self._build_index_metadata(table_meta, index_row)
@@ -2772,7 +2763,7 @@ def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=Fa
meta.clustering_key.append(meta.columns[r.get('column_name')])
for col_row in (r for r in col_rows
- if r.get('kind', None) not in ('partition_key', 'clustering_key')):
+ if r.get('kind', None) not in ('partition_key', 'clustering')):
column_meta = self._build_column_metadata(meta, col_row)
if is_dense and column_meta.cql_type == types.cql_empty_type:
continue
@@ -2824,6 +2815,7 @@ def _build_trigger_metadata(table_metadata, row):
trigger_meta = TriggerMetadata(table_metadata, name, options)
return trigger_meta
+
def _query_all(self):
cl = ConsistencyLevel.ONE
fetch_size = self.fetch_size
@@ -2840,39 +2832,45 @@ def _query_all(self):
fetch_size=fetch_size, consistency_level=cl),
QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout),
fetch_size=fetch_size, consistency_level=cl),
- QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout),
- fetch_size=fetch_size, consistency_level=cl),
QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout),
fetch_size=fetch_size, consistency_level=cl),
QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout),
fetch_size=fetch_size, consistency_level=cl),
- QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCYLLA, self.metadata_request_timeout),
- fetch_size=fetch_size, consistency_level=cl),
]
+ # ScyllaDB doesn't have triggers, skip the query
+ if self._is_not_scylla():
+ queries.append(QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout),
+ fetch_size=fetch_size, consistency_level=cl))
+
+ responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False)
+
+ # Unpack common responses (always present)
((ks_success, ks_result),
(table_success, table_result),
(col_success, col_result),
(types_success, types_result),
(functions_success, functions_result),
(aggregates_success, aggregates_result),
- (triggers_success, triggers_result),
(indexes_success, indexes_result),
- (views_success, views_result),
- (scylla_success, scylla_result)) = self.connection.wait_for_responses(
- *queries, timeout=self.timeout, fail_on_error=False
- )
+ (views_success, views_result)) = responses[:8]
+
+ # Unpack triggers response if present (Cassandra/DSE only)
+ if self._is_not_scylla():
+ (triggers_success, triggers_result) = responses[8]
self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0])
self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1])
self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2])
- self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6])
self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3])
self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4])
self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5])
- self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7])
- self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8])
- self.scylla_result = self._handle_results(scylla_success, scylla_result, expected_failures=(InvalidRequest,), query_msg=queries[9])
+ self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[6])
+ self.views_result = self._handle_results(views_success, views_result, query_msg=queries[7])
+ if self._is_not_scylla():
+ self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[8])
+ else:
+ self.triggers_result = []
self._aggregate_results()
@@ -2950,8 +2948,6 @@ def _query_all(self):
fetch_size=fetch_size, consistency_level=cl),
QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout),
fetch_size=fetch_size, consistency_level=cl),
- QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout),
- fetch_size=fetch_size, consistency_level=cl),
QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout),
fetch_size=fetch_size, consistency_level=cl),
QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout),
@@ -2965,8 +2961,15 @@ def _query_all(self):
fetch_size=fetch_size, consistency_level=cl),
]
+ # ScyllaDB doesn't have triggers, skip the query
+ if self._is_not_scylla():
+ queries.append(QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout),
+ fetch_size=fetch_size, consistency_level=cl))
+
responses = self.connection.wait_for_responses(
*queries, timeout=self.timeout, fail_on_error=False)
+
+ # Unpack common responses (always present)
(
# copied from V3
(ks_success, ks_result),
@@ -2975,39 +2978,45 @@ def _query_all(self):
(types_success, types_result),
(functions_success, functions_result),
(aggregates_success, aggregates_result),
- (triggers_success, triggers_result),
(indexes_success, indexes_result),
(views_success, views_result),
# V4-only responses
(virtual_ks_success, virtual_ks_result),
(virtual_table_success, virtual_table_result),
- (virtual_column_success, virtual_column_result)
- ) = responses
+ (virtual_column_success, virtual_column_result),
+ ) = responses[:11]
+
+ # Unpack triggers response if present (Cassandra/DSE only)
+ if self._is_not_scylla():
+ (triggers_success, triggers_result) = responses[11]
# copied from V3
self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0])
self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1])
self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2])
- self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6])
self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3])
self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4])
self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5])
- self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7])
- self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8])
+ self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[6])
+ self.views_result = self._handle_results(views_success, views_result, query_msg=queries[7])
+ if self._is_not_scylla():
+ self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[11])
+ else:
+ self.triggers_result = []
# V4-only results
# These tables don't exist in some DSE versions reporting 4.X so we can
# ignore them if we got an error
self.virtual_keyspaces_result = self._handle_results(
virtual_ks_success, virtual_ks_result,
- expected_failures=(InvalidRequest,), query_msg=queries[9]
+ expected_failures=(InvalidRequest,), query_msg=queries[8]
)
self.virtual_tables_result = self._handle_results(
virtual_table_success, virtual_table_result,
- expected_failures=(InvalidRequest,), query_msg=queries[10]
+ expected_failures=(InvalidRequest,), query_msg=queries[9]
)
self.virtual_columns_result = self._handle_results(
virtual_column_success, virtual_column_result,
- expected_failures=(InvalidRequest,), query_msg=queries[11]
+ expected_failures=(InvalidRequest,), query_msg=queries[10]
)
self._aggregate_results()
@@ -3445,8 +3454,27 @@ def __init__(
self.to_clustering_columns = to_clustering_columns
+def get_column_from_system_local(connection, column_name: str, timeout, metadata_request_timeout) -> str:
+ success, local_result = connection.wait_for_response(
+ QueryMessage(
+ query=maybe_add_timeout_to_query(
+ "SELECT " + column_name + " FROM system.local WHERE key='local'",
+ metadata_request_timeout),
+ consistency_level=ConsistencyLevel.ONE)
+ , timeout=timeout, fail_on_error=False)
+ if not success or not local_result.parsed_rows:
+ return ""
+ local_rows = dict_factory(local_result.column_names, local_result.parsed_rows)
+ local_row = local_rows[0]
+ return local_row.get(column_name)
+
+
def get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size=None):
- version = Version(server_version)
+ if server_version is None and dse_version is None:
+ server_version = get_column_from_system_local(connection, "release_version", timeout, metadata_request_timeout)
+ dse_version = get_column_from_system_local(connection, "dse_version", timeout, metadata_request_timeout)
+
+ version = Version(server_version or "0")
if dse_version:
v = Version(dse_version)
if v >= Version('6.8.0'):
@@ -3497,7 +3525,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
:class:`~.NO_VALID_REPLICA`
Example usage::
-
+
>>> result = group_keys_by_replica(
... session, "system", "peers",
... (("127.0.0.1", ), ("127.0.0.2", )))
diff --git a/cassandra/metrics.py b/cassandra/metrics.py
index 223b0c7c6e..7ff44107af 100644
--- a/cassandra/metrics.py
+++ b/cassandra/metrics.py
@@ -12,19 +12,326 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Driver metrics collection module.
+
+This module provides metrics collection functionality without external dependencies.
+It was originally based on the `scales` library but now uses a self-contained
+implementation.
+"""
+
from itertools import chain
import logging
-
-try:
- from greplin import scales
-except ImportError:
- raise ImportError(
- "The scales library is required for metrics support: "
- "https://pypi.org/project/scales/")
+import math
+import random
+import threading
log = logging.getLogger(__name__)
+# Global stats registry
+_stats_registry = {}
+_registry_lock = threading.Lock()
+
+
+def getStats():
+ """
+ Returns a copy of all registered stats.
+ """
+ with _registry_lock:
+ return {name: stats._get_stats_dict() for name, stats in _stats_registry.items()}
+
+
+class IntStat:
+ """
+ A thread-safe integer counter statistic.
+ """
+ __slots__ = ('name', '_value', '_lock')
+
+ def __init__(self, name):
+ self.name = name
+ self._value = 0
+ self._lock = threading.Lock()
+
+ def __iadd__(self, other):
+ with self._lock:
+ self._value += other
+ return self
+
+ def __int__(self):
+ return self._value
+
+ def __eq__(self, other):
+ if isinstance(other, IntStat):
+ return self._value == other._value
+ return self._value == other
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __lt__(self, other):
+ if isinstance(other, IntStat):
+ return self._value < other._value
+ return self._value < other
+
+ def __le__(self, other):
+ if isinstance(other, IntStat):
+ return self._value <= other._value
+ return self._value <= other
+
+ def __gt__(self, other):
+ if isinstance(other, IntStat):
+ return self._value > other._value
+ return self._value > other
+
+ def __ge__(self, other):
+ if isinstance(other, IntStat):
+ return self._value >= other._value
+ return self._value >= other
+
+ def __hash__(self):
+ return hash(self._value)
+
+ def __repr__(self):
+ return f"IntStat({self.name}={self._value})"
+
+ @property
+ def value(self):
+ return self._value
+
+
+class Stat:
+ """
+ A gauge statistic that evaluates a callable on access.
+ """
+ __slots__ = ('name', '_func')
+
+ def __init__(self, name, func):
+ self.name = name
+ self._func = func
+
+ @property
+ def value(self):
+ return self._func()
+
+ def __repr__(self):
+ return f"Stat({self.name}={self.value})"
+
+
+class PmfStat:
+ """
+ A probability mass function statistic that tracks timing/size distributions.
+
+ Computes count, min, max, mean, stddev, median and various percentiles.
+ Uses reservoir sampling to limit memory usage for large sample counts.
+ """
+ __slots__ = ('name', '_values', '_lock', '_count', '_min', '_max', '_sum', '_sum_sq')
+
+ # Maximum number of values to retain for percentile calculations
+ _MAX_SAMPLES = 10000
+
+ def __init__(self, name):
+ self.name = name
+ self._values = []
+ self._lock = threading.Lock()
+ self._count = 0
+ self._min = float('inf')
+ self._max = float('-inf')
+ self._sum = 0.0
+ self._sum_sq = 0.0
+
+ def addValue(self, value):
+ """Record a new value."""
+ with self._lock:
+ self._count += 1
+ self._sum += value
+ self._sum_sq += value * value
+
+ if value < self._min:
+ self._min = value
+ if value > self._max:
+ self._max = value
+
+ # Reservoir sampling for percentiles
+ if len(self._values) < self._MAX_SAMPLES:
+ self._values.append(value)
+ else:
+ # Replace random element with decreasing probability
+ idx = random.randint(0, self._count - 1)
+ if idx < self._MAX_SAMPLES:
+ self._values[idx] = value
+
+ def _percentile(self, sorted_values, p):
+ """Calculate the p-th percentile from sorted values."""
+ if not sorted_values:
+ return 0.0
+ k = (len(sorted_values) - 1) * p / 100.0
+ f = math.floor(k)
+ c = math.ceil(k)
+ if f == c:
+ return sorted_values[int(k)]
+ return sorted_values[int(f)] * (c - k) + sorted_values[int(c)] * (k - f)
+
+ def _get_stats(self):
+ """Calculate all statistics."""
+ with self._lock:
+ count = self._count
+ if count == 0:
+ return {
+ 'count': 0,
+ 'min': 0.0,
+ 'max': 0.0,
+ 'mean': 0.0,
+ 'stddev': 0.0,
+ 'median': 0.0,
+ '75percentile': 0.0,
+ '95percentile': 0.0,
+ '98percentile': 0.0,
+ '99percentile': 0.0,
+ '999percentile': 0.0,
+ }
+
+ mean = self._sum / count
+
+ # Calculate stddev using Welford's algorithm values
+ variance = (self._sum_sq / count) - (mean * mean)
+ stddev = math.sqrt(max(0, variance)) # max to handle floating point errors
+
+ sorted_values = sorted(self._values)
+
+ return {
+ 'count': count,
+ 'min': self._min,
+ 'max': self._max,
+ 'mean': mean,
+ 'stddev': stddev,
+ 'median': self._percentile(sorted_values, 50),
+ '75percentile': self._percentile(sorted_values, 75),
+ '95percentile': self._percentile(sorted_values, 95),
+ '98percentile': self._percentile(sorted_values, 98),
+ '99percentile': self._percentile(sorted_values, 99),
+ '999percentile': self._percentile(sorted_values, 99.9),
+ }
+
+ def __getitem__(self, key):
+ return self._get_stats()[key]
+
+ def __iter__(self):
+ return iter(self._get_stats())
+
+ def keys(self):
+ return self._get_stats().keys()
+
+ def items(self):
+ return self._get_stats().items()
+
+ def values(self):
+ return self._get_stats().values()
+
+ def __repr__(self):
+ return f"PmfStat({self.name}, count={self._count})"
+
+
+class StatsCollection:
+ """
+ A named collection of statistics.
+ """
+ __slots__ = ('_name', '_stats', '_int_stats', '_pmf_stats', '_gauge_stats')
+
+ def __init__(self, name, *stats):
+ self._name = name
+ self._stats = {}
+ self._int_stats = {}
+ self._pmf_stats = {}
+ self._gauge_stats = {}
+
+ for stat in stats:
+ self._stats[stat.name] = stat
+ if isinstance(stat, IntStat):
+ self._int_stats[stat.name] = stat
+ elif isinstance(stat, PmfStat):
+ self._pmf_stats[stat.name] = stat
+ elif isinstance(stat, Stat):
+ self._gauge_stats[stat.name] = stat
+
+ def __getattr__(self, name):
+ if name.startswith('_'):
+ raise AttributeError(name)
+ try:
+ stats = object.__getattribute__(self, '_stats')
+ if name in stats:
+ return stats[name]
+ except AttributeError:
+ pass
+ raise AttributeError(f"No stat named '{name}'")
+
+ def __setattr__(self, name, value):
+ if name.startswith('_'):
+ object.__setattr__(self, name, value)
+ return
+ # Allow rebinding stats (e.g., for augmented assignment like stats.errors += 1)
+ try:
+ stats = object.__getattribute__(self, '_stats')
+ if name in stats:
+ # For augmented assignment, value should be the same IntStat/PmfStat object
+ # Just verify and allow the rebind
+ return
+ except AttributeError:
+ pass
+ raise AttributeError(f"Cannot set attribute '{name}' on StatsCollection")
+
+ def _get_stats_dict(self):
+ """Return dictionary representation of all stats."""
+ result = {}
+ for name, stat in self._int_stats.items():
+ result[name] = stat.value
+ for name, stat in self._pmf_stats.items():
+ result[name] = stat._get_stats()
+ for name, stat in self._gauge_stats.items():
+ result[name] = stat.value
+ return result
+
+
+def collection(name, *stats):
+ """
+ Create a named collection of statistics and register it globally.
+ """
+ coll = StatsCollection(name, *stats)
+ with _registry_lock:
+ _stats_registry[name] = coll
+ return coll
+
+
+def init(obj, path):
+ """
+ Initialize class-level stats on an instance and register in the global registry.
+
+ This allows class-level PmfStat/IntStat descriptors to be used per-instance.
+ """
+ # Get class-level stats and create instance copies
+ cls = obj.__class__
+ instance_stats = {}
+
+ for attr_name in dir(cls):
+ attr = getattr(cls, attr_name, None)
+ if isinstance(attr, (PmfStat, IntStat)):
+ # Create a new instance of the stat for this object
+ if isinstance(attr, PmfStat):
+ new_stat = PmfStat(attr.name)
+ else:
+ new_stat = IntStat(attr.name)
+ instance_stats[attr_name] = new_stat
+ # Set on instance to shadow class attribute
+ object.__setattr__(obj, attr_name, new_stat)
+
+ # Register under the given path (remove leading /)
+ reg_name = path.lstrip('/')
+ if instance_stats:
+ stats_coll = StatsCollection(reg_name, *instance_stats.values())
+ with _registry_lock:
+ _stats_registry[reg_name] = stats_coll
+
+
class Metrics(object):
"""
A collection of timers and counters for various performance metrics.
@@ -34,7 +341,7 @@ class Metrics(object):
request_timer = None
"""
- A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like
+ A :class:`~cassandra.metrics.PmfStat` timer for requests. This is a dict-like
object with the following keys:
* count - number of requests that have been timed
@@ -52,64 +359,64 @@ class Metrics(object):
connection_errors = None
"""
- A :class:`greplin.scales.IntStat` count of the number of times that a
+ A :class:`~cassandra.metrics.IntStat` count of the number of times that a
request to a Cassandra node has failed due to a connection problem.
"""
write_timeouts = None
"""
- A :class:`greplin.scales.IntStat` count of write requests that resulted
+ A :class:`~cassandra.metrics.IntStat` count of write requests that resulted
in a timeout.
"""
read_timeouts = None
"""
- A :class:`greplin.scales.IntStat` count of read requests that resulted
+ A :class:`~cassandra.metrics.IntStat` count of read requests that resulted
in a timeout.
"""
unavailables = None
"""
- A :class:`greplin.scales.IntStat` count of write or read requests that
+ A :class:`~cassandra.metrics.IntStat` count of write or read requests that
failed due to an insufficient number of replicas being alive to meet
the requested :class:`.ConsistencyLevel`.
"""
other_errors = None
"""
- A :class:`greplin.scales.IntStat` count of all other request failures,
+ A :class:`~cassandra.metrics.IntStat` count of all other request failures,
including failures caused by invalid requests, bootstrapping nodes,
overloaded nodes, etc.
"""
retries = None
"""
- A :class:`greplin.scales.IntStat` count of the number of times a
+ A :class:`~cassandra.metrics.IntStat` count of the number of times a
request was retried based on the :class:`.RetryPolicy` decision.
"""
ignores = None
"""
- A :class:`greplin.scales.IntStat` count of the number of times a
+ A :class:`~cassandra.metrics.IntStat` count of the number of times a
failed request was ignored based on the :class:`.RetryPolicy` decision.
"""
known_hosts = None
"""
- A :class:`greplin.scales.IntStat` count of the number of nodes in
+ A :class:`~cassandra.metrics.IntStat` count of the number of nodes in
the cluster that the driver is aware of, regardless of whether any
connections are opened to those nodes.
"""
connected_to = None
"""
- A :class:`greplin.scales.IntStat` count of the number of nodes that
+ A :class:`~cassandra.metrics.IntStat` count of the number of nodes that
the driver currently has at least one connection open to.
"""
open_connections = None
"""
- A :class:`greplin.scales.IntStat` count of the number connections
+ A :class:`~cassandra.metrics.IntStat` count of the number connections
the driver currently has open.
"""
@@ -120,28 +427,29 @@ def __init__(self, cluster_proxy):
self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter))
Metrics._stats_counter += 1
- self.stats = scales.collection(self.stats_name,
- scales.PmfStat('request_timer'),
- scales.IntStat('connection_errors'),
- scales.IntStat('write_timeouts'),
- scales.IntStat('read_timeouts'),
- scales.IntStat('unavailables'),
- scales.IntStat('other_errors'),
- scales.IntStat('retries'),
- scales.IntStat('ignores'),
+ self.stats = collection(self.stats_name,
+ PmfStat('request_timer'),
+ IntStat('connection_errors'),
+ IntStat('write_timeouts'),
+ IntStat('read_timeouts'),
+ IntStat('unavailables'),
+ IntStat('other_errors'),
+ IntStat('retries'),
+ IntStat('ignores'),
# gauges
- scales.Stat('known_hosts',
+ Stat('known_hosts',
lambda: len(cluster_proxy.metadata.all_hosts())),
- scales.Stat('connected_to',
- lambda: len(set(chain.from_iterable(s._pools.keys() for s in cluster_proxy.sessions)))),
- scales.Stat('open_connections',
- lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions)))
+ Stat('connected_to',
+ lambda: len(set(chain.from_iterable(list(s._pools.keys()) for s in cluster_proxy.sessions)))),
+ Stat('open_connections',
+ lambda: sum(sum(p.open_count for p in list(s._pools.values())) for s in cluster_proxy.sessions)))
# TODO, to be removed in 4.0
# /cassandra contains the metrics of the first cluster registered
- if 'cassandra' not in scales._Stats.stats:
- scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name]
+ with _registry_lock:
+ if 'cassandra' not in _stats_registry:
+ _stats_registry['cassandra'] = _stats_registry[self.stats_name]
self.request_timer = self.stats.request_timer
self.connection_errors = self.stats.connection_errors
@@ -180,22 +488,35 @@ def get_stats(self):
"""
Returns the metrics for the registered cluster instance.
"""
- return scales.getStats()[self.stats_name]
+ return getStats()[self.stats_name]
def set_stats_name(self, stats_name):
"""
Set the metrics stats name.
- The stats_name is a string used to access the metris through scales: scales.getStats()[]
+ The stats_name is a string used to access the metrics through getStats(): getStats()[]
Default is 'cassandra-'.
"""
if self.stats_name == stats_name:
return
- if stats_name in scales._Stats.stats:
- raise ValueError('"{0}" already exists in stats.'.format(stats_name))
+ with _registry_lock:
+ if stats_name in _stats_registry:
+ raise ValueError('"{0}" already exists in stats.'.format(stats_name))
- stats = scales._Stats.stats[self.stats_name]
- del scales._Stats.stats[self.stats_name]
- self.stats_name = stats_name
- scales._Stats.stats[self.stats_name] = stats
+ stats = _stats_registry[self.stats_name]
+ del _stats_registry[self.stats_name]
+ self.stats_name = stats_name
+ _stats_registry[self.stats_name] = stats
+
+ def shutdown(self):
+ """
+ Remove this metrics instance from the global registry.
+ Called when the cluster is shutdown to prevent stale references.
+ """
+ with _registry_lock:
+ if self.stats_name in _stats_registry:
+ del _stats_registry[self.stats_name]
+ # Also clean up the legacy 'cassandra' entry if it points to our stats
+ if _stats_registry.get('cassandra') is self.stats:
+ del _stats_registry['cassandra']
diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx
index 030c2c65c7..0ad34f66e2 100644
--- a/cassandra/numpy_parser.pyx
+++ b/cassandra/numpy_parser.pyx
@@ -181,5 +181,6 @@ def make_native_byteorder(arr):
# accordingly (e.g. from '>i8' to ' time.time()) or \
self._session.cluster.shard_aware_options.disable_shardaware_port:
return None
@@ -756,23 +751,26 @@ def _open_connection_to_missing_shard(self, shard_id):
)
old_conn = None
with self._lock:
- if self.is_shutdown:
- conn.close()
- return
- if conn.features.shard_id in self._connections.keys():
- # Move the current connection to the trash and use the new one from now on
- old_conn = self._connections[conn.features.shard_id]
- log.debug(
- "Replacing overloaded connection (%s) with (%s) for shard %i for host %s",
- id(old_conn),
- id(conn),
- conn.features.shard_id,
- self.host
- )
- if self._keyspace:
- conn.set_keyspace_blocking(self._keyspace)
+ is_shutdown = self.is_shutdown
+ if not is_shutdown:
+ if conn.features.shard_id in self._connections:
+ # Move the current connection to the trash and use the new one from now on
+ old_conn = self._connections[conn.features.shard_id]
+ log.debug(
+ "Replacing overloaded connection (%s) with (%s) for shard %i for host %s",
+ id(old_conn),
+ id(conn),
+ conn.features.shard_id,
+ self.host
+ )
+ if self._keyspace:
+ conn.set_keyspace_blocking(self._keyspace)
+ self._connections[conn.features.shard_id] = conn
+
+ if is_shutdown:
+ conn.close()
+ return
- self._connections[conn.features.shard_id] = conn
if old_conn is not None:
remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids)
if remaining == 0:
@@ -792,14 +790,15 @@ def _open_connection_to_missing_shard(self, shard_id):
remaining,
)
with self._lock:
- if self.is_shutdown:
- old_conn.close()
- else:
+ is_shutdown = self.is_shutdown
+ if not is_shutdown:
self._trash.add(old_conn)
+ if is_shutdown:
+ conn.close()
num_missing_or_needing_replacement = self.num_missing_or_needing_replacement
log.debug(
"Connected to %s/%i shards on host %s (%i missing or needs replacement)",
- len(self._connections.keys()),
+ len(self._connections),
self.host.sharding_info.shards_count,
self.host,
num_missing_or_needing_replacement
@@ -811,7 +810,7 @@ def _open_connection_to_missing_shard(self, shard_id):
len(self._excess_connections)
)
self._close_excess_connections()
- elif self.host.sharding_info.shards_count == len(self._connections.keys()) and self.num_missing_or_needing_replacement == 0:
+ elif self.host.sharding_info.shards_count == len(self._connections) and self.num_missing_or_needing_replacement == 0:
log.debug(
"All shards are already covered, closing newly opened excess connection %s for host %s",
id(self),
@@ -912,7 +911,7 @@ def get_state(self):
@property
def num_missing_or_needing_replacement(self):
return self.host.sharding_info.shards_count \
- - sum(1 for c in self._connections.values() if not c.orphaned_threshold_reached)
+ - sum(1 for c in list(self._connections.values()) if not c.orphaned_threshold_reached)
@property
def open_count(self):
@@ -923,362 +922,3 @@ def _excess_connection_limit(self):
return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier
-_MAX_SIMULTANEOUS_CREATION = 1
-_MIN_TRASH_INTERVAL = 10
-
-
-class HostConnectionPool(object):
- """
- Used to pool connections to a host for v1 and v2 native protocol.
- """
-
- host = None
- host_distance = None
-
- is_shutdown = False
- open_count = 0
- _scheduled_for_creation = 0
- _next_trash_allowed_at = 0
- _keyspace = None
-
- def __init__(self, host, host_distance, session):
- self.host = host
- self.host_distance = host_distance
-
- self._session = weakref.proxy(session)
- self._lock = RLock()
- self._conn_available_condition = Condition()
-
- log.debug("Initializing new connection pool for host %s", self.host)
- core_conns = session.cluster.get_core_connections_per_host(host_distance)
- self._connections = [session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
- for i in range(core_conns)]
-
- self._keyspace = session.keyspace
- if self._keyspace:
- for conn in self._connections:
- conn.set_keyspace_blocking(self._keyspace)
-
- self._trash = set()
- self._next_trash_allowed_at = time.time()
- self.open_count = core_conns
- log.debug("Finished initializing new connection pool for host %s", self.host)
-
- def borrow_connection(self, timeout, routing_key=None):
- if self.is_shutdown:
- raise ConnectionException(
- "Pool for %s is shutdown" % (self.host,), self.host)
-
- conns = self._connections
- if not conns:
- # handled specially just for simpler code
- log.debug("Detected empty pool, opening core conns to %s", self.host)
- core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
- with self._lock:
- # we check the length of self._connections again
- # along with self._scheduled_for_creation while holding the lock
- # in case multiple threads hit this condition at the same time
- to_create = core_conns - (len(self._connections) + self._scheduled_for_creation)
- for i in range(to_create):
- self._scheduled_for_creation += 1
- self._session.submit(self._create_new_connection)
-
- # in_flight is incremented by wait_for_conn
- conn = self._wait_for_conn(timeout)
- return conn
- else:
- # note: it would be nice to push changes to these config settings
- # to pools instead of doing a new lookup on every
- # borrow_connection() call
- max_reqs = self._session.cluster.get_max_requests_per_connection(self.host_distance)
- max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance)
-
- least_busy = min(conns, key=lambda c: c.in_flight)
- request_id = None
- # to avoid another thread closing this connection while
- # trashing it (through the return_connection process), hold
- # the connection lock from this point until we've incremented
- # its in_flight count
- need_to_wait = False
- with least_busy.lock:
- if least_busy.in_flight < least_busy.max_request_id:
- least_busy.in_flight += 1
- request_id = least_busy.get_request_id()
- else:
- # once we release the lock, wait for another connection
- need_to_wait = True
-
- if need_to_wait:
- # wait_for_conn will increment in_flight on the conn
- least_busy, request_id = self._wait_for_conn(timeout)
-
- # if we have too many requests on this connection but we still
- # have space to open a new connection against this host, go ahead
- # and schedule the creation of a new connection
- if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns:
- self._maybe_spawn_new_connection()
-
- return least_busy, request_id
-
- def _maybe_spawn_new_connection(self):
- with self._lock:
- if self._scheduled_for_creation >= _MAX_SIMULTANEOUS_CREATION:
- return
- if self.open_count >= self._session.cluster.get_max_connections_per_host(self.host_distance):
- return
- self._scheduled_for_creation += 1
-
- log.debug("Submitting task for creation of new Connection to %s", self.host)
- self._session.submit(self._create_new_connection)
-
- def _create_new_connection(self):
- try:
- self._add_conn_if_under_max()
- except (ConnectionException, socket.error) as exc:
- log.warning("Failed to create new connection to %s: %s", self.host, exc)
- except Exception:
- log.exception("Unexpectedly failed to create new connection")
- finally:
- with self._lock:
- self._scheduled_for_creation -= 1
-
- def _add_conn_if_under_max(self):
- max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance)
- with self._lock:
- if self.is_shutdown:
- return True
-
- if self.open_count >= max_conns:
- return True
-
- self.open_count += 1
-
- log.debug("Going to open new connection to host %s", self.host)
- try:
- conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
- if self._keyspace:
- conn.set_keyspace_blocking(self._session.keyspace)
- self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
- with self._lock:
- new_connections = self._connections[:] + [conn]
- self._connections = new_connections
- log.debug("Added new connection (%s) to pool for host %s, signaling availablility",
- id(conn), self.host)
- self._signal_available_conn()
- return True
- except (ConnectionException, socket.error) as exc:
- log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc)
- with self._lock:
- self.open_count -= 1
- if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):
- self.shutdown()
- return False
- except AuthenticationFailed:
- with self._lock:
- self.open_count -= 1
- return False
-
- def _await_available_conn(self, timeout):
- with self._conn_available_condition:
- self._conn_available_condition.wait(timeout)
-
- def _signal_available_conn(self):
- with self._conn_available_condition:
- self._conn_available_condition.notify()
-
- def _signal_all_available_conn(self):
- with self._conn_available_condition:
- self._conn_available_condition.notify_all()
-
- def _wait_for_conn(self, timeout):
- start = time.time()
- remaining = timeout
-
- while remaining > 0:
- # wait on our condition for the possibility that a connection
- # is useable
- self._await_available_conn(remaining)
-
- # self.shutdown() may trigger the above Condition
- if self.is_shutdown:
- raise ConnectionException("Pool is shutdown")
-
- conns = self._connections
- if conns:
- least_busy = min(conns, key=lambda c: c.in_flight)
- with least_busy.lock:
- if least_busy.in_flight < least_busy.max_request_id:
- least_busy.in_flight += 1
- return least_busy, least_busy.get_request_id()
-
- remaining = timeout - (time.time() - start)
-
- raise NoConnectionsAvailable()
-
- def return_connection(self, connection, stream_was_orphaned=False):
- with connection.lock:
- if not stream_was_orphaned:
- connection.in_flight -= 1
- in_flight = connection.in_flight
-
- if connection.is_defunct or connection.is_closed:
- if not connection.signaled_error:
- log.debug("Defunct or closed connection (%s) returned to pool, potentially "
- "marking host %s as down", id(connection), self.host)
- is_down = self._session.cluster.signal_connection_failure(
- self.host, connection.last_error, is_host_addition=False)
- connection.signaled_error = True
- if is_down:
- self.shutdown()
- else:
- self._replace(connection)
- else:
- if connection in self._trash:
- with connection.lock:
- if connection.in_flight == 0:
- with self._lock:
- if connection in self._trash:
- self._trash.remove(connection)
- log.debug("Closing trashed connection (%s) to %s", id(connection), self.host)
- connection.close()
- return
-
- core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
- min_reqs = self._session.cluster.get_min_requests_per_connection(self.host_distance)
- # we can use in_flight here without holding the connection lock
- # because the fact that in_flight dipped below the min at some
- # point is enough to start the trashing procedure
- if len(self._connections) > core_conns and in_flight <= min_reqs and \
- time.time() >= self._next_trash_allowed_at:
- self._maybe_trash_connection(connection)
- else:
- self._signal_available_conn()
-
- def on_orphaned_stream_released(self):
- """
- Called when a response for an orphaned stream (timed out on the client
- side) was received.
- """
- self._signal_available_conn()
-
- def _maybe_trash_connection(self, connection):
- core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
- did_trash = False
- with self._lock:
- if connection not in self._connections:
- return
-
- if self.open_count > core_conns:
- did_trash = True
- self.open_count -= 1
- new_connections = self._connections[:]
- new_connections.remove(connection)
- self._connections = new_connections
-
- with connection.lock:
- if connection.in_flight == 0:
- log.debug("Skipping trash and closing unused connection (%s) to %s", id(connection), self.host)
- connection.close()
-
- # skip adding it to the trash if we're already closing it
- return
-
- self._trash.add(connection)
-
- if did_trash:
- self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
- log.debug("Trashed connection (%s) to %s", id(connection), self.host)
-
- def _replace(self, connection):
- should_replace = False
- with self._lock:
- if connection in self._connections:
- new_connections = self._connections[:]
- new_connections.remove(connection)
- self._connections = new_connections
- self.open_count -= 1
- should_replace = True
-
- if should_replace:
- log.debug("Replacing connection (%s) to %s", id(connection), self.host)
- connection.close()
- self._session.submit(self._retrying_replace)
- else:
- log.debug("Closing connection (%s) to %s", id(connection), self.host)
- connection.close()
-
- def _retrying_replace(self):
- replaced = False
- try:
- replaced = self._add_conn_if_under_max()
- except Exception:
- log.exception("Failed replacing connection to %s", self.host)
- if not replaced:
- log.debug("Failed replacing connection to %s. Retrying.", self.host)
- self._session.submit(self._retrying_replace)
-
- def shutdown(self):
- with self._lock:
- if self.is_shutdown:
- return
- else:
- self.is_shutdown = True
-
- self._signal_all_available_conn()
-
- connections_to_close = []
- with self._lock:
- connections_to_close.extend(self._connections)
- self.open_count -= len(self._connections)
- self._connections.clear()
- connections_to_close.extend(self._trash)
- self._trash.clear()
-
- for conn in connections_to_close:
- conn.close()
-
- def ensure_core_connections(self):
- if self.is_shutdown:
- return
-
- core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
- with self._lock:
- to_create = core_conns - (len(self._connections) + self._scheduled_for_creation)
- for i in range(to_create):
- self._scheduled_for_creation += 1
- self._session.submit(self._create_new_connection)
-
- def _set_keyspace_for_all_conns(self, keyspace, callback):
- """
- Asynchronously sets the keyspace for all connections. When all
- connections have been set, `callback` will be called with two
- arguments: this pool, and a list of any errors that occurred.
- """
- remaining_callbacks = set(self._connections)
- errors = []
-
- if not remaining_callbacks:
- callback(self, errors)
- return
-
- def connection_finished_setting_keyspace(conn, error):
- self.return_connection(conn)
- remaining_callbacks.remove(conn)
- if error:
- errors.append(error)
-
- if not remaining_callbacks:
- callback(self, errors)
-
- self._keyspace = keyspace
- for conn in self._connections:
- conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
-
- def get_connections(self):
- return self._connections
-
- def get_state(self):
- in_flights = [c.in_flight for c in self._connections]
- orphan_requests = [c.orphaned_request_ids for c in self._connections]
- return {'shutdown': self.is_shutdown, 'open_count': self.open_count, \
- 'in_flights': in_flights, 'orphan_requests': orphan_requests}
diff --git a/cassandra/protocol.py b/cassandra/protocol.py
index 29ae404048..4628c7ee0e 100644
--- a/cassandra/protocol.py
+++ b/cassandra/protocol.py
@@ -36,7 +36,7 @@
TupleType, lookup_casstype, SimpleDateType,
TimeType, ByteType, ShortType, DurationType)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
- uint8_pack, int8_unpack, uint64_pack, header_pack,
+ uint8_pack, int8_unpack, uint64_pack,
v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack)
from cassandra.policies import ColDesc
from cassandra import WriteType
@@ -553,7 +553,6 @@ def __init__(self, query_params, consistency_level,
self.paging_state = paging_state
self.timestamp = timestamp
self.skip_meta = skip_meta
- self.continuous_paging_options = continuous_paging_options
self.keyspace = keyspace
def _write_query_params(self, f, protocol_version):
@@ -563,41 +562,17 @@ def _write_query_params(self, f, protocol_version):
flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now
if self.serial_consistency_level:
- if protocol_version >= 2:
- flags |= _WITH_SERIAL_CONSISTENCY_FLAG
- else:
- raise UnsupportedOperation(
- "Serial consistency levels require the use of protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2 "
- "to support serial consistency levels.")
+ flags |= _WITH_SERIAL_CONSISTENCY_FLAG
if self.fetch_size:
- if protocol_version >= 2:
- flags |= _PAGE_SIZE_FLAG
- else:
- raise UnsupportedOperation(
- "Automatic query paging may only be used with protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2.")
+ flags |= _PAGE_SIZE_FLAG
if self.paging_state:
- if protocol_version >= 2:
- flags |= _WITH_PAGING_STATE_FLAG
- else:
- raise UnsupportedOperation(
- "Automatic query paging may only be used with protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2.")
+ flags |= _WITH_PAGING_STATE_FLAG
if self.timestamp is not None:
flags |= _PROTOCOL_TIMESTAMP_FLAG
- if self.continuous_paging_options:
- if ProtocolVersion.has_continuous_paging_support(protocol_version):
- flags |= _PAGING_OPTIONS_FLAG
- else:
- raise UnsupportedOperation(
- "Continuous paging may only be used with protocol version "
- "ProtocolVersion.DSE_V1 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V1.")
-
if self.keyspace is not None:
if ProtocolVersion.uses_keyspace_flag(protocol_version):
flags |= _WITH_KEYSPACE_FLAG
@@ -625,14 +600,10 @@ def _write_query_params(self, f, protocol_version):
write_long(f, self.timestamp)
if self.keyspace is not None:
write_string(f, self.keyspace)
- if self.continuous_paging_options:
- self._write_paging_options(f, self.continuous_paging_options, protocol_version)
def _write_paging_options(self, f, paging_options, protocol_version):
write_int(f, paging_options.max_pages)
write_int(f, paging_options.max_pages_per_second)
- if ProtocolVersion.has_continuous_paging_next_pages(protocol_version):
- write_int(f, paging_options.max_queue_size)
class QueryMessage(_QueryMessage):
@@ -640,9 +611,10 @@ class QueryMessage(_QueryMessage):
name = 'QUERY'
def __init__(self, query, consistency_level, serial_consistency_level=None,
- fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None):
+ fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None,
+ query_params=None):
self.query = query
- super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size,
+ super(QueryMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size,
paging_state, timestamp, False, continuous_paging_options, keyspace)
def send_body(self, f, protocol_version):
@@ -664,22 +636,7 @@ def __init__(self, query_id, query_params, consistency_level,
paging_state, timestamp, skip_meta, continuous_paging_options)
def _write_query_params(self, f, protocol_version):
- if protocol_version == 1:
- if self.serial_consistency_level:
- raise UnsupportedOperation(
- "Serial consistency levels require the use of protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2 "
- "to support serial consistency levels.")
- if self.fetch_size or self.paging_state:
- raise UnsupportedOperation(
- "Automatic query paging may only be used with protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2.")
- write_short(f, len(self.query_params))
- for param in self.query_params:
- write_value(f, param)
- write_consistency_level(f, self.consistency_level)
- else:
- super(ExecuteMessage, self)._write_query_params(f, protocol_version)
+ super(ExecuteMessage, self)._write_query_params(f, protocol_version)
def send_body(self, f, protocol_version):
write_string(f, self.query_id)
@@ -730,11 +687,12 @@ class ResultMessage(_MessageType):
bind_metadata = None
pk_indexes = None
schema_change_event = None
+ is_lwt = False
def __init__(self, kind):
self.kind = kind
- def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
+ def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
if self.kind == RESULT_KIND_VOID:
return
elif self.kind == RESULT_KIND_ROWS:
@@ -742,7 +700,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
elif self.kind == RESULT_KIND_SET_KEYSPACE:
self.new_keyspace = read_string(f)
elif self.kind == RESULT_KIND_PREPARED:
- self.recv_results_prepared(f, protocol_version, user_type_map)
+ self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map)
elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
self.recv_results_schema_change(f, protocol_version)
else:
@@ -752,7 +710,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
kind = read_int(f)
msg = cls(kind)
- msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
+ msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy)
return msg
def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
@@ -785,13 +743,13 @@ def decode_row(row):
col_md[3].cql_parameterized_type(),
str(e)))
- def recv_results_prepared(self, f, protocol_version, user_type_map):
+ def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
self.query_id = read_binary_string(f)
if ProtocolVersion.uses_prepared_metadata(protocol_version):
self.result_metadata_id = read_binary_string(f)
else:
self.result_metadata_id = None
- self.recv_prepared_metadata(f, protocol_version, user_type_map)
+ self.recv_prepared_metadata(f, protocol_version, protocol_features, user_type_map)
def recv_results_metadata(self, f, user_type_map):
flags = read_int(f)
@@ -829,8 +787,9 @@ def recv_results_metadata(self, f, user_type_map):
self.column_metadata = column_metadata
- def recv_prepared_metadata(self, f, protocol_version, user_type_map):
+ def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_type_map):
flags = read_int(f)
+ self.is_lwt = protocol_features.lwt_info.get_lwt_flag(flags) if protocol_features.lwt_info is not None else False
colcount = read_int(f)
pk_indexes = None
if protocol_version >= 4:
@@ -853,8 +812,7 @@ def recv_prepared_metadata(self, f, protocol_version, user_type_map):
coltype = self.read_type(f, user_type_map)
bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
- if protocol_version >= 2:
- self.recv_results_metadata(f, user_type_map)
+ self.recv_results_metadata(f, user_type_map)
self.bind_metadata = bind_metadata
self.pk_indexes = pk_indexes
@@ -969,39 +927,38 @@ def send_body(self, f, protocol_version):
write_value(f, param)
write_consistency_level(f, self.consistency_level)
- if protocol_version >= 3:
- flags = 0
- if self.serial_consistency_level:
- flags |= _WITH_SERIAL_CONSISTENCY_FLAG
- if self.timestamp is not None:
- flags |= _PROTOCOL_TIMESTAMP_FLAG
- if self.keyspace:
- if ProtocolVersion.uses_keyspace_flag(protocol_version):
- flags |= _WITH_KEYSPACE_FLAG
- else:
- raise UnsupportedOperation(
- "Keyspaces may only be set on queries with protocol version "
- "5 or higher. Consider setting Cluster.protocol_version to 5.")
-
- if ProtocolVersion.uses_int_query_flags(protocol_version):
- write_int(f, flags)
+ flags = 0
+ if self.serial_consistency_level:
+ flags |= _WITH_SERIAL_CONSISTENCY_FLAG
+ if self.timestamp is not None:
+ flags |= _PROTOCOL_TIMESTAMP_FLAG
+ if self.keyspace:
+ if ProtocolVersion.uses_keyspace_flag(protocol_version):
+ flags |= _WITH_KEYSPACE_FLAG
else:
- write_byte(f, flags)
+ raise UnsupportedOperation(
+ "Keyspaces may only be set on queries with protocol version "
+ "5 or higher. Consider setting Cluster.protocol_version to 5.")
+ if ProtocolVersion.uses_int_query_flags(protocol_version):
+ write_int(f, flags)
+ else:
+ write_byte(f, flags)
- if self.serial_consistency_level:
- write_consistency_level(f, self.serial_consistency_level)
- if self.timestamp is not None:
- write_long(f, self.timestamp)
+ if self.serial_consistency_level:
+ write_consistency_level(f, self.serial_consistency_level)
+ if self.timestamp is not None:
+ write_long(f, self.timestamp)
- if ProtocolVersion.uses_keyspace_flag(protocol_version):
- if self.keyspace is not None:
- write_string(f, self.keyspace)
+ if ProtocolVersion.uses_keyspace_flag(protocol_version):
+ if self.keyspace is not None:
+ write_string(f, self.keyspace)
known_event_types = frozenset((
'TOPOLOGY_CHANGE',
'STATUS_CHANGE',
- 'SCHEMA_CHANGE'
+ 'SCHEMA_CHANGE',
+ 'CLIENT_ROUTES_CHANGE'
))
@@ -1032,6 +989,14 @@ def recv_body(cls, f, protocol_version, *args):
return cls(event_type=event_type, event_args=read_method(f, protocol_version))
raise NotSupportedError('Unknown event type %r' % event_type)
+ @classmethod
+ def recv_client_routes_change(cls, f, protocol_version):
+ # "UPDATE_NODES"
+ change_type = read_string(f)
+ connection_ids = read_stringlist(f)
+ host_ids = read_stringlist(f)
+ return dict(change_type=change_type, connection_ids=connection_ids, host_ids=host_ids)
+
@classmethod
def recv_topology_change(cls, f, protocol_version):
# "NEW_NODE" or "REMOVED_NODE"
@@ -1050,25 +1015,17 @@ def recv_status_change(cls, f, protocol_version):
def recv_schema_change(cls, f, protocol_version):
# "CREATED", "DROPPED", or "UPDATED"
change_type = read_string(f)
- if protocol_version >= 3:
- target = read_string(f)
- keyspace = read_string(f)
- event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace}
- if target != SchemaTargetType.KEYSPACE:
- target_name = read_string(f)
- if target == SchemaTargetType.FUNCTION:
- event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
- elif target == SchemaTargetType.AGGREGATE:
- event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
- else:
- event[target.lower()] = target_name
- else:
- keyspace = read_string(f)
- table = read_string(f)
- if table:
- event = {'target_type': SchemaTargetType.TABLE, 'change_type': change_type, 'keyspace': keyspace, 'table': table}
+ target = read_string(f)
+ keyspace = read_string(f)
+ event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace}
+ if target != SchemaTargetType.KEYSPACE:
+ target_name = read_string(f)
+ if target == SchemaTargetType.FUNCTION:
+ event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
+ elif target == SchemaTargetType.AGGREGATE:
+ event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
else:
- event = {'target_type': SchemaTargetType.KEYSPACE, 'change_type': change_type, 'keyspace': keyspace}
+ event[target.lower()] = target_name
return event
@@ -1092,12 +1049,9 @@ def send_body(self, f, protocol_version):
if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE:
if self.next_pages <= 0:
raise UnsupportedOperation("Continuous paging backpressure requires next_pages > 0")
- elif not ProtocolVersion.has_continuous_paging_next_pages(protocol_version):
- raise UnsupportedOperation(
- "Continuous paging backpressure may only be used with protocol version "
- "ProtocolVersion.DSE_V2 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V2.")
else:
- write_int(f, self.next_pages)
+ raise UnsupportedOperation(
+ "Continuous paging backpressure is not supported.")
class _ProtocolHandler(object):
@@ -1132,20 +1086,10 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
:param compressor: optional compression function to be used on the body
"""
flags = 0
- body = io.BytesIO()
if msg.custom_payload:
if protocol_version < 4:
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
flags |= CUSTOM_PAYLOAD_FLAG
- write_bytesmap(body, msg.custom_payload)
- msg.send_body(body, protocol_version)
- body = body.getvalue()
-
- # With checksumming, the compression is done at the segment frame encoding
- if (not ProtocolVersion.has_checksumming_support(protocol_version)
- and compressor and len(body) > 0):
- body = compressor(body)
- flags |= COMPRESSED_FLAG
if msg.tracing:
flags |= TRACING_FLAG
@@ -1154,9 +1098,31 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
flags |= USE_BETA_FLAG
buff = io.BytesIO()
- cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
- buff.write(body)
+ buff.seek(9)
+
+ # With checksumming, the compression is done at the segment frame encoding
+ if (compressor and not ProtocolVersion.has_checksumming_support(protocol_version)):
+ body = io.BytesIO()
+ if msg.custom_payload:
+ write_bytesmap(body, msg.custom_payload)
+ msg.send_body(body, protocol_version)
+ body = body.getvalue()
+
+ if len(body) > 0:
+ body = compressor(body)
+ flags |= COMPRESSED_FLAG
+
+ buff.write(body)
+ length = len(body)
+ else:
+ if msg.custom_payload:
+ write_bytesmap(buff, msg.custom_payload)
+ msg.send_body(buff, protocol_version)
+
+ length = buff.tell() - 9
+ buff.seek(0)
+ cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, length)
return buff.getvalue()
@staticmethod
@@ -1164,8 +1130,7 @@ def _write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
- pack = v3_header_pack if version >= 3 else header_pack
- f.write(pack(version, flags, stream_id, opcode))
+ f.write(v3_header_pack(version, flags, stream_id, opcode))
write_int(f, length)
@classmethod
diff --git a/cassandra/protocol_features.py b/cassandra/protocol_features.py
index 4eb7019f84..877998be7d 100644
--- a/cassandra/protocol_features.py
+++ b/cassandra/protocol_features.py
@@ -1,10 +1,13 @@
import logging
from cassandra.shard_info import _ShardingInfo
+from cassandra.lwt_info import _LwtInfo
log = logging.getLogger(__name__)
+LWT_ADD_METADATA_MARK = "SCYLLA_LWT_ADD_METADATA_MARK"
+LWT_OPTIMIZATION_META_BIT_MASK = "LWT_OPTIMIZATION_META_BIT_MASK"
RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1"
@@ -13,19 +16,22 @@ class ProtocolFeatures(object):
shard_id = 0
sharding_info = None
tablets_routing_v1 = False
+ lwt_info = None
- def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False):
+ def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None):
self.rate_limit_error = rate_limit_error
self.shard_id = shard_id
self.sharding_info = sharding_info
self.tablets_routing_v1 = tablets_routing_v1
+ self.lwt_info = lwt_info
@staticmethod
def parse_from_supported(supported):
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported)
- return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1)
+ lwt_info = ProtocolFeatures.parse_lwt_info(supported)
+ return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info)
@staticmethod
def maybe_parse_rate_limit_error(supported):
@@ -49,6 +55,8 @@ def add_startup_options(self, options):
options[RATE_LIMIT_ERROR_EXTENSION] = ""
if self.tablets_routing_v1:
options[TABLETS_ROUTING_V1] = ""
+ if self.lwt_info is not None:
+ options[LWT_ADD_METADATA_MARK] = str(self.lwt_info.lwt_meta_bit_mask)
@staticmethod
def parse_sharding_info(options):
@@ -72,3 +80,18 @@ def parse_sharding_info(options):
@staticmethod
def parse_tablets_info(options):
return TABLETS_ROUTING_V1 in options
+
+ @staticmethod
+ def parse_lwt_info(options):
+ value_list = options.get(LWT_ADD_METADATA_MARK, [None])
+ for value in value_list:
+ if value is None or not value.startswith(LWT_OPTIMIZATION_META_BIT_MASK + "="):
+ continue
+ try:
+ lwt_meta_bit_mask = int(value[len(LWT_OPTIMIZATION_META_BIT_MASK + "="):])
+ return _LwtInfo(lwt_meta_bit_mask)
+ except Exception as e:
+ log.exception(f"Error while parsing {LWT_ADD_METADATA_MARK}: {e}")
+ return None
+
+ return None
diff --git a/cassandra/query.py b/cassandra/query.py
index f3922849ab..6c6878fdb4 100644
--- a/cassandra/query.py
+++ b/cassandra/query.py
@@ -345,6 +345,9 @@ def _set_serial_consistency_level(self, serial_consistency_level):
def _del_serial_consistency_level(self):
self._serial_consistency_level = None
+ def is_lwt(self):
+ return False
+
serial_consistency_level = property(
_get_serial_consistency_level,
_set_serial_consistency_level,
@@ -454,10 +457,11 @@ class PreparedStatement(object):
routing_key_indexes = None
_routing_key_index_set = None
serial_consistency_level = None # TODO never used?
+ _is_lwt = False
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
keyspace, protocol_version, result_metadata, result_metadata_id,
- column_encryption_policy=None):
+ is_lwt=False, column_encryption_policy=None):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
@@ -468,15 +472,16 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
self.result_metadata_id = result_metadata_id
self.column_encryption_policy = column_encryption_policy
self.is_idempotent = False
+ self._is_lwt = is_lwt
@classmethod
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, column_encryption_policy=None):
+ result_metadata_id, is_lwt, column_encryption_policy=None):
if not column_metadata:
return PreparedStatement(column_metadata, query_id, None,
query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, column_encryption_policy)
+ result_metadata_id, is_lwt, column_encryption_policy)
if pk_indexes:
routing_key_indexes = pk_indexes
@@ -502,7 +507,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id, column_encryption_policy)
+ result_metadata_id, is_lwt, column_encryption_policy)
def bind(self, values):
"""
@@ -517,6 +522,9 @@ def is_routing_key_index(self, i):
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
return i in self._routing_key_index_set
+ def is_lwt(self):
+ return self._is_lwt
+
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'' %
@@ -682,6 +690,9 @@ def routing_key(self):
return self._routing_key
+ def is_lwt(self):
+ return self.prepared_statement.is_lwt()
+
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'' %
@@ -750,6 +761,7 @@ class BatchStatement(Statement):
_statements_and_parameters = None
_session = None
+ _is_lwt = False
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
consistency_level=None, serial_consistency_level=None,
@@ -834,6 +846,8 @@ def add(self, statement, parameters=None):
query_id = statement.query_id
bound_statement = statement.bind(() if parameters is None else parameters)
self._update_state(bound_statement)
+ if statement.is_lwt():
+ self._is_lwt = True
self._add_statement_and_params(True, query_id, bound_statement.values)
elif isinstance(statement, BoundStatement):
if parameters:
@@ -841,6 +855,8 @@ def add(self, statement, parameters=None):
"Parameters cannot be passed with a BoundStatement "
"to BatchStatement.add()")
self._update_state(statement)
+ if statement.is_lwt():
+ self._is_lwt = True
self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values)
else:
# it must be a SimpleStatement
@@ -849,6 +865,8 @@ def add(self, statement, parameters=None):
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._update_state(statement)
+ if statement.is_lwt():
+ self._is_lwt = True
self._add_statement_and_params(False, query_string, ())
return self
@@ -882,6 +900,9 @@ def _update_state(self, statement):
self._maybe_set_routing_attributes(statement)
self._update_custom_payload(statement)
+ def is_lwt(self):
+ return self._is_lwt
+
def __len__(self):
return len(self._statements_and_parameters)
diff --git a/cassandra/scylla/cloud.py b/cassandra/scylla/cloud.py
deleted file mode 100644
index c3298b199a..0000000000
--- a/cassandra/scylla/cloud.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright ScyllaDB, Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import sys
-import ssl
-import tempfile
-import base64
-from ssl import SSLContext
-from contextlib import contextmanager
-from itertools import islice
-
-import yaml
-
-from cassandra.connection import SniEndPointFactory
-from cassandra.auth import AuthProvider, PlainTextAuthProvider
-
-
-@contextmanager
-def file_or_memory(path=None, data=None):
- # since we can't read keys/cert from memory yet
- # see https://github.com/python/cpython/pull/2449 which isn't accepted and PEP-543 that was withdrawn
- # so we use temporary file to load the key
- if data:
- with tempfile.NamedTemporaryFile(mode="wb") as f:
- d = base64.b64decode(data)
- f.write(d)
- if not d.endswith(b"\n"):
- f.write(b"\n")
-
- f.flush()
- yield f.name
-
- if path:
- yield path
-
-
-def nth(iterable, n, default=None):
- "Returns the nth item or a default value"
- return next(islice(iterable, n, None), default)
-
-
-class CloudConfiguration:
- endpoint_factory: SniEndPointFactory
- contact_points: list
- auth_provider: AuthProvider = None
- ssl_options: dict
- ssl_context: SSLContext
- skip_tls_verify: bool
-
- def __init__(self, configuration_file, pyopenssl=False, endpoint_factory=None):
- cloud_config = yaml.safe_load(open(configuration_file))
-
- self.current_context = cloud_config['contexts'][cloud_config['currentContext']]
- self.data_centers = cloud_config['datacenters']
- self.current_data_center = self.data_centers[self.current_context['datacenterName']]
- self.auth_info = cloud_config['authInfos'][self.current_context['authInfoName']]
- self.ssl_options = {}
- self.skip_tls_verify = self.current_data_center.get('insecureSkipTlsVerify', False)
- self.ssl_context = self.create_pyopenssl_context() if pyopenssl else self.create_ssl_context()
-
- proxy_address, port, node_domain = self.get_server(self.current_data_center)
-
- if not endpoint_factory:
- endpoint_factory = SniEndPointFactory(proxy_address, port=int(port), node_domain=node_domain)
- else:
- assert isinstance(endpoint_factory, SniEndPointFactory)
- self.endpoint_factory = endpoint_factory
-
- username, password = self.auth_info.get('username'), self.auth_info.get('password')
- if username and password:
- self.auth_provider = PlainTextAuthProvider(username, password)
-
- @property
- def contact_points(self):
- _contact_points = []
- for data_center in self.data_centers.values():
- _, _, node_domain = self.get_server(data_center)
- _contact_points.append(self.endpoint_factory.create_from_sni(node_domain))
- return _contact_points
-
- def get_server(self, data_center):
- address = data_center.get('server')
- address = address.split(":")
- port = nth(address, 1, default=9142)
- address = nth(address, 0)
- node_domain = data_center.get('nodeDomain')
- assert address and port and node_domain, "server or nodeDomain are missing"
- return address, port, node_domain
-
- def create_ssl_context(self):
- ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
- ssl_context.verify_mode = ssl.CERT_NONE if self.skip_tls_verify else ssl.CERT_REQUIRED
- for data_center in self.data_centers.values():
- with file_or_memory(path=data_center.get('certificateAuthorityPath'),
- data=data_center.get('certificateAuthorityData')) as cafile:
- ssl_context.load_verify_locations(cadata=open(cafile).read())
- with file_or_memory(path=self.auth_info.get('clientCertificatePath'),
- data=self.auth_info.get('clientCertificateData')) as certfile, \
- file_or_memory(path=self.auth_info.get('clientKeyPath'), data=self.auth_info.get('clientKeyData')) as keyfile:
- ssl_context.load_cert_chain(keyfile=keyfile,
- certfile=certfile)
-
- return ssl_context
-
- def create_pyopenssl_context(self):
- try:
- from OpenSSL import SSL
- except ImportError as e:
- raise ImportError(
- "PyOpenSSL must be installed to connect to scylla-cloud with the Eventlet or Twisted event loops") \
- .with_traceback(e.__traceback__)
- ssl_context = SSL.Context(SSL.TLS_METHOD)
- ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: True if self.skip_tls_verify else ok)
- for data_center in self.data_centers.values():
- with file_or_memory(path=data_center.get('certificateAuthorityPath'),
- data=data_center.get('certificateAuthorityData')) as cafile:
- ssl_context.load_verify_locations(cafile)
- with file_or_memory(path=self.auth_info.get('clientCertificatePath'),
- data=self.auth_info.get('clientCertificateData')) as certfile, \
- file_or_memory(path=self.auth_info.get('clientKeyPath'), data=self.auth_info.get('clientKeyData')) as keyfile:
- ssl_context.use_privatekey_file(keyfile)
- ssl_context.use_certificate_file(certfile)
-
- return ssl_context
-
- @classmethod
- def create(cls, configuration_file, pyopenssl=False, endpoint_factory=None):
- return cls(configuration_file, pyopenssl=pyopenssl, endpoint_factory=endpoint_factory)
diff --git a/cassandra/tablets.py b/cassandra/tablets.py
index 457ee93ca4..96e61a50c2 100644
--- a/cassandra/tablets.py
+++ b/cassandra/tablets.py
@@ -1,7 +1,13 @@
+from bisect import bisect_left
+from operator import attrgetter
from threading import Lock
from typing import Optional
from uuid import UUID
+# C-accelerated attrgetter avoids per-call lambda allocation overhead
+_get_first_token = attrgetter("first_token")
+_get_last_token = attrgetter("last_token")
+
class Tablet(object):
"""
@@ -49,12 +55,15 @@ def __init__(self, tablets):
self._tablets = tablets
self._lock = Lock()
+ def table_has_tablets(self, keyspace, table) -> bool:
+ return bool(self._tablets.get((keyspace, table), []))
+
def get_tablet_for_key(self, keyspace, table, t):
tablet = self._tablets.get((keyspace, table), [])
if not tablet:
return None
- id = bisect_left(tablet, t.value, key=lambda tablet: tablet.last_token)
+ id = bisect_left(tablet, t.value, key=_get_last_token)
if id < len(tablet) and t.value > tablet[id].first_token:
return tablet[id]
return None
@@ -91,12 +100,12 @@ def add_tablet(self, keyspace, table, tablet):
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
# find first overlapping range
- start = bisect_left(tablets_for_table, tablet.first_token, key=lambda t: t.first_token)
+ start = bisect_left(tablets_for_table, tablet.first_token, key=_get_first_token)
if start > 0 and tablets_for_table[start - 1].last_token > tablet.first_token:
start = start - 1
# find last overlapping range
- end = bisect_left(tablets_for_table, tablet.last_token, key=lambda t: t.last_token)
+ end = bisect_left(tablets_for_table, tablet.last_token, key=_get_last_token)
if end < len(tablets_for_table) and tablets_for_table[end].first_token >= tablet.last_token:
end = end - 1
@@ -105,39 +114,3 @@ def add_tablet(self, keyspace, table, tablet):
tablets_for_table.insert(start, tablet)
-
-# bisect.bisect_left implementation from Python 3.11, needed untill support for
-# Python < 3.10 is dropped, it is needed to use `key` to extract last_token from
-# Tablet list - better solution performance-wise than materialize list of last_tokens
-def bisect_left(a, x, lo=0, hi=None, *, key=None):
- """Return the index where to insert item x in list a, assuming a is sorted.
-
- The return value i is such that all e in a[:i] have e < x, and all e in
- a[i:] have e >= x. So if x already appears in the list, a.insert(i, x) will
- insert just before the leftmost x already there.
-
- Optional args lo (default 0) and hi (default len(a)) bound the
- slice of a to be searched.
- """
-
- if lo < 0:
- raise ValueError('lo must be non-negative')
- if hi is None:
- hi = len(a)
- # Note, the comparison uses "<" to match the
- # __lt__() logic in list.sort() and in heapq.
- if key is None:
- while lo < hi:
- mid = (lo + hi) // 2
- if a[mid] < x:
- lo = mid + 1
- else:
- hi = mid
- return
- while lo < hi:
- mid = (lo + hi) // 2
- if key(a[mid]) < x:
- lo = mid + 1
- else:
- hi = mid
- return lo
diff --git a/cassandra/util.py b/cassandra/util.py
index 12886d05ab..593c264033 100644
--- a/cassandra/util.py
+++ b/cassandra/util.py
@@ -62,6 +62,16 @@ def datetime_from_timestamp(timestamp):
return dt
+def datetime_from_ms_timestamp(timestamp_ms):
+ """
+ Creates a timezone-agnostic datetime from a timestamp in milliseconds,
+ using integer arithmetic to preserve precision for large values.
+
+ :param timestamp_ms: a unix timestamp, in milliseconds (integer)
+ """
+ return DATETIME_EPOC + datetime.timedelta(milliseconds=timestamp_ms)
+
+
def utc_datetime_from_ms_timestamp(timestamp):
"""
Creates a UTC datetime from a timestamp in milliseconds. See
diff --git a/docs/.gitignore b/docs/.gitignore
new file mode 100644
index 0000000000..733bc65597
--- /dev/null
+++ b/docs/.gitignore
@@ -0,0 +1,2 @@
+# Track uv.lock for reproducible docs builds
+!uv.lock
diff --git a/docs/Makefile b/docs/Makefile
index b1c54c8199..09512be470 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -1,8 +1,9 @@
# Global variables
# You can set these variables from the command line.
-POETRY = poetry
+SHELL = bash
+UV = uv
SPHINXOPTS = -j auto
-SPHINXBUILD = $(POETRY) run sphinx-build
+SPHINXBUILD = $(UV) run --frozen sphinx-build
PAPER =
BUILDDIR = _build
SOURCEDIR = .
@@ -17,18 +18,13 @@ TESTSPHINXOPTS = $(ALLSPHINXOPTS) -W --keep-going
all: dirhtml
# Setup commands
-.PHONY: setupenv
-setupenv:
- pip install -q poetry
- sudo apt-get install gcc python3-dev libev4 libev-dev
-
-.PHONY: setup
-setup:
- $(POETRY) install
+#.PHONY: setupenv
+#setupenv:
+# uv pip install -r <(uv pip compile pyproject.toml)
.PHONY: update
update:
- $(POETRY) update
+ $(UV) lock --upgrade
# Clean commands
.PHONY: pristine
@@ -41,58 +37,58 @@ clean:
# Generate output commands
.PHONY: dirhtml
-dirhtml: setup
+dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
.PHONY: singlehtml
-singlehtml: setup
+singlehtml:
$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
.PHONY: epub
-epub: setup
+epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
@echo "Build finished. The epub file is in $(BUILDDIR)/epub."
.PHONY: epub3
-epub3: setup
+epub3:
$(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3
@echo
@echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3."
.PHONY: multiversion
-multiversion: setup
- $(POETRY) run sphinx-multiversion $(SOURCEDIR) $(BUILDDIR)/dirhtml
+multiversion:
+ $(UV) run --frozen sphinx-multiversion $(SOURCEDIR) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
.PHONY: redirects
-redirects: setup
- $(POETRY) run redirects-cli fromfile --yaml-file _utils/redirects.yaml --output-dir $(BUILDDIR)/dirhtml
+redirects:
+ $(UV) run --frozen redirects-cli fromfile --yaml-file _utils/redirects.yaml --output-dir $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
# Preview commands
.PHONY: preview
-preview: setup
- $(POETRY) run sphinx-autobuild -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml --port 5500
+preview:
+ $(UV) run --frozen sphinx-autobuild -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml --port 5500
.PHONY: multiversionpreview
multiversionpreview: multiversion
- $(POETRY) run python -m http.server 5500 --directory $(BUILDDIR)/dirhtml
+ $(UV) run --frozen python -m http.server 5500 --directory $(BUILDDIR)/dirhtml
# Test commands
.PHONY: test
-test: setup
+test:
$(SPHINXBUILD) -b dirhtml $(TESTSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
.PHONY: linkcheck
-linkcheck: setup
+linkcheck:
$(SPHINXBUILD) -b linkcheck $(SOURCEDIR) $(BUILDDIR)/linkcheck
diff --git a/docs/api/cassandra.rst b/docs/api/cassandra.rst
index d46aae56cb..53789b9582 100644
--- a/docs/api/cassandra.rst
+++ b/docs/api/cassandra.rst
@@ -1,5 +1,7 @@
-:mod:`cassandra` - Exceptions and Enums
-=======================================
+cassandra
+=========
+
+Exceptions and Enums
.. module:: cassandra
diff --git a/docs/api/cassandra/auth.rst b/docs/api/cassandra/auth.rst
index 58c964cf89..91bb4e9139 100644
--- a/docs/api/cassandra/auth.rst
+++ b/docs/api/cassandra/auth.rst
@@ -1,5 +1,7 @@
-``cassandra.auth`` - Authentication
-===================================
+cassandra.auth
+==============
+
+Authentication
.. module:: cassandra.auth
diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst
index a9a9d378a4..44b7b63f67 100644
--- a/docs/api/cassandra/cluster.rst
+++ b/docs/api/cassandra/cluster.rst
@@ -1,5 +1,7 @@
-``cassandra.cluster`` - Clusters and Sessions
-=============================================
+cassandra.cluster
+=================
+
+Clusters and Sessions
.. module:: cassandra.cluster
@@ -46,6 +48,8 @@
.. autoattribute:: control_connection_timeout
+ .. autoattribute:: allow_control_connection_query_fallback
+
.. autoattribute:: idle_heartbeat_interval
.. autoattribute:: idle_heartbeat_timeout
@@ -86,22 +90,6 @@
.. automethod:: add_execution_profile
- .. automethod:: set_max_requests_per_connection
-
- .. automethod:: get_max_requests_per_connection
-
- .. automethod:: set_min_requests_per_connection
-
- .. automethod:: get_min_requests_per_connection
-
- .. automethod:: get_core_connections_per_host
-
- .. automethod:: set_core_connections_per_host
-
- .. automethod:: get_max_connections_per_host
-
- .. automethod:: set_max_connections_per_host
-
.. automethod:: get_control_connection_host
.. automethod:: refresh_schema_metadata
@@ -120,6 +108,9 @@
.. automethod:: set_meta_refresh_enabled
+.. autoclass:: ControlConnectionQueryFallback
+ :members:
+
.. autoclass:: ExecutionProfile (load_balancing_policy=